1 package net.bmahe.genetics4j.samples.mixturemodel;
2
3 import java.io.IOException;
4 import java.nio.charset.StandardCharsets;
5 import java.nio.file.Path;
6 import java.util.HashSet;
7 import java.util.Map;
8 import java.util.Map.Entry;
9 import java.util.Set;
10 import java.util.TreeMap;
11 import java.util.stream.IntStream;
12
13 import org.apache.commons.csv.CSVFormat;
14 import org.apache.commons.csv.CSVPrinter;
15 import org.apache.commons.lang3.Validate;
16 import org.apache.commons.math3.distribution.MultivariateNormalDistribution;
17 import org.apache.commons.math3.exception.MathUnsupportedOperationException;
18 import org.apache.commons.math3.linear.NonPositiveDefiniteMatrixException;
19 import org.apache.commons.math3.linear.SingularMatrixException;
20 import org.apache.logging.log4j.LogManager;
21 import org.apache.logging.log4j.Logger;
22
23 import net.bmahe.genetics4j.core.Genotype;
24 import net.bmahe.genetics4j.core.Individual;
25 import net.bmahe.genetics4j.core.chromosomes.DoubleChromosome;
26 import net.bmahe.genetics4j.core.chromosomes.FloatChromosome;
27 import net.bmahe.genetics4j.core.spec.EvolutionResult;
28 import net.bmahe.genetics4j.moo.FitnessVector;
29
30 public class ClusteringUtils {
31 final static public Logger logger = LogManager.getLogger(ClusteringUtils.class);
32
33 public static int[] assignClustersDoubleChromosome(final int distributionNumParameters, final double[][] samples,
34 final Genotype genotype) {
35
36 final var fChromosome = genotype.getChromosome(0, DoubleChromosome.class);
37 final int[] clusters = new int[samples.length];
38 final double[] bestProb = new double[samples.length];
39
40 for (int c = 0; c < clusters.length; c++) {
41 clusters[c] = 0;
42 bestProb[c] = Double.MIN_VALUE;
43 }
44
45 double sumAlpha = 0.0f;
46 int k = 0;
47 while (k < fChromosome.getSize()) {
48 sumAlpha += fChromosome.getAllele(k);
49 k += distributionNumParameters;
50 }
51
52 int i = 0;
53 int clusterIndex = 0;
54 while (i < fChromosome.getSize()) {
55
56 final double alpha = fChromosome.getAllele(i) / sumAlpha;
57 final double[] mean = new double[] { fChromosome.getAllele(i + 1), fChromosome.getAllele(i + 2) };
58 final double[][] covariance = new double[][] {
59 { fChromosome.getAllele(i + 3) - 15, fChromosome.getAllele(i + 4) - 15 },
60 { fChromosome.getAllele(i + 4) - 15, fChromosome.getAllele(i + 5) - 15 } };
61
62 try {
63 final var multivariateNormalDistribution = new MultivariateNormalDistribution(mean, covariance);
64
65 for (int j = 0; j < samples.length; j++) {
66 float likelyhood = (float) (alpha * multivariateNormalDistribution.density(samples[j]));
67
68 if (clusters[j] < 0 || bestProb[j] < likelyhood) {
69 bestProb[j] = likelyhood;
70 clusters[j] = clusterIndex;
71 }
72 }
73 } catch (NonPositiveDefiniteMatrixException | SingularMatrixException | MathUnsupportedOperationException e) {
74 }
75
76 i += distributionNumParameters;
77 clusterIndex++;
78 }
79
80 return clusters;
81 }
82
83 public static int[] assignClustersFloatChromosome(final int distributionNumParameters, final double[][] samples,
84 final Genotype genotype) {
85
86 final var fChromosome = genotype.getChromosome(0, FloatChromosome.class);
87 final int[] clusters = new int[samples.length];
88 final double[] bestProb = new double[samples.length];
89
90 for (int c = 0; c < clusters.length; c++) {
91 clusters[c] = 0;
92 bestProb[c] = Double.MIN_VALUE;
93 }
94
95 double sumAlpha = 0.0f;
96 int k = 0;
97 while (k < fChromosome.getSize()) {
98 sumAlpha += fChromosome.getAllele(k);
99 k += distributionNumParameters;
100 }
101
102 int i = 0;
103 int clusterIndex = 0;
104 while (i < fChromosome.getSize()) {
105
106 final double alpha = fChromosome.getAllele(i) / sumAlpha;
107 final double[] mean = new double[] { fChromosome.getAllele(i + 1), fChromosome.getAllele(i + 2) };
108 final double[][] covariance = new double[][] {
109 { fChromosome.getAllele(i + 3) - 15, fChromosome.getAllele(i + 4) - 15 },
110 { fChromosome.getAllele(i + 4) - 15, fChromosome.getAllele(i + 5) - 15 } };
111
112 try {
113 final var multivariateNormalDistribution = new MultivariateNormalDistribution(mean, covariance);
114
115 for (int j = 0; j < samples.length; j++) {
116 float likelyhood = (float) (alpha * multivariateNormalDistribution.density(samples[j]));
117
118 if (clusters[j] < 0 || bestProb[j] < likelyhood) {
119 bestProb[j] = likelyhood;
120 clusters[j] = clusterIndex;
121 }
122 }
123 } catch (NonPositiveDefiniteMatrixException | SingularMatrixException | MathUnsupportedOperationException e) {
124 }
125
126 i += distributionNumParameters;
127 clusterIndex++;
128 }
129
130 return clusters;
131 }
132
133 public static void persistClusters(final float[] x, final float[] y, final int[] cluster, final String filename)
134 throws IOException {
135 Validate.isTrue(x.length == y.length);
136 Validate.isTrue(x.length == cluster.length);
137 logger.info("Saving clusters to CSV: {}", filename);
138
139 final CSVPrinter csvPrinter;
140 try {
141 csvPrinter = CSVFormat.DEFAULT.withAutoFlush(true)
142 .withHeader(new String[] { "cluster", "x", "y" })
143 .print(Path.of(filename), StandardCharsets.UTF_8);
144 } catch (IOException e) {
145 logger.error("Could not open {}", filename, e);
146 throw new RuntimeException("Could not open file " + filename, e);
147 }
148
149 for (int i = 0; i < cluster.length; i++) {
150 try {
151 csvPrinter.printRecord(cluster[i], x[i], y[i]);
152 } catch (IOException e) {
153 throw new RuntimeException("Could not write data", e);
154 }
155 }
156 csvPrinter.close(true);
157 }
158
159
160 public static void persistClusters(final double[] x, final double[] y, final int[] cluster, final String filename)
161 throws IOException {
162 Validate.isTrue(x.length == y.length);
163 Validate.isTrue(x.length == cluster.length);
164 logger.info("Saving clusters to CSV: {}", filename);
165
166 final CSVPrinter csvPrinter;
167 try {
168 csvPrinter = CSVFormat.DEFAULT.withAutoFlush(true)
169 .withHeader(new String[] { "cluster", "x", "y" })
170 .print(Path.of(filename), StandardCharsets.UTF_8);
171 } catch (IOException e) {
172 logger.error("Could not open {}", filename, e);
173 throw new RuntimeException("Could not open file " + filename, e);
174 }
175
176 for (int i = 0; i < cluster.length; i++) {
177 try {
178 csvPrinter.printRecord(cluster[i], x[i], y[i]);
179 } catch (IOException e) {
180 throw new RuntimeException("Could not write data", e);
181 }
182 }
183 csvPrinter.close(true);
184 }
185
186 public static Map<Integer, Individual<FitnessVector<Float>>> groupByNumClusters(final double[][] samplesDouble,
187 final EvolutionResult<FitnessVector<Float>> evolutionResult) {
188 Validate.notEmpty(samplesDouble);
189 Validate.notNull(evolutionResult);
190
191 final Map<Integer, Individual<FitnessVector<Float>>> groups = new TreeMap<>();
192
193 final var listFitnessResult = evolutionResult.fitness();
194 final var populationResult = evolutionResult.population();
195
196 for (int i = 0; i < populationResult.size(); i++) {
197
198 final var genotype = populationResult.get(i);
199 final var fitness = listFitnessResult.get(i);
200
201 groups.compute(Math.round(fitness.get(1)),
202 (k, currentBestIndividual) -> currentBestIndividual == null || currentBestIndividual.fitness()
203 .get(0) < fitness.get(0) ? Individual.of(genotype, fitness) : currentBestIndividual);
204 }
205
206 return groups;
207 }
208
209 public static void categorizeByNumClusters(final int distributionNumParameters, final int maxPossibleDistributions,
210 final float[] x, final float[] y, final double[][] samplesDouble,
211 final EvolutionResult<FitnessVector<Float>> evolutionResult, final String baseDir, final String type)
212 throws IOException {
213 Validate.notEmpty(samplesDouble);
214 Validate.notNull(evolutionResult);
215 Validate.notBlank(baseDir);
216 Validate.notBlank(type);
217
218 final var groupedByNumClusters = groupByNumClusters(samplesDouble, evolutionResult);
219 logger.info("Groups:");
220 for (Entry<Integer, Individual<FitnessVector<Float>>> entry : groupedByNumClusters.entrySet()) {
221 final int numUnusedClusters = entry.getKey();
222 final var individual = entry.getValue();
223
224 final int numClusters = maxPossibleDistributions - numUnusedClusters;
225
226 logger.info("\tNum Clusters: {} - Unused Clusters: {} - Fitness: {}",
227 numClusters,
228 numUnusedClusters,
229 individual.fitness());
230
231 final int[] assignedClusters = ClusteringUtils
232 .assignClustersFloatChromosome(distributionNumParameters, samplesDouble, individual.genotype());
233 final Set<Integer> uniqueAssigned = new HashSet<>();
234 uniqueAssigned.addAll(IntStream.of(assignedClusters)
235 .boxed()
236 .toList());
237
238 ClusteringUtils.persistClusters(x,
239 y,
240 assignedClusters,
241 baseDir + "assigned-" + type + "-" + uniqueAssigned.size() + ".csv");
242 }
243 }
244
245 public static void writeCSVReferenceValue(final String filename, final int generations, final Number value)
246 throws IOException {
247 Validate.notBlank(filename);
248 Validate.isTrue(generations > 0);
249 Validate.notNull(value);
250
251 final var csvPrinter = CSVFormat.DEFAULT.withAutoFlush(true)
252 .withHeader("generation", "fitness")
253 .print(Path.of(filename), StandardCharsets.UTF_8);
254
255 for (int i = 0; i < generations; i++) {
256 csvPrinter.printRecord(i, value);
257 }
258 csvPrinter.close();
259 }
260 }