View Javadoc
1   package net.bmahe.genetics4j.samples.clustering;
2   
3   import java.util.HashMap;
4   import java.util.HashSet;
5   import java.util.Map;
6   import java.util.Set;
7   
8   import org.apache.commons.lang3.Validate;
9   
10  import net.bmahe.genetics4j.core.Fitness;
11  
12  public class FitnessUtils {
13  
14  	// tag::a_i[]
15  	private final static double a_i(final double[][] data, final double[][] distances,
16  			final Map<Integer, Set<Integer>> clusterToMembers, final int clusterIndex, final int i) {
17  		Validate.notNull(data);
18  		Validate.notNull(distances);
19  		Validate.notNull(clusterToMembers);
20  
21  		final var members = clusterToMembers.get(clusterIndex);
22  
23  		double sumDistances = 0.0;
24  		for (final int memberIndex : members) {
25  			if (memberIndex != i) {
26  				sumDistances += distances[i][memberIndex];
27  			}
28  		}
29  
30  		return sumDistances / ((double) members.size() - 1.0d);
31  	}
32  	// end::a_i[]
33  
34  	// tag::b_i[]
35  	private final static double b_i(final double[][] data, final double[][] distances,
36  			final Map<Integer, Set<Integer>> clusterToMembers, final int numClusters, final int clusterIndex,
37  			final int i) {
38  		Validate.notNull(data);
39  		Validate.notNull(distances);
40  		Validate.notNull(clusterToMembers);
41  		Validate.isTrue(numClusters > 0);
42  		Validate.inclusiveBetween(0, numClusters - 1, clusterIndex);
43  
44  		double minMean = -1;
45  		for (int otherClusterIndex = 0; otherClusterIndex < numClusters; otherClusterIndex++) {
46  
47  			if (otherClusterIndex != clusterIndex) {
48  
49  				final var members = clusterToMembers.get(otherClusterIndex);
50  
51  				if (members != null && members.size() > 0) {
52  					double sumDistances = 0.0;
53  					for (final int memberIndex : members) {
54  						sumDistances += distances[i][memberIndex];
55  					}
56  
57  					final double meanDistance = sumDistances / members.size();
58  
59  					if (minMean < 0 || meanDistance < minMean) {
60  						minMean = meanDistance;
61  					}
62  				}
63  			}
64  		}
65  
66  		return minMean;
67  	}
68  	// end::b_i[]
69  
70  	public final static int[] assignDataToClusters(final double[][] data, double[][] distances,
71  			final double[][] clusters) {
72  
73  		final double[] closestClusterDistance = new double[data.length];
74  		final int[] closestClusterIndex = new int[data.length];
75  
76  		for (int i = 0; i < data.length; i++) {
77  			closestClusterIndex[i] = -1;
78  
79  			final double dataX = data[i][0];
80  			final double dataY = data[i][1];
81  
82  			for (int clusterIndex = 0; clusterIndex < clusters.length; clusterIndex++) {
83  				final double clusterX = clusters[clusterIndex][0];
84  				final double clusterY = clusters[clusterIndex][1];
85  
86  				final double distance = Math
87  						.sqrt(((clusterX - dataX) * (clusterX - dataX)) + ((clusterY - dataY) * (clusterY - dataY)));
88  
89  				if (closestClusterIndex[i] == -1 || distance < closestClusterDistance[i]) {
90  					closestClusterIndex[i] = clusterIndex;
91  					closestClusterDistance[i] = distance;
92  				}
93  			}
94  		}
95  
96  		return closestClusterIndex;
97  	}
98  
99  	public final static double computeSilhouetteScore(final double[][] data, double[][] distances, final int numClusters,
100 			final Map<Integer, Set<Integer>> clusterToMembers, final int[] closestClusterIndex, final int i) {
101 
102 		final int clusterI = closestClusterIndex[i];
103 
104 		double silhouetteScore = 0.0d;
105 		if (clusterToMembers.getOrDefault(clusterI, Set.of()).size() > 1) {
106 			final double ai = a_i(data, distances, clusterToMembers, clusterI, i);
107 			final double bi = b_i(data, distances, clusterToMembers, numClusters, clusterI, i);
108 
109 			silhouetteScore = (bi - ai) / Math.max(ai, bi);
110 		}
111 
112 		return silhouetteScore;
113 	}
114 
115 	public final static double computeSumSquaredErrors(final double[][] data, double[][] distances,
116 			final double[][] clusters, final Map<Integer, Set<Integer>> clusterToMembers,
117 			final int[] closestClusterIndex) {
118 
119 		double sumSquareErrors = 0.0d;
120 		for (int i = 0; i < data.length; i++) {
121 			final double[] cluster = clusters[closestClusterIndex[i]];
122 
123 			sumSquareErrors += (cluster[0] - data[i][0]) * (cluster[0] - data[i][0]);
124 			sumSquareErrors += (cluster[1] - data[i][1]) * (cluster[1] - data[i][1]);
125 		}
126 
127 		return sumSquareErrors;
128 	}
129 
130 	// tag::fitness[]
131 	public final static Fitness<Double> computeFitness(final int numDataPoints, final double[][] data,
132 			double[][] distances, final int numClusters) {
133 		Validate.notNull(data);
134 		Validate.notNull(distances);
135 		Validate.isTrue(numDataPoints > 0);
136 		Validate.isTrue(numDataPoints == data.length);
137 		Validate.isTrue(numDataPoints == distances.length);
138 		Validate.isTrue(numClusters > 0);
139 
140 		return (genoType) -> {
141 
142 			final double[][] clusters = PhenotypeUtils.toPhenotype(genoType);
143 
144 			final int[] closestClusterIndex = assignDataToClusters(data, distances, clusters);
145 
146 			final Map<Integer, Set<Integer>> clusterToMembers = new HashMap<>();
147 
148 			for (int i = 0; i < numDataPoints; i++) {
149 				final var members = clusterToMembers.computeIfAbsent(closestClusterIndex[i], k -> new HashSet<>());
150 				members.add(i);
151 			}
152 
153 			double sum_si = 0.0;
154 			for (int i = 0; i < numDataPoints; i++) {
155 				sum_si += computeSilhouetteScore(data, distances, numClusters, clusterToMembers, closestClusterIndex, i);
156 			}
157 
158 			return sum_si;
159 		};
160 	}
161 	// end::fitness[]
162 
163 	// Copy/pasted for the Clustering doc
164 	// tag::fitness_with_sse[]
165 	public final static Fitness<Double> computeFitnessWithSSE(final int numDataPoints, final double[][] data,
166 			double[][] distances, final int numClusters) {
167 		Validate.notNull(data);
168 		Validate.notNull(distances);
169 		Validate.isTrue(numDataPoints > 0);
170 		Validate.isTrue(numDataPoints == data.length);
171 		Validate.isTrue(numDataPoints == distances.length);
172 		Validate.isTrue(numClusters > 0);
173 
174 		return (genoType) -> {
175 
176 			final double[][] clusters = PhenotypeUtils.toPhenotype(genoType);
177 
178 			final int[] closestClusterIndex = assignDataToClusters(data, distances, clusters);
179 
180 			final Map<Integer, Set<Integer>> clusterToMembers = new HashMap<>();
181 
182 			for (int i = 0; i < numDataPoints; i++) {
183 				final var members = clusterToMembers.computeIfAbsent(closestClusterIndex[i], k -> new HashSet<>());
184 				members.add(i);
185 			}
186 
187 			double sum_si = 0.0;
188 			for (int i = 0; i < numDataPoints; i++) {
189 				sum_si += computeSilhouetteScore(data, distances, numClusters, clusterToMembers, closestClusterIndex, i);
190 			}
191 
192 			final double sumSquaredError = computeSumSquaredErrors(data,
193 					distances,
194 					clusters,
195 					clusterToMembers,
196 					closestClusterIndex);
197 
198 			return sum_si + (1.0 / sumSquaredError);
199 		};
200 	}
201 	// end::fitness_with_sse[]
202 
203 }