View Javadoc
1   package net.bmahe.genetics4j.samples.mixturemodel;
2   
3   import java.io.IOException;
4   import java.nio.charset.StandardCharsets;
5   import java.nio.file.Path;
6   import java.util.HashSet;
7   import java.util.Map;
8   import java.util.Map.Entry;
9   import java.util.Set;
10  import java.util.TreeMap;
11  import java.util.stream.IntStream;
12  
13  import org.apache.commons.csv.CSVFormat;
14  import org.apache.commons.csv.CSVPrinter;
15  import org.apache.commons.lang3.Validate;
16  import org.apache.commons.math3.distribution.MultivariateNormalDistribution;
17  import org.apache.commons.math3.exception.MathUnsupportedOperationException;
18  import org.apache.commons.math3.linear.NonPositiveDefiniteMatrixException;
19  import org.apache.commons.math3.linear.SingularMatrixException;
20  import org.apache.logging.log4j.LogManager;
21  import org.apache.logging.log4j.Logger;
22  
23  import net.bmahe.genetics4j.core.Genotype;
24  import net.bmahe.genetics4j.core.Individual;
25  import net.bmahe.genetics4j.core.chromosomes.DoubleChromosome;
26  import net.bmahe.genetics4j.core.chromosomes.FloatChromosome;
27  import net.bmahe.genetics4j.core.spec.EvolutionResult;
28  import net.bmahe.genetics4j.moo.FitnessVector;
29  
30  public class ClusteringUtils {
31  	final static public Logger logger = LogManager.getLogger(ClusteringUtils.class);
32  
33  	public static int[] assignClustersDoubleChromosome(final int distributionNumParameters, final double[][] samples,
34  			final Genotype genotype) {
35  
36  		final var fChromosome = genotype.getChromosome(0, DoubleChromosome.class);
37  		final int[] clusters = new int[samples.length];
38  		final double[] bestProb = new double[samples.length];
39  
40  		for (int c = 0; c < clusters.length; c++) {
41  			clusters[c] = 0;
42  			bestProb[c] = Double.MIN_VALUE;
43  		}
44  
45  		double sumAlpha = 0.0f;
46  		int k = 0;
47  		while (k < fChromosome.getSize()) {
48  			sumAlpha += fChromosome.getAllele(k);
49  			k += distributionNumParameters;
50  		}
51  
52  		int i = 0;
53  		int clusterIndex = 0;
54  		while (i < fChromosome.getSize()) {
55  
56  			final double alpha = fChromosome.getAllele(i) / sumAlpha;
57  			final double[] mean = new double[] { fChromosome.getAllele(i + 1), fChromosome.getAllele(i + 2) };
58  			final double[][] covariance = new double[][] {
59  					{ fChromosome.getAllele(i + 3) - 15, fChromosome.getAllele(i + 4) - 15 },
60  					{ fChromosome.getAllele(i + 4) - 15, fChromosome.getAllele(i + 5) - 15 } };
61  
62  			try {
63  				final var multivariateNormalDistribution = new MultivariateNormalDistribution(mean, covariance);
64  
65  				for (int j = 0; j < samples.length; j++) {
66  					float likelyhood = (float) (alpha * multivariateNormalDistribution.density(samples[j]));
67  
68  					if (clusters[j] < 0 || bestProb[j] < likelyhood) {
69  						bestProb[j] = likelyhood;
70  						clusters[j] = clusterIndex;
71  					}
72  				}
73  			} catch (NonPositiveDefiniteMatrixException | SingularMatrixException | MathUnsupportedOperationException e) {
74  			}
75  
76  			i += distributionNumParameters;
77  			clusterIndex++;
78  		}
79  
80  		return clusters;
81  	}
82  
83  	public static int[] assignClustersFloatChromosome(final int distributionNumParameters, final double[][] samples,
84  			final Genotype genotype) {
85  
86  		final var fChromosome = genotype.getChromosome(0, FloatChromosome.class);
87  		final int[] clusters = new int[samples.length];
88  		final double[] bestProb = new double[samples.length];
89  
90  		for (int c = 0; c < clusters.length; c++) {
91  			clusters[c] = 0;
92  			bestProb[c] = Double.MIN_VALUE;
93  		}
94  
95  		double sumAlpha = 0.0f;
96  		int k = 0;
97  		while (k < fChromosome.getSize()) {
98  			sumAlpha += fChromosome.getAllele(k);
99  			k += distributionNumParameters;
100 		}
101 
102 		int i = 0;
103 		int clusterIndex = 0;
104 		while (i < fChromosome.getSize()) {
105 
106 			final double alpha = fChromosome.getAllele(i) / sumAlpha;
107 			final double[] mean = new double[] { fChromosome.getAllele(i + 1), fChromosome.getAllele(i + 2) };
108 			final double[][] covariance = new double[][] {
109 					{ fChromosome.getAllele(i + 3) - 15, fChromosome.getAllele(i + 4) - 15 },
110 					{ fChromosome.getAllele(i + 4) - 15, fChromosome.getAllele(i + 5) - 15 } };
111 
112 			try {
113 				final var multivariateNormalDistribution = new MultivariateNormalDistribution(mean, covariance);
114 
115 				for (int j = 0; j < samples.length; j++) {
116 					float likelyhood = (float) (alpha * multivariateNormalDistribution.density(samples[j]));
117 
118 					if (clusters[j] < 0 || bestProb[j] < likelyhood) {
119 						bestProb[j] = likelyhood;
120 						clusters[j] = clusterIndex;
121 					}
122 				}
123 			} catch (NonPositiveDefiniteMatrixException | SingularMatrixException | MathUnsupportedOperationException e) {
124 			}
125 
126 			i += distributionNumParameters;
127 			clusterIndex++;
128 		}
129 
130 		return clusters;
131 	}
132 
133 	public static void persistClusters(final float[] x, final float[] y, final int[] cluster, final String filename)
134 			throws IOException {
135 		Validate.isTrue(x.length == y.length);
136 		Validate.isTrue(x.length == cluster.length);
137 		logger.info("Saving clusters to CSV: {}", filename);
138 
139 		final CSVPrinter csvPrinter;
140 		try {
141 			csvPrinter = CSVFormat.DEFAULT.withAutoFlush(true)
142 					.withHeader(new String[] { "cluster", "x", "y" })
143 					.print(Path.of(filename), StandardCharsets.UTF_8);
144 		} catch (IOException e) {
145 			logger.error("Could not open {}", filename, e);
146 			throw new RuntimeException("Could not open file " + filename, e);
147 		}
148 
149 		for (int i = 0; i < cluster.length; i++) {
150 			try {
151 				csvPrinter.printRecord(cluster[i], x[i], y[i]);
152 			} catch (IOException e) {
153 				throw new RuntimeException("Could not write data", e);
154 			}
155 		}
156 		csvPrinter.close(true);
157 	}
158 
159 	// TODO fix duplication
160 	public static void persistClusters(final double[] x, final double[] y, final int[] cluster, final String filename)
161 			throws IOException {
162 		Validate.isTrue(x.length == y.length);
163 		Validate.isTrue(x.length == cluster.length);
164 		logger.info("Saving clusters to CSV: {}", filename);
165 
166 		final CSVPrinter csvPrinter;
167 		try {
168 			csvPrinter = CSVFormat.DEFAULT.withAutoFlush(true)
169 					.withHeader(new String[] { "cluster", "x", "y" })
170 					.print(Path.of(filename), StandardCharsets.UTF_8);
171 		} catch (IOException e) {
172 			logger.error("Could not open {}", filename, e);
173 			throw new RuntimeException("Could not open file " + filename, e);
174 		}
175 
176 		for (int i = 0; i < cluster.length; i++) {
177 			try {
178 				csvPrinter.printRecord(cluster[i], x[i], y[i]);
179 			} catch (IOException e) {
180 				throw new RuntimeException("Could not write data", e);
181 			}
182 		}
183 		csvPrinter.close(true);
184 	}
185 
186 	public static Map<Integer, Individual<FitnessVector<Float>>> groupByNumClusters(final double[][] samplesDouble,
187 			final EvolutionResult<FitnessVector<Float>> evolutionResult) {
188 		Validate.notEmpty(samplesDouble);
189 		Validate.notNull(evolutionResult);
190 
191 		final Map<Integer, Individual<FitnessVector<Float>>> groups = new TreeMap<>();
192 
193 		final var listFitnessResult = evolutionResult.fitness();
194 		final var populationResult = evolutionResult.population();
195 
196 		for (int i = 0; i < populationResult.size(); i++) {
197 
198 			final var genotype = populationResult.get(i);
199 			final var fitness = listFitnessResult.get(i);
200 
201 			groups.compute(Math.round(fitness.get(1)),
202 					(k, currentBestIndividual) -> currentBestIndividual == null || currentBestIndividual.fitness()
203 							.get(0) < fitness.get(0) ? Individual.of(genotype, fitness) : currentBestIndividual);
204 		}
205 
206 		return groups;
207 	}
208 
209 	public static void categorizeByNumClusters(final int distributionNumParameters, final int maxPossibleDistributions,
210 			final float[] x, final float[] y, final double[][] samplesDouble,
211 			final EvolutionResult<FitnessVector<Float>> evolutionResult, final String baseDir, final String type)
212 			throws IOException {
213 		Validate.notEmpty(samplesDouble);
214 		Validate.notNull(evolutionResult);
215 		Validate.notBlank(baseDir);
216 		Validate.notBlank(type);
217 
218 		final var groupedByNumClusters = groupByNumClusters(samplesDouble, evolutionResult);
219 		logger.info("Groups:");
220 		for (Entry<Integer, Individual<FitnessVector<Float>>> entry : groupedByNumClusters.entrySet()) {
221 			final int numUnusedClusters = entry.getKey();
222 			final var individual = entry.getValue();
223 
224 			final int numClusters = maxPossibleDistributions - numUnusedClusters;
225 
226 			logger.info("\tNum Clusters: {} - Unused Clusters: {} - Fitness: {}",
227 					numClusters,
228 					numUnusedClusters,
229 					individual.fitness());
230 
231 			final int[] assignedClusters = ClusteringUtils
232 					.assignClustersFloatChromosome(distributionNumParameters, samplesDouble, individual.genotype());
233 			final Set<Integer> uniqueAssigned = new HashSet<>();
234 			uniqueAssigned.addAll(IntStream.of(assignedClusters)
235 					.boxed()
236 					.toList());
237 
238 			ClusteringUtils.persistClusters(x,
239 					y,
240 					assignedClusters,
241 					baseDir + "assigned-" + type + "-" + uniqueAssigned.size() + ".csv");
242 		}
243 	}
244 
245 	public static void writeCSVReferenceValue(final String filename, final int generations, final Number value)
246 			throws IOException {
247 		Validate.notBlank(filename);
248 		Validate.isTrue(generations > 0);
249 		Validate.notNull(value);
250 
251 		final var csvPrinter = CSVFormat.DEFAULT.withAutoFlush(true)
252 				.withHeader("generation", "fitness")
253 				.print(Path.of(filename), StandardCharsets.UTF_8);
254 
255 		for (int i = 0; i < generations; i++) {
256 			csvPrinter.printRecord(i, value);
257 		}
258 		csvPrinter.close();
259 	}
260 }