1 package net.bmahe.genetics4j.samples.mixturemodel;
2
3 import java.io.IOException;
4 import java.nio.charset.StandardCharsets;
5 import java.nio.file.Path;
6 import java.util.HashSet;
7 import java.util.Map;
8 import java.util.Map.Entry;
9 import java.util.Set;
10 import java.util.TreeMap;
11 import java.util.stream.IntStream;
12
13 import org.apache.commons.csv.CSVFormat;
14 import org.apache.commons.csv.CSVPrinter;
15 import org.apache.commons.lang3.Validate;
16 import org.apache.commons.math3.distribution.MultivariateNormalDistribution;
17 import org.apache.commons.math3.exception.MathUnsupportedOperationException;
18 import org.apache.commons.math3.linear.NonPositiveDefiniteMatrixException;
19 import org.apache.commons.math3.linear.SingularMatrixException;
20 import org.apache.logging.log4j.LogManager;
21 import org.apache.logging.log4j.Logger;
22
23 import net.bmahe.genetics4j.core.Genotype;
24 import net.bmahe.genetics4j.core.Individual;
25 import net.bmahe.genetics4j.core.chromosomes.DoubleChromosome;
26 import net.bmahe.genetics4j.core.chromosomes.FloatChromosome;
27 import net.bmahe.genetics4j.core.spec.EvolutionResult;
28 import net.bmahe.genetics4j.moo.FitnessVector;
29
30 public class ClusteringUtils {
31 final static public Logger logger = LogManager.getLogger(ClusteringUtils.class);
32
33 public static int[] assignClustersDoubleChromosome(final int distributionNumParameters, final double[][] samples,
34 final Genotype genotype) {
35
36 final var fChromosome = genotype.getChromosome(0, DoubleChromosome.class);
37 final int[] clusters = new int[samples.length];
38 final double[] bestProb = new double[samples.length];
39
40 for (int c = 0; c < clusters.length; c++) {
41 clusters[c] = 0;
42 bestProb[c] = Double.MIN_VALUE;
43 }
44
45 double sumAlpha = 0.0f;
46 int k = 0;
47 while (k < fChromosome.getSize()) {
48 sumAlpha += fChromosome.getAllele(k);
49 k += distributionNumParameters;
50 }
51
52 int i = 0;
53 int clusterIndex = 0;
54 while (i < fChromosome.getSize()) {
55
56 final double alpha = fChromosome.getAllele(i) / sumAlpha;
57 final double[] mean = new double[] { fChromosome.getAllele(i + 1), fChromosome.getAllele(i + 2) };
58 final double[][] covariance = new double[][] {
59 { fChromosome.getAllele(i + 3) - 15, fChromosome.getAllele(i + 4) - 15 },
60 { fChromosome.getAllele(i + 4) - 15, fChromosome.getAllele(i + 5) - 15 } };
61
62 try {
63 final var multivariateNormalDistribution = new MultivariateNormalDistribution(mean, covariance);
64
65 for (int j = 0; j < samples.length; j++) {
66 float likelyhood = (float) (alpha * multivariateNormalDistribution.density(samples[j]));
67
68 if (clusters[j] < 0 || bestProb[j] < likelyhood) {
69 bestProb[j] = likelyhood;
70 clusters[j] = clusterIndex;
71 }
72 }
73 } catch (NonPositiveDefiniteMatrixException | SingularMatrixException | MathUnsupportedOperationException e) {
74 }
75
76 i += distributionNumParameters;
77 clusterIndex++;
78 }
79
80 return clusters;
81 }
82
83 public static int[] assignClustersFloatChromosome(final int distributionNumParameters, final double[][] samples,
84 final Genotype genotype) {
85
86 final var fChromosome = genotype.getChromosome(0, FloatChromosome.class);
87 final int[] clusters = new int[samples.length];
88 final double[] bestProb = new double[samples.length];
89
90 for (int c = 0; c < clusters.length; c++) {
91 clusters[c] = 0;
92 bestProb[c] = Double.MIN_VALUE;
93 }
94
95 double sumAlpha = 0.0f;
96 int k = 0;
97 while (k < fChromosome.getSize()) {
98 sumAlpha += fChromosome.getAllele(k);
99 k += distributionNumParameters;
100 }
101
102 int i = 0;
103 int clusterIndex = 0;
104 while (i < fChromosome.getSize()) {
105
106 final double alpha = fChromosome.getAllele(i) / sumAlpha;
107 final double[] mean = new double[] { fChromosome.getAllele(i + 1), fChromosome.getAllele(i + 2) };
108 final double[][] covariance = new double[][] {
109 { fChromosome.getAllele(i + 3) - 15, fChromosome.getAllele(i + 4) - 15 },
110 { fChromosome.getAllele(i + 4) - 15, fChromosome.getAllele(i + 5) - 15 } };
111
112 try {
113 final var multivariateNormalDistribution = new MultivariateNormalDistribution(mean, covariance);
114
115 for (int j = 0; j < samples.length; j++) {
116 float likelyhood = (float) (alpha * multivariateNormalDistribution.density(samples[j]));
117
118 if (clusters[j] < 0 || bestProb[j] < likelyhood) {
119 bestProb[j] = likelyhood;
120 clusters[j] = clusterIndex;
121 }
122 }
123 } catch (NonPositiveDefiniteMatrixException | SingularMatrixException | MathUnsupportedOperationException e) {
124 }
125
126 i += distributionNumParameters;
127 clusterIndex++;
128 }
129
130 return clusters;
131 }
132
133 public static void persistClusters(final float[] x, final float[] y, final int[] cluster, final String filename)
134 throws IOException {
135 Validate.isTrue(x.length == y.length);
136 Validate.isTrue(x.length == cluster.length);
137 logger.info("Saving clusters to CSV: {}", filename);
138
139 final CSVPrinter csvPrinter;
140 try {
141 csvPrinter = CSVFormat.DEFAULT.withAutoFlush(true)
142 .withHeader(new String[] { "cluster", "x", "y" })
143 .print(Path.of(filename), StandardCharsets.UTF_8);
144 } catch (IOException e) {
145 logger.error("Could not open {}", filename, e);
146 throw new RuntimeException("Could not open file " + filename, e);
147 }
148
149 for (int i = 0; i < cluster.length; i++) {
150 try {
151 csvPrinter.printRecord(cluster[i], x[i], y[i]);
152 } catch (IOException e) {
153 throw new RuntimeException("Could not write data", e);
154 }
155 }
156 csvPrinter.close(true);
157 }
158
159
160 public static void persistClusters(final double[] x, final double[] y, final int[] cluster, final String filename)
161 throws IOException {
162 Validate.isTrue(x.length == y.length);
163 Validate.isTrue(x.length == cluster.length);
164 logger.info("Saving clusters to CSV: {}", filename);
165
166 final CSVPrinter csvPrinter;
167 try {
168 csvPrinter = CSVFormat.DEFAULT.withAutoFlush(true)
169 .withHeader(new String[] { "cluster", "x", "y" })
170 .print(Path.of(filename), StandardCharsets.UTF_8);
171 } catch (IOException e) {
172 logger.error("Could not open {}", filename, e);
173 throw new RuntimeException("Could not open file " + filename, e);
174 }
175
176 for (int i = 0; i < cluster.length; i++) {
177 try {
178 csvPrinter.printRecord(cluster[i], x[i], y[i]);
179 } catch (IOException e) {
180 throw new RuntimeException("Could not write data", e);
181 }
182 }
183 csvPrinter.close(true);
184 }
185
186 public static Map<Integer, Individual<FitnessVector<Float>>> groupByNumClusters(final double[][] samplesDouble,
187 final EvolutionResult<FitnessVector<Float>> evolutionResult) {
188 Validate.notEmpty(samplesDouble);
189 Validate.notNull(evolutionResult);
190
191 final Map<Integer, Individual<FitnessVector<Float>>> groups = new TreeMap<>();
192
193 final var listFitnessResult = evolutionResult.fitness();
194 final var populationResult = evolutionResult.population();
195
196 for (int i = 0; i < populationResult.size(); i++) {
197
198 final var genotype = populationResult.get(i);
199 final var fitness = listFitnessResult.get(i);
200
201 groups.compute(
202 Math.round(fitness.get(1)),
203 (k, currentBestIndividual) -> currentBestIndividual == null
204 || currentBestIndividual.fitness().get(0) < fitness.get(0) ? Individual.of(genotype, fitness)
205 : currentBestIndividual);
206 }
207
208 return groups;
209 }
210
211 public static void categorizeByNumClusters(final int distributionNumParameters, final int maxPossibleDistributions,
212 final float[] x, final float[] y, final double[][] samplesDouble,
213 final EvolutionResult<FitnessVector<Float>> evolutionResult, final String baseDir, final String type)
214 throws IOException {
215 Validate.notEmpty(samplesDouble);
216 Validate.notNull(evolutionResult);
217 Validate.notBlank(baseDir);
218 Validate.notBlank(type);
219
220 final var groupedByNumClusters = groupByNumClusters(samplesDouble, evolutionResult);
221 logger.info("Groups:");
222 for (Entry<Integer, Individual<FitnessVector<Float>>> entry : groupedByNumClusters.entrySet()) {
223 final int numUnusedClusters = entry.getKey();
224 final var individual = entry.getValue();
225
226 final int numClusters = maxPossibleDistributions - numUnusedClusters;
227
228 logger.info(
229 "\tNum Clusters: {} - Unused Clusters: {} - Fitness: {}",
230 numClusters,
231 numUnusedClusters,
232 individual.fitness());
233
234 final int[] assignedClusters = ClusteringUtils
235 .assignClustersFloatChromosome(distributionNumParameters, samplesDouble, individual.genotype());
236 final Set<Integer> uniqueAssigned = new HashSet<>();
237 uniqueAssigned.addAll(IntStream.of(assignedClusters).boxed().toList());
238
239 ClusteringUtils.persistClusters(
240 x,
241 y,
242 assignedClusters,
243 baseDir + "assigned-" + type + "-" + uniqueAssigned.size() + ".csv");
244 }
245 }
246
247 public static void writeCSVReferenceValue(final String filename, final int generations, final Number value)
248 throws IOException {
249 Validate.notBlank(filename);
250 Validate.isTrue(generations > 0);
251 Validate.notNull(value);
252
253 final var csvPrinter = CSVFormat.DEFAULT.withAutoFlush(true)
254 .withHeader("generation", "fitness")
255 .print(Path.of(filename), StandardCharsets.UTF_8);
256
257 for (int i = 0; i < generations; i++) {
258 csvPrinter.printRecord(i, value);
259 }
260 csvPrinter.close();
261 }
262 }