1 package net.bmahe.genetics4j.samples.clustering;
2
3 import java.util.HashMap;
4 import java.util.HashSet;
5 import java.util.Map;
6 import java.util.Set;
7
8 import org.apache.commons.lang3.Validate;
9
10 import net.bmahe.genetics4j.core.Fitness;
11
12 public class FitnessUtils {
13
14
15 private final static double a_i(final double[][] data, final double[][] distances,
16 final Map<Integer, Set<Integer>> clusterToMembers, final int clusterIndex, final int i) {
17 Validate.notNull(data);
18 Validate.notNull(distances);
19 Validate.notNull(clusterToMembers);
20
21 final var members = clusterToMembers.get(clusterIndex);
22
23 double sumDistances = 0.0;
24 for (final int memberIndex : members) {
25 if (memberIndex != i) {
26 sumDistances += distances[i][memberIndex];
27 }
28 }
29
30 return sumDistances / ((double) members.size() - 1.0d);
31 }
32
33
34
35 private final static double b_i(final double[][] data, final double[][] distances,
36 final Map<Integer, Set<Integer>> clusterToMembers, final int numClusters, final int clusterIndex,
37 final int i) {
38 Validate.notNull(data);
39 Validate.notNull(distances);
40 Validate.notNull(clusterToMembers);
41 Validate.isTrue(numClusters > 0);
42 Validate.inclusiveBetween(0, numClusters - 1, clusterIndex);
43
44 double minMean = -1;
45 for (int otherClusterIndex = 0; otherClusterIndex < numClusters; otherClusterIndex++) {
46
47 if (otherClusterIndex != clusterIndex) {
48
49 final var members = clusterToMembers.get(otherClusterIndex);
50
51 if (members != null && members.size() > 0) {
52 double sumDistances = 0.0;
53 for (final int memberIndex : members) {
54 sumDistances += distances[i][memberIndex];
55 }
56
57 final double meanDistance = sumDistances / members.size();
58
59 if (minMean < 0 || meanDistance < minMean) {
60 minMean = meanDistance;
61 }
62 }
63 }
64 }
65
66 return minMean;
67 }
68
69
70 public final static int[] assignDataToClusters(final double[][] data, double[][] distances,
71 final double[][] clusters) {
72
73 final double[] closestClusterDistance = new double[data.length];
74 final int[] closestClusterIndex = new int[data.length];
75
76 for (int i = 0; i < data.length; i++) {
77 closestClusterIndex[i] = -1;
78
79 final double dataX = data[i][0];
80 final double dataY = data[i][1];
81
82 for (int clusterIndex = 0; clusterIndex < clusters.length; clusterIndex++) {
83 final double clusterX = clusters[clusterIndex][0];
84 final double clusterY = clusters[clusterIndex][1];
85
86 final double distance = Math
87 .sqrt(((clusterX - dataX) * (clusterX - dataX)) + ((clusterY - dataY) * (clusterY - dataY)));
88
89 if (closestClusterIndex[i] == -1 || distance < closestClusterDistance[i]) {
90 closestClusterIndex[i] = clusterIndex;
91 closestClusterDistance[i] = distance;
92 }
93 }
94 }
95
96 return closestClusterIndex;
97 }
98
99 public final static double computeSilhouetteScore(final double[][] data, double[][] distances, final int numClusters,
100 final Map<Integer, Set<Integer>> clusterToMembers, final int[] closestClusterIndex, final int i) {
101
102 final int clusterI = closestClusterIndex[i];
103
104 double silhouetteScore = 0.0d;
105 if (clusterToMembers.getOrDefault(clusterI, Set.of()).size() > 1) {
106 final double ai = a_i(data, distances, clusterToMembers, clusterI, i);
107 final double bi = b_i(data, distances, clusterToMembers, numClusters, clusterI, i);
108
109 silhouetteScore = (bi - ai) / Math.max(ai, bi);
110 }
111
112 return silhouetteScore;
113 }
114
115 public final static double computeSumSquaredErrors(final double[][] data, double[][] distances,
116 final double[][] clusters, final Map<Integer, Set<Integer>> clusterToMembers,
117 final int[] closestClusterIndex) {
118
119 double sumSquareErrors = 0.0d;
120 for (int i = 0; i < data.length; i++) {
121 final double[] cluster = clusters[closestClusterIndex[i]];
122
123 sumSquareErrors += (cluster[0] - data[i][0]) * (cluster[0] - data[i][0]);
124 sumSquareErrors += (cluster[1] - data[i][1]) * (cluster[1] - data[i][1]);
125 }
126
127 return sumSquareErrors;
128 }
129
130
131 public final static Fitness<Double> computeFitness(final int numDataPoints, final double[][] data,
132 double[][] distances, final int numClusters) {
133 Validate.notNull(data);
134 Validate.notNull(distances);
135 Validate.isTrue(numDataPoints > 0);
136 Validate.isTrue(numDataPoints == data.length);
137 Validate.isTrue(numDataPoints == distances.length);
138 Validate.isTrue(numClusters > 0);
139
140 return (genoType) -> {
141
142 final double[][] clusters = PhenotypeUtils.toPhenotype(genoType);
143
144 final int[] closestClusterIndex = assignDataToClusters(data, distances, clusters);
145
146 final Map<Integer, Set<Integer>> clusterToMembers = new HashMap<>();
147
148 for (int i = 0; i < numDataPoints; i++) {
149 final var members = clusterToMembers.computeIfAbsent(closestClusterIndex[i], k -> new HashSet<>());
150 members.add(i);
151 }
152
153 double sum_si = 0.0;
154 for (int i = 0; i < numDataPoints; i++) {
155 sum_si += computeSilhouetteScore(data, distances, numClusters, clusterToMembers, closestClusterIndex, i);
156 }
157
158 return sum_si;
159 };
160 }
161
162
163
164
165 public final static Fitness<Double> computeFitnessWithSSE(final int numDataPoints, final double[][] data,
166 double[][] distances, final int numClusters) {
167 Validate.notNull(data);
168 Validate.notNull(distances);
169 Validate.isTrue(numDataPoints > 0);
170 Validate.isTrue(numDataPoints == data.length);
171 Validate.isTrue(numDataPoints == distances.length);
172 Validate.isTrue(numClusters > 0);
173
174 return (genoType) -> {
175
176 final double[][] clusters = PhenotypeUtils.toPhenotype(genoType);
177
178 final int[] closestClusterIndex = assignDataToClusters(data, distances, clusters);
179
180 final Map<Integer, Set<Integer>> clusterToMembers = new HashMap<>();
181
182 for (int i = 0; i < numDataPoints; i++) {
183 final var members = clusterToMembers.computeIfAbsent(closestClusterIndex[i], k -> new HashSet<>());
184 members.add(i);
185 }
186
187 double sum_si = 0.0;
188 for (int i = 0; i < numDataPoints; i++) {
189 sum_si += computeSilhouetteScore(data, distances, numClusters, clusterToMembers, closestClusterIndex, i);
190 }
191
192 final double sumSquaredError = computeSumSquaredErrors(data,
193 distances,
194 clusters,
195 clusterToMembers,
196 closestClusterIndex);
197
198 return sum_si + (1.0 / sumSquaredError);
199 };
200 }
201
202
203 }