View Javadoc
1   package net.bmahe.genetics4j.samples.mixturemodel;
2   
3   import java.io.IOException;
4   import java.nio.charset.StandardCharsets;
5   import java.nio.file.Path;
6   import java.util.HashSet;
7   import java.util.Map;
8   import java.util.Map.Entry;
9   import java.util.Set;
10  import java.util.TreeMap;
11  import java.util.stream.IntStream;
12  
13  import org.apache.commons.csv.CSVFormat;
14  import org.apache.commons.csv.CSVPrinter;
15  import org.apache.commons.lang3.Validate;
16  import org.apache.commons.math3.distribution.MultivariateNormalDistribution;
17  import org.apache.commons.math3.exception.MathUnsupportedOperationException;
18  import org.apache.commons.math3.linear.NonPositiveDefiniteMatrixException;
19  import org.apache.commons.math3.linear.SingularMatrixException;
20  import org.apache.logging.log4j.LogManager;
21  import org.apache.logging.log4j.Logger;
22  
23  import net.bmahe.genetics4j.core.Genotype;
24  import net.bmahe.genetics4j.core.Individual;
25  import net.bmahe.genetics4j.core.chromosomes.DoubleChromosome;
26  import net.bmahe.genetics4j.core.chromosomes.FloatChromosome;
27  import net.bmahe.genetics4j.core.spec.EvolutionResult;
28  import net.bmahe.genetics4j.moo.FitnessVector;
29  
30  public class ClusteringUtils {
31  	final static public Logger logger = LogManager.getLogger(ClusteringUtils.class);
32  
33  	public static int[] assignClustersDoubleChromosome(final int distributionNumParameters, final double[][] samples,
34  			final Genotype genotype) {
35  
36  		final var fChromosome = genotype.getChromosome(0, DoubleChromosome.class);
37  		final int[] clusters = new int[samples.length];
38  		final double[] bestProb = new double[samples.length];
39  
40  		for (int c = 0; c < clusters.length; c++) {
41  			clusters[c] = 0;
42  			bestProb[c] = Double.MIN_VALUE;
43  		}
44  
45  		double sumAlpha = 0.0f;
46  		int k = 0;
47  		while (k < fChromosome.getSize()) {
48  			sumAlpha += fChromosome.getAllele(k);
49  			k += distributionNumParameters;
50  		}
51  
52  		int i = 0;
53  		int clusterIndex = 0;
54  		while (i < fChromosome.getSize()) {
55  
56  			final double alpha = fChromosome.getAllele(i) / sumAlpha;
57  			final double[] mean = new double[] { fChromosome.getAllele(i + 1), fChromosome.getAllele(i + 2) };
58  			final double[][] covariance = new double[][] {
59  					{ fChromosome.getAllele(i + 3) - 15, fChromosome.getAllele(i + 4) - 15 },
60  					{ fChromosome.getAllele(i + 4) - 15, fChromosome.getAllele(i + 5) - 15 } };
61  
62  			try {
63  				final var multivariateNormalDistribution = new MultivariateNormalDistribution(mean, covariance);
64  
65  				for (int j = 0; j < samples.length; j++) {
66  					float likelyhood = (float) (alpha * multivariateNormalDistribution.density(samples[j]));
67  
68  					if (clusters[j] < 0 || bestProb[j] < likelyhood) {
69  						bestProb[j] = likelyhood;
70  						clusters[j] = clusterIndex;
71  					}
72  				}
73  			} catch (NonPositiveDefiniteMatrixException | SingularMatrixException | MathUnsupportedOperationException e) {
74  			}
75  
76  			i += distributionNumParameters;
77  			clusterIndex++;
78  		}
79  
80  		return clusters;
81  	}
82  
83  	public static int[] assignClustersFloatChromosome(final int distributionNumParameters, final double[][] samples,
84  			final Genotype genotype) {
85  
86  		final var fChromosome = genotype.getChromosome(0, FloatChromosome.class);
87  		final int[] clusters = new int[samples.length];
88  		final double[] bestProb = new double[samples.length];
89  
90  		for (int c = 0; c < clusters.length; c++) {
91  			clusters[c] = 0;
92  			bestProb[c] = Double.MIN_VALUE;
93  		}
94  
95  		double sumAlpha = 0.0f;
96  		int k = 0;
97  		while (k < fChromosome.getSize()) {
98  			sumAlpha += fChromosome.getAllele(k);
99  			k += distributionNumParameters;
100 		}
101 
102 		int i = 0;
103 		int clusterIndex = 0;
104 		while (i < fChromosome.getSize()) {
105 
106 			final double alpha = fChromosome.getAllele(i) / sumAlpha;
107 			final double[] mean = new double[] { fChromosome.getAllele(i + 1), fChromosome.getAllele(i + 2) };
108 			final double[][] covariance = new double[][] {
109 					{ fChromosome.getAllele(i + 3) - 15, fChromosome.getAllele(i + 4) - 15 },
110 					{ fChromosome.getAllele(i + 4) - 15, fChromosome.getAllele(i + 5) - 15 } };
111 
112 			try {
113 				final var multivariateNormalDistribution = new MultivariateNormalDistribution(mean, covariance);
114 
115 				for (int j = 0; j < samples.length; j++) {
116 					float likelyhood = (float) (alpha * multivariateNormalDistribution.density(samples[j]));
117 
118 					if (clusters[j] < 0 || bestProb[j] < likelyhood) {
119 						bestProb[j] = likelyhood;
120 						clusters[j] = clusterIndex;
121 					}
122 				}
123 			} catch (NonPositiveDefiniteMatrixException | SingularMatrixException | MathUnsupportedOperationException e) {
124 			}
125 
126 			i += distributionNumParameters;
127 			clusterIndex++;
128 		}
129 
130 		return clusters;
131 	}
132 
133 	public static void persistClusters(final float[] x, final float[] y, final int[] cluster, final String filename)
134 			throws IOException {
135 		Validate.isTrue(x.length == y.length);
136 		Validate.isTrue(x.length == cluster.length);
137 		logger.info("Saving clusters to CSV: {}", filename);
138 
139 		final CSVPrinter csvPrinter;
140 		try {
141 			csvPrinter = CSVFormat.DEFAULT.withAutoFlush(true)
142 					.withHeader(new String[] { "cluster", "x", "y" })
143 					.print(Path.of(filename), StandardCharsets.UTF_8);
144 		} catch (IOException e) {
145 			logger.error("Could not open {}", filename, e);
146 			throw new RuntimeException("Could not open file " + filename, e);
147 		}
148 
149 		for (int i = 0; i < cluster.length; i++) {
150 			try {
151 				csvPrinter.printRecord(cluster[i], x[i], y[i]);
152 			} catch (IOException e) {
153 				throw new RuntimeException("Could not write data", e);
154 			}
155 		}
156 		csvPrinter.close(true);
157 	}
158 
159 	// TODO fix duplication
160 	public static void persistClusters(final double[] x, final double[] y, final int[] cluster, final String filename)
161 			throws IOException {
162 		Validate.isTrue(x.length == y.length);
163 		Validate.isTrue(x.length == cluster.length);
164 		logger.info("Saving clusters to CSV: {}", filename);
165 
166 		final CSVPrinter csvPrinter;
167 		try {
168 			csvPrinter = CSVFormat.DEFAULT.withAutoFlush(true)
169 					.withHeader(new String[] { "cluster", "x", "y" })
170 					.print(Path.of(filename), StandardCharsets.UTF_8);
171 		} catch (IOException e) {
172 			logger.error("Could not open {}", filename, e);
173 			throw new RuntimeException("Could not open file " + filename, e);
174 		}
175 
176 		for (int i = 0; i < cluster.length; i++) {
177 			try {
178 				csvPrinter.printRecord(cluster[i], x[i], y[i]);
179 			} catch (IOException e) {
180 				throw new RuntimeException("Could not write data", e);
181 			}
182 		}
183 		csvPrinter.close(true);
184 	}
185 
186 	public static Map<Integer, Individual<FitnessVector<Float>>> groupByNumClusters(final double[][] samplesDouble,
187 			final EvolutionResult<FitnessVector<Float>> evolutionResult) {
188 		Validate.notEmpty(samplesDouble);
189 		Validate.notNull(evolutionResult);
190 
191 		final Map<Integer, Individual<FitnessVector<Float>>> groups = new TreeMap<>();
192 
193 		final var listFitnessResult = evolutionResult.fitness();
194 		final var populationResult = evolutionResult.population();
195 
196 		for (int i = 0; i < populationResult.size(); i++) {
197 
198 			final var genotype = populationResult.get(i);
199 			final var fitness = listFitnessResult.get(i);
200 
201 			groups.compute(
202 					Math.round(fitness.get(1)),
203 						(k, currentBestIndividual) -> currentBestIndividual == null
204 								|| currentBestIndividual.fitness().get(0) < fitness.get(0) ? Individual.of(genotype, fitness)
205 										: currentBestIndividual);
206 		}
207 
208 		return groups;
209 	}
210 
211 	public static void categorizeByNumClusters(final int distributionNumParameters, final int maxPossibleDistributions,
212 			final float[] x, final float[] y, final double[][] samplesDouble,
213 			final EvolutionResult<FitnessVector<Float>> evolutionResult, final String baseDir, final String type)
214 			throws IOException {
215 		Validate.notEmpty(samplesDouble);
216 		Validate.notNull(evolutionResult);
217 		Validate.notBlank(baseDir);
218 		Validate.notBlank(type);
219 
220 		final var groupedByNumClusters = groupByNumClusters(samplesDouble, evolutionResult);
221 		logger.info("Groups:");
222 		for (Entry<Integer, Individual<FitnessVector<Float>>> entry : groupedByNumClusters.entrySet()) {
223 			final int numUnusedClusters = entry.getKey();
224 			final var individual = entry.getValue();
225 
226 			final int numClusters = maxPossibleDistributions - numUnusedClusters;
227 
228 			logger.info(
229 					"\tNum Clusters: {} - Unused Clusters: {} - Fitness: {}",
230 						numClusters,
231 						numUnusedClusters,
232 						individual.fitness());
233 
234 			final int[] assignedClusters = ClusteringUtils
235 					.assignClustersFloatChromosome(distributionNumParameters, samplesDouble, individual.genotype());
236 			final Set<Integer> uniqueAssigned = new HashSet<>();
237 			uniqueAssigned.addAll(IntStream.of(assignedClusters).boxed().toList());
238 
239 			ClusteringUtils.persistClusters(
240 					x,
241 						y,
242 						assignedClusters,
243 						baseDir + "assigned-" + type + "-" + uniqueAssigned.size() + ".csv");
244 		}
245 	}
246 
247 	public static void writeCSVReferenceValue(final String filename, final int generations, final Number value)
248 			throws IOException {
249 		Validate.notBlank(filename);
250 		Validate.isTrue(generations > 0);
251 		Validate.notNull(value);
252 
253 		final var csvPrinter = CSVFormat.DEFAULT.withAutoFlush(true)
254 				.withHeader("generation", "fitness")
255 				.print(Path.of(filename), StandardCharsets.UTF_8);
256 
257 		for (int i = 0; i < generations; i++) {
258 			csvPrinter.printRecord(i, value);
259 		}
260 		csvPrinter.close();
261 	}
262 }