1 package net.bmahe.genetics4j.samples.clustering;
2
3 import java.util.HashMap;
4 import java.util.HashSet;
5 import java.util.Map;
6 import java.util.Set;
7
8 import org.apache.commons.lang3.Validate;
9
10 import net.bmahe.genetics4j.core.Fitness;
11
12 public class FitnessUtils {
13
14
15 private final static double a_i(final double[][] data, final double[][] distances,
16 final Map<Integer, Set<Integer>> clusterToMembers, final int clusterIndex, final int i) {
17 Validate.notNull(data);
18 Validate.notNull(distances);
19 Validate.notNull(clusterToMembers);
20
21 final var members = clusterToMembers.get(clusterIndex);
22
23 double sumDistances = 0.0;
24 for (final int memberIndex : members) {
25 if (memberIndex != i) {
26 sumDistances += distances[i][memberIndex];
27 }
28 }
29
30 return sumDistances / ((double) members.size() - 1.0d);
31 }
32
33
34
35 private final static double b_i(final double[][] data, final double[][] distances,
36 final Map<Integer, Set<Integer>> clusterToMembers, final int numClusters, final int clusterIndex,
37 final int i) {
38 Validate.notNull(data);
39 Validate.notNull(distances);
40 Validate.notNull(clusterToMembers);
41 Validate.isTrue(numClusters > 0);
42 Validate.inclusiveBetween(0, numClusters - 1, clusterIndex);
43
44 double minMean = -1;
45 for (int otherClusterIndex = 0; otherClusterIndex < numClusters; otherClusterIndex++) {
46
47 if (otherClusterIndex != clusterIndex) {
48
49 final var members = clusterToMembers.get(otherClusterIndex);
50
51 if (members != null && members.size() > 0) {
52 double sumDistances = 0.0;
53 for (final int memberIndex : members) {
54 sumDistances += distances[i][memberIndex];
55 }
56
57 final double meanDistance = sumDistances / members.size();
58
59 if (minMean < 0 || meanDistance < minMean) {
60 minMean = meanDistance;
61 }
62 }
63 }
64 }
65
66 return minMean;
67 }
68
69
70 public final static int[] assignDataToClusters(final double[][] data, double[][] distances,
71 final double[][] clusters) {
72
73 final double[] closestClusterDistance = new double[data.length];
74 final int[] closestClusterIndex = new int[data.length];
75
76 for (int i = 0; i < data.length; i++) {
77 closestClusterIndex[i] = -1;
78
79 final double dataX = data[i][0];
80 final double dataY = data[i][1];
81
82 for (int clusterIndex = 0; clusterIndex < clusters.length; clusterIndex++) {
83 final double clusterX = clusters[clusterIndex][0];
84 final double clusterY = clusters[clusterIndex][1];
85
86 final double distance = Math
87 .sqrt(((clusterX - dataX) * (clusterX - dataX)) + ((clusterY - dataY) * (clusterY - dataY)));
88
89 if (closestClusterIndex[i] == -1 || distance < closestClusterDistance[i]) {
90 closestClusterIndex[i] = clusterIndex;
91 closestClusterDistance[i] = distance;
92 }
93 }
94 }
95
96 return closestClusterIndex;
97 }
98
99 public final static double computeSilhouetteScore(final double[][] data, double[][] distances, final int numClusters,
100 final Map<Integer, Set<Integer>> clusterToMembers, final int[] closestClusterIndex, final int i) {
101
102 final int clusterI = closestClusterIndex[i];
103
104 double silhouetteScore = 0.0d;
105 if (clusterToMembers.getOrDefault(clusterI, Set.of())
106 .size() > 1) {
107 final double ai = a_i(data, distances, clusterToMembers, clusterI, i);
108 final double bi = b_i(data, distances, clusterToMembers, numClusters, clusterI, i);
109
110 silhouetteScore = (bi - ai) / Math.max(ai, bi);
111 }
112
113 return silhouetteScore;
114 }
115
116 public final static double computeSumSquaredErrors(final double[][] data, double[][] distances,
117 final double[][] clusters, final Map<Integer, Set<Integer>> clusterToMembers,
118 final int[] closestClusterIndex) {
119
120 double sumSquareErrors = 0.0d;
121 for (int i = 0; i < data.length; i++) {
122 final double[] cluster = clusters[closestClusterIndex[i]];
123
124 sumSquareErrors += (cluster[0] - data[i][0]) * (cluster[0] - data[i][0]);
125 sumSquareErrors += (cluster[1] - data[i][1]) * (cluster[1] - data[i][1]);
126 }
127
128 return sumSquareErrors;
129 }
130
131
132 public final static Fitness<Double> computeFitness(final int numDataPoints, final double[][] data,
133 double[][] distances, final int numClusters) {
134 Validate.notNull(data);
135 Validate.notNull(distances);
136 Validate.isTrue(numDataPoints > 0);
137 Validate.isTrue(numDataPoints == data.length);
138 Validate.isTrue(numDataPoints == distances.length);
139 Validate.isTrue(numClusters > 0);
140
141 return (genoType) -> {
142
143 final double[][] clusters = PhenotypeUtils.toPhenotype(genoType);
144
145 final int[] closestClusterIndex = assignDataToClusters(data, distances, clusters);
146
147 final Map<Integer, Set<Integer>> clusterToMembers = new HashMap<>();
148
149 for (int i = 0; i < numDataPoints; i++) {
150 final var members = clusterToMembers.computeIfAbsent(closestClusterIndex[i], k -> new HashSet<>());
151 members.add(i);
152 }
153
154 double sum_si = 0.0;
155 for (int i = 0; i < numDataPoints; i++) {
156 sum_si += computeSilhouetteScore(data, distances, numClusters, clusterToMembers, closestClusterIndex, i);
157 }
158
159 return sum_si;
160 };
161 }
162
163
164
165 public final static Fitness<Double> computeFitnessWithSSE(final int numDataPoints, final double[][] data,
166 double[][] distances, final int numClusters) {
167 Validate.notNull(data);
168 Validate.notNull(distances);
169 Validate.isTrue(numDataPoints > 0);
170 Validate.isTrue(numDataPoints == data.length);
171 Validate.isTrue(numDataPoints == distances.length);
172 Validate.isTrue(numClusters > 0);
173
174 return (genoType) -> {
175
176 final double[][] clusters = PhenotypeUtils.toPhenotype(genoType);
177
178 final int[] closestClusterIndex = assignDataToClusters(data, distances, clusters);
179
180 final Map<Integer, Set<Integer>> clusterToMembers = new HashMap<>();
181
182 for (int i = 0; i < numDataPoints; i++) {
183 final var members = clusterToMembers.computeIfAbsent(closestClusterIndex[i], k -> new HashSet<>());
184 members.add(i);
185 }
186
187 double sum_si = 0.0;
188 for (int i = 0; i < numDataPoints; i++) {
189 sum_si += computeSilhouetteScore(data, distances, numClusters, clusterToMembers, closestClusterIndex, i);
190 }
191
192 final double sumSquaredError = computeSumSquaredErrors(data,
193 distances,
194 clusters,
195 clusterToMembers,
196 closestClusterIndex);
197
198 return sum_si + (1.0 / sumSquaredError);
199 };
200 }
201
202
203 }