1 package net.bmahe.genetics4j.samples.clustering;
2
3 import static net.bmahe.genetics4j.core.termination.Terminations.or;
4
5 import java.io.IOException;
6 import java.util.ArrayList;
7 import java.util.List;
8 import java.util.Optional;
9 import java.util.Random;
10
11 import org.apache.commons.cli.CommandLine;
12 import org.apache.commons.cli.CommandLineParser;
13 import org.apache.commons.cli.DefaultParser;
14 import org.apache.commons.cli.Options;
15 import org.apache.commons.cli.ParseException;
16 import org.apache.commons.lang3.Validate;
17 import org.apache.commons.lang3.time.DurationFormatUtils;
18 import org.apache.commons.math3.ml.clustering.CentroidCluster;
19 import org.apache.commons.math3.ml.clustering.KMeansPlusPlusClusterer;
20 import org.apache.logging.log4j.LogManager;
21 import org.apache.logging.log4j.Logger;
22
23 import net.bmahe.genetics4j.core.EASystemFactory;
24 import net.bmahe.genetics4j.core.Fitness;
25 import net.bmahe.genetics4j.core.Genotype;
26 import net.bmahe.genetics4j.core.chromosomes.DoubleChromosome;
27 import net.bmahe.genetics4j.core.evolutionlisteners.EvolutionListeners;
28 import net.bmahe.genetics4j.core.spec.EAConfiguration;
29 import net.bmahe.genetics4j.core.spec.EAExecutionContexts;
30 import net.bmahe.genetics4j.core.spec.chromosome.DoubleChromosomeSpec;
31 import net.bmahe.genetics4j.core.spec.combination.MultiCombinations;
32 import net.bmahe.genetics4j.core.spec.combination.MultiPointArithmetic;
33 import net.bmahe.genetics4j.core.spec.combination.MultiPointCrossover;
34 import net.bmahe.genetics4j.core.spec.mutation.CreepMutation;
35 import net.bmahe.genetics4j.core.spec.mutation.MultiMutations;
36 import net.bmahe.genetics4j.core.spec.mutation.RandomMutation;
37 import net.bmahe.genetics4j.core.spec.selection.Tournament;
38 import net.bmahe.genetics4j.core.termination.Termination;
39 import net.bmahe.genetics4j.core.termination.Terminations;
40 import net.bmahe.genetics4j.extras.evolutionlisteners.CSVEvolutionListener;
41 import net.bmahe.genetics4j.extras.evolutionlisteners.ColumnExtractor;
42 import net.bmahe.genetics4j.samples.CLIUtils;
43
44 public class Clustering {
45 final static public Logger logger = LogManager.getLogger(Clustering.class);
46
47 final static public int DEFAULT_NUM_CLUSTERS = 6;
48 final static public int DEFAULT_NUMBER_TOURNAMENTS = 2;
49 final static public int DEFAULT_POPULATION_SIZE = 120;
50 final static public double DEFAULT_RANDOM_MUTATION_RATE = 0.15d;
51 final static public double DEFAULT_CREEP_MUTATION_RATE = 0.20d;
52 final static public double DEFAULT_CREEP_MUTATION_MEAN = 0.0d;
53 final static public double DEFAULT_CREEP_MUTATION_STDDEV = 5;
54 final static public int DEFAULT_COMBINATION_ARITHMETIC = 3;
55 final static public int DEFAULT_COMBINATION_CROSSOVER_ARITHMETIC = 3;
56
57 final static public String PARAM_NUM_CLUSTERS = "n";
58 final static public String LONG_PARAM_NUM_CLUSTERS = "num-clusters";
59
60 final static public String PARAM_NUMBER_TOURNAMENTS = "t";
61 final static public String LONG_PARAM_NUMBER_TOURNAMENTS = "num-tournaments";
62
63 final static public String PARAM_POPULATION_SIZE = "p";
64 final static public String LONG_PARAM_POPULATION_SIZE = "population-size";
65
66 final static public String PARAM_SOURCE_CLUSTERS_CSV = "c";
67 final static public String LONG_PARAM_SOURCE_CUSTERS_CSV = "source-clusters";
68
69 final static public String PARAM_SOURCE_DATA_CSV = "s";
70 final static public String LONG_PARAM_SOURCE_DATA_CSV = "source-data";
71
72 final static public String PARAM_FIXED_TERMINATION = "f";
73 final static public String LONG_PARAM_FIXED_TERMINATION = "fixed-termination";
74
75 final static public String PARAM_RANDOM_MUTATION_RATE = "r";
76 final static public String LONG_PARAM_RANDOM_MUTATION_RATE = "random-mutation-rate";
77
78 final static public String PARAM_CREEP_MUTATION_RATE = "m";
79 final static public String LONG_PARAM_CREEP_MUTATION_RATE = "creep-mutation-rate";
80
81 final static public String PARAM_CREEP_MUTATION_MEAN = "a";
82 final static public String LONG_PARAM_CREEP_MUTATION_MEAN = "creep-mutation-mean";
83
84 final static public String PARAM_CREEP_MUTATION_STD_DEV = "d";
85 final static public String LONG_PARAM_CREEP_MUTATION_STD_DEV = "creep-mutation-std-dev";
86
87 final static public String PARAM_COMBINATION_ARITHMETIC = "b";
88 final static public String LONG_PARAM_COMBINATION_ARITHMETIC = "combination-arithmetic";
89
90 final static public String PARAM_COMBINATION_CROSSOVER = "e";
91 final static public String LONG_PARAM_COMBINATION_CROSSOVER = "combination-crossover";
92
93 final static public String PARAM_OUTPUT_CSV = "o";
94 final static public String LONG_PARAM_OUTPUT_CSV = "output";
95
96 final static public String PARAM_OUTPUT_WITH_SSE_CSV = "g";
97 final static public String LONG_PARAM_OUTPUT_WITH_SSE_CSV = "output-sse";
98
99 final static public String PARAM_BASE_DIR_OUTPUT = "i";
100 final static public String LONG_PARAM_BASE_DIR_OUTPUT = "base-dir";
101
102 public static void cliError(final Options options, final String errorMessage) {
103 CLIUtils.cliHelpAndExit(logger, Clustering.class, options, errorMessage);
104 }
105
106 private final static double computeDistance(final double[][] array, final int i, final int j) {
107 final double xDiff = array[j][0] - array[i][0];
108 final double yDiff = array[j][1] - array[i][1];
109 return Math.sqrt((xDiff * xDiff) + (yDiff * yDiff));
110 }
111
112 private final static double[][] computeAllDistances(final double[][] array) {
113
114 final double[][] distances = new double[array.length][array.length];
115
116 for (int i = 0; i < array.length; i++) {
117 distances[i][i] = 0.0;
118 }
119
120 for (int i = 0; i < array.length; i++) {
121 for (int j = 0; j < i; j++) {
122 final double distance = computeDistance(array, i, j);
123 distances[i][j] = distance;
124 distances[j][i] = distance;
125 }
126 }
127
128 return distances;
129 }
130
131
132 public static double[][] generateClusters(final Random random, final int numClusters, final double minX,
133 final double maxX, final double minY, final double maxY) {
134 Validate.notNull(random);
135 Validate.isTrue(numClusters > 0);
136 Validate.isTrue(minX <= maxX);
137 Validate.isTrue(minY <= maxY);
138
139 logger.info("Generating {} clusters", numClusters);
140
141 final double[][] clusters = new double[numClusters][2];
142 for (int i = 0; i < numClusters; i++) {
143 clusters[i][0] = minX + random.nextDouble() * (maxX - minX);
144 clusters[i][1] = minY + random.nextDouble() * (maxY - minY);
145 }
146
147 return clusters;
148 }
149
150
151
152 public static double[][] generateDataPoints(final Random random, final double[][] clusters, final int numDataPoints,
153 final double radius) {
154 Validate.notNull(random);
155 Validate.notNull(clusters);
156 Validate.isTrue(clusters.length > 0);
157
158 final int numClusters = clusters.length;
159 final double[][] data = new double[numDataPoints][3];
160 for (int i = 0; i < numDataPoints; i++) {
161 final int clusterIndex = i % numClusters;
162
163 data[i][0] = random.nextGaussian() * radius + clusters[clusterIndex][0];
164 data[i][1] = random.nextGaussian() * radius + clusters[clusterIndex][1];
165 data[i][2] = clusterIndex;
166 }
167
168 return data;
169 }
170
171
172 public static void doGA(final int k, final double min, final double max, final int numberTournaments,
173 final int combinationArithmetic, final int combinationCrossover, final double randomMutationRate,
174 final double creepMutationRate, final double creepMutationMean, final double creepMutationStdDev,
175 final Fitness<Double> fitnessFunction, final Termination<Double> terminations, final int populationSize,
176 final String outputCSV, final double[][] data, final double[][] distances, final String baseDir,
177 final String filenameSuffix) throws IOException {
178
179
180 final var eaConfigurationBuilder = new EAConfiguration.Builder<Double>();
181 eaConfigurationBuilder.chromosomeSpecs(DoubleChromosomeSpec.of(k * 2, min, max))
182 .parentSelectionPolicy(Tournament.of(numberTournaments))
183 .combinationPolicy(MultiCombinations.of(MultiPointArithmetic.of(combinationArithmetic, 0.5),
184 MultiPointCrossover.of(combinationCrossover)))
185 .mutationPolicies(MultiMutations.of(RandomMutation.of(randomMutationRate),
186 CreepMutation.ofNormal(creepMutationRate, creepMutationMean, creepMutationStdDev)))
187 .fitness(fitnessFunction)
188 .postEvaluationProcessor(FitnessSharingUtils.clusterDistance)
189 .termination(terminations);
190 final var eaConfiguration = eaConfigurationBuilder.build();
191
192
193 final var eaExecutionContext = EAExecutionContexts.<Double>forScalarFitness()
194 .populationSize(populationSize)
195 .addEvolutionListeners(EvolutionListeners.ofLogTopN(logger, 3),
196 new CSVEvolutionListener.Builder<Double, Double>().filename(outputCSV)
197 .columnExtractors(List.of(
198 ColumnExtractor.of("generation", (evolutionStep) -> evolutionStep.generation()),
199 ColumnExtractor.of("fitness", (evolutionStep) -> evolutionStep.fitness()),
200 ColumnExtractor.of("combination_arithmetic", (evolutionStep) -> combinationArithmetic),
201 ColumnExtractor.of("combination_crossover", (evolutionStep) -> combinationCrossover),
202 ColumnExtractor.of("random_mutation_rate", (evolutionStep) -> randomMutationRate),
203 ColumnExtractor.of("creep_mutation_mean", (evolutionStep) -> creepMutationMean),
204 ColumnExtractor.of("creep_mutation_stddev", (evolutionStep) -> creepMutationStdDev),
205 ColumnExtractor.of("creep_mutation_rate", (evolutionStep) -> creepMutationRate)))
206 .build())
207 .build();
208
209 final var eaSystem = EASystemFactory.from(eaConfiguration, eaExecutionContext);
210
211 final var evolutionResult = eaSystem.evolve();
212 logger.info("Best genotype: {}", evolutionResult.bestGenotype());
213 logger.info(" with fitness: {}", evolutionResult.bestFitness());
214 logger.info(" at generation: {}", evolutionResult.generation());
215
216 final Genotype bestGenotype = evolutionResult.bestGenotype();
217 final double[][] bestPhenotype = PhenotypeUtils.toPhenotype(bestGenotype);
218 logger.info("Best phenotype:");
219 for (int i = 0; i < k; i++) {
220 logger.info("\tx: {} - y: {}", bestPhenotype[i][0], bestPhenotype[i][1]);
221 }
222 final int[] bestClusterMembership = FitnessUtils.assignDataToClusters(data, distances, bestPhenotype);
223 IOUtils
224 .persistDataPoints(data, bestClusterMembership, baseDir + "clustering-result-ga" + filenameSuffix + ".csv");
225 IOUtils.persistClusters(bestPhenotype, baseDir + "clustering-result-clusters-ga" + filenameSuffix + ".csv");
226
227 }
228
229 public static List<CentroidCluster<LocationWrapper>> apacheCommonsMathCluster(final double[][] clusters,
230 final double[][] data) {
231
232 logger.info("Initializing kmeans from Apache Commons Math");
233
234 final long startTs = System.currentTimeMillis();
235
236 final int numClusters = clusters.length;
237 final int numDataPoints = data.length;
238
239 final List<LocationWrapper> clusterInput = new ArrayList<LocationWrapper>(numDataPoints);
240 for (int i = 0; i < numDataPoints; i++) {
241 clusterInput.add(new LocationWrapper(data[i]));
242 }
243
244 logger.info("Running kmeans");
245
246 final KMeansPlusPlusClusterer<LocationWrapper> clusterer = new KMeansPlusPlusClusterer<LocationWrapper>(
247 numClusters,
248 10_000);
249 final List<CentroidCluster<LocationWrapper>> clusterResults = clusterer.cluster(clusterInput);
250
251 final long durationMs = System.currentTimeMillis() - startTs;
252 logger.info("Computation time: {}", DurationFormatUtils.formatDurationHMS(durationMs));
253
254 return clusterResults;
255 }
256
257 public static void main(String[] args) throws IOException {
258 logger.info("Starting");
259
260 final Random random = new Random();
261
262 final double min = -100;
263 final double max = 100;
264 final double minX = min;
265 final double maxX = max;
266 final double minY = min;
267 final double maxY = max;
268
269 final double radius = 8;
270
271 final int numDataPoints = 1_000;
272
273
274
275
276
277 final CommandLineParser parser = new DefaultParser();
278
279 final Options options = new Options();
280 options.addOption(PARAM_NUM_CLUSTERS, LONG_PARAM_NUM_CLUSTERS, true, "number of clusters");
281 options.addOption(PARAM_NUMBER_TOURNAMENTS, LONG_PARAM_NUMBER_TOURNAMENTS, true, "number of tournaments");
282 options.addOption(PARAM_SOURCE_CLUSTERS_CSV, LONG_PARAM_SOURCE_CUSTERS_CSV, true, "source csv file for clusters");
283 options.addOption(PARAM_SOURCE_DATA_CSV, LONG_PARAM_SOURCE_DATA_CSV, true, "source csv file for data");
284 options.addOption(PARAM_OUTPUT_CSV, LONG_PARAM_OUTPUT_CSV, true, "output csv");
285 options.addOption(PARAM_OUTPUT_WITH_SSE_CSV, LONG_PARAM_OUTPUT_WITH_SSE_CSV, true, "output with sse csv");
286 options
287 .addOption(PARAM_COMBINATION_ARITHMETIC, LONG_PARAM_COMBINATION_ARITHMETIC, true, "combination arithmetic");
288 options.addOption(PARAM_COMBINATION_CROSSOVER, LONG_PARAM_COMBINATION_CROSSOVER, true, "combination crossover");
289 options.addOption(PARAM_POPULATION_SIZE, LONG_PARAM_POPULATION_SIZE, true, "population size");
290 options.addOption(PARAM_BASE_DIR_OUTPUT, LONG_PARAM_BASE_DIR_OUTPUT, true, "base directory");
291 options.addOption(PARAM_CREEP_MUTATION_STD_DEV,
292 LONG_PARAM_CREEP_MUTATION_STD_DEV,
293 true,
294 "creep mutation std dev. Default: " + DEFAULT_CREEP_MUTATION_STDDEV);
295 options.addOption(PARAM_CREEP_MUTATION_MEAN,
296 LONG_PARAM_CREEP_MUTATION_MEAN,
297 true,
298 "creep mutation mean. Default: " + DEFAULT_CREEP_MUTATION_MEAN);
299 options.addOption(PARAM_CREEP_MUTATION_RATE,
300 LONG_PARAM_CREEP_MUTATION_RATE,
301 true,
302 "creep mutation rate. Default: " + DEFAULT_CREEP_MUTATION_RATE);
303 options.addOption(PARAM_RANDOM_MUTATION_RATE,
304 LONG_PARAM_RANDOM_MUTATION_RATE,
305 true,
306 "random mutation rate. Default: " + DEFAULT_RANDOM_MUTATION_RATE);
307 options.addOption(PARAM_FIXED_TERMINATION,
308 LONG_PARAM_FIXED_TERMINATION,
309 true,
310 "Fix the termination to the specified number of generations");
311
312 Optional<String> paramSourceClustersCSV = Optional.empty();
313 Optional<String> paramSourceDataCSV = Optional.empty();
314 Optional<Integer> paramNumClusters = Optional.empty();
315 Optional<String> paramOutputCSV = Optional.empty();
316 Optional<String> paramOutputWithSSECSV = Optional.empty();
317 Optional<Long> paramFixedTermination = Optional.empty();
318
319 int numberTournaments = DEFAULT_NUMBER_TOURNAMENTS;
320 int populationSize = DEFAULT_POPULATION_SIZE;
321 final double randomMutationRate;
322 final double creepMutationRate;
323 final double creepMutationMean;
324 final double creepMutationStdDev;
325 final int combinationArithmetic;
326 final int combinationCrossover;
327 final String baseDir;
328 try {
329 final CommandLine line = parser.parse(options, args);
330
331 if (line.hasOption(PARAM_NUMBER_TOURNAMENTS)) {
332 numberTournaments = Integer.parseInt(line.getOptionValue(PARAM_NUMBER_TOURNAMENTS)
333 .strip());
334 }
335
336 baseDir = Optional.ofNullable(line.getOptionValue(PARAM_BASE_DIR_OUTPUT))
337 .orElse("");
338
339 combinationArithmetic = Optional.ofNullable(line.getOptionValue(PARAM_COMBINATION_ARITHMETIC))
340 .map(String::strip)
341 .map(Integer::parseInt)
342 .orElse(DEFAULT_COMBINATION_ARITHMETIC);
343
344 combinationCrossover = Optional.ofNullable(line.getOptionValue(PARAM_COMBINATION_CROSSOVER))
345 .map(String::strip)
346 .map(Integer::parseInt)
347 .orElse(DEFAULT_COMBINATION_CROSSOVER_ARITHMETIC);
348
349 paramNumClusters = Optional.ofNullable(line.getOptionValue(PARAM_NUM_CLUSTERS))
350 .map(String::strip)
351 .map(Integer::parseInt);
352 populationSize = Optional.ofNullable(line.getOptionValue(PARAM_POPULATION_SIZE))
353 .map(String::strip)
354 .map(Integer::parseInt)
355 .orElse(DEFAULT_POPULATION_SIZE);
356
357 paramSourceClustersCSV = Optional.ofNullable(line.getOptionValue(PARAM_SOURCE_CLUSTERS_CSV))
358 .map(String::strip);
359
360 paramSourceDataCSV = Optional.ofNullable(line.getOptionValue(PARAM_SOURCE_DATA_CSV))
361 .map(String::strip);
362
363 paramOutputCSV = Optional.ofNullable(line.getOptionValue(PARAM_OUTPUT_CSV))
364 .map(String::strip);
365 paramOutputWithSSECSV = Optional.ofNullable(line.getOptionValue(PARAM_OUTPUT_WITH_SSE_CSV))
366 .map(String::strip);
367
368 paramFixedTermination = Optional.ofNullable(line.getOptionValue(PARAM_FIXED_TERMINATION))
369 .map(String::strip)
370 .map(Long::parseLong);
371
372 randomMutationRate = Optional.ofNullable(line.getOptionValue(PARAM_RANDOM_MUTATION_RATE))
373 .map(String::strip)
374 .map(Double::parseDouble)
375 .orElse(DEFAULT_RANDOM_MUTATION_RATE);
376
377 creepMutationRate = Optional.ofNullable(line.getOptionValue(PARAM_CREEP_MUTATION_RATE))
378 .map(String::strip)
379 .map(Double::parseDouble)
380 .orElse(DEFAULT_CREEP_MUTATION_RATE);
381
382 creepMutationMean = Optional.ofNullable(line.getOptionValue(PARAM_CREEP_MUTATION_MEAN))
383 .map(String::strip)
384 .map(Double::parseDouble)
385 .orElse(DEFAULT_CREEP_MUTATION_MEAN);
386
387 creepMutationStdDev = Optional.ofNullable(line.getOptionValue(PARAM_CREEP_MUTATION_STD_DEV))
388 .map(String::strip)
389 .map(Double::parseDouble)
390 .orElse(DEFAULT_CREEP_MUTATION_STDDEV);
391
392 logger.info("Unrecognized args:");
393 boolean hasError = false;
394 for (final String extraArg : line.getArgList()) {
395 logger.info("\t[{}]", extraArg);
396 if (extraArg.isBlank() == false) {
397 hasError = true;
398 }
399 }
400
401 if (hasError) {
402 throw new RuntimeException();
403 }
404 } catch (ParseException exp) {
405 cliError(options, "Unexpected exception:" + exp.getMessage());
406
407
408 throw new RuntimeException();
409
410 }
411
412 final int numClusters = paramNumClusters.orElse(DEFAULT_NUM_CLUSTERS);
413
414 logger.info("Preparing {} clusters", numClusters);
415 final double[][] clusters = paramSourceClustersCSV.map(IOUtils::loadClusters)
416 .orElseGet(() -> generateClusters(random, numClusters, minX, maxX, minY, maxY));
417 logger.info("Found {} clusters", clusters.length);
418
419 logger.info("Preparing data points");
420 final double[][] data = paramSourceDataCSV.map(sourceDataCSVFileName -> {
421 try {
422 logger.info("Loading data points");
423 return IOUtils.loadDataPoints(sourceDataCSVFileName);
424 } catch (IOException e) {
425 throw new RuntimeException("Could not load " + sourceDataCSVFileName, e);
426 }
427 })
428 .orElseGet(() -> generateDataPoints(random, clusters, numDataPoints, radius));
429 final double[][] distances = computeAllDistances(data);
430
431 final String originalClustersFilename = "originalClusters.csv";
432 if (paramSourceClustersCSV.isPresent()) {
433 logger.info("Not persisting clusters since it was provided");
434 } else {
435 logger.info("Saving clusters to CSV: {}", originalClustersFilename);
436 IOUtils.persistClusters(clusters, originalClustersFilename);
437 }
438
439 final String originalDataFilename = "originalData.csv";
440 if (paramSourceDataCSV.isPresent()) {
441 logger.info("Not persisting data since it was provided");
442 } else {
443 logger.info("Saving data to CSV: {}", originalDataFilename);
444 IOUtils.persistDataPoints(data, originalDataFilename);
445 }
446
447 logger.info("Clustering data with Apache Commons Math");
448
449 final List<CentroidCluster<LocationWrapper>> clusterResults = apacheCommonsMathCluster(clusters, data);
450
451 logger.info("Definition of genetic problem");
452
453 final int k = numClusters;
454 final var fitnessFunction = FitnessUtils.computeFitness(numDataPoints, data, distances, k);
455
456 final var terminations = paramFixedTermination
457 .map(maxGeneration -> Terminations.<Double>ofMaxGeneration(maxGeneration))
458 .orElseGet(() -> or(Terminations.<Double>ofMaxGeneration(500), Terminations.ofStableFitness(50)));
459
460 logger.info("Terminations: {}", paramFixedTermination);
461 logger.info("Parameters: random_mutation_rate: {} - creep_mutation_rate: {} - creep_mutation_stddev: {} ",
462 randomMutationRate,
463 creepMutationRate,
464 creepMutationStdDev);
465 logger.info("Combinations: arithmetic {} ; crossover {}", combinationArithmetic, combinationCrossover);
466
467 logger.info("Running GA with Silhouette score");
468 doGA(k,
469 min,
470 max,
471 numberTournaments,
472 combinationArithmetic,
473 combinationCrossover,
474 randomMutationRate,
475 creepMutationRate,
476 creepMutationMean,
477 creepMutationStdDev,
478 fitnessFunction,
479 terminations,
480 populationSize,
481 paramOutputCSV.orElse("output.csv"),
482 data,
483 distances,
484 baseDir,
485 "");
486
487 logger.info("Running GA with Silhouette score + SSE");
488 final var fitnessFunctionWithSumSquareErrors = FitnessUtils
489 .computeFitnessWithSSE(numDataPoints, data, distances, k);
490 doGA(k,
491 min,
492 max,
493 numberTournaments,
494 combinationArithmetic,
495 combinationCrossover,
496 randomMutationRate,
497 creepMutationRate,
498 creepMutationMean,
499 creepMutationStdDev,
500 fitnessFunctionWithSumSquareErrors,
501 terminations,
502 populationSize,
503 paramOutputWithSSECSV.orElse("output-with-sse.csv"),
504 data,
505 distances,
506 baseDir,
507 "-with-sse");
508
509 logger.info("Original clusters:");
510 final double[] originalMeans = new double[numClusters * 2];
511 for (int i = 0; i < numClusters; i++) {
512 logger.info("\tx: {} - y: {}", clusters[i][0], clusters[i][1]);
513 originalMeans[i * 2] = clusters[i][0];
514 originalMeans[i * 2 + 1] = clusters[i][1];
515 }
516 final var originalFitness = fitnessFunction
517 .compute(new Genotype(new DoubleChromosome(k * 2, -100.0d, 100.0d, originalMeans)));
518 logger.info("Original fitness: {}", originalFitness);
519 final int[] originalClusterMembership = FitnessUtils.assignDataToClusters(data, distances, clusters);
520 IOUtils.persistDataPoints(data, originalClusterMembership, baseDir + "clustering-result-original.csv");
521 IOUtils.persistClusters(clusters, baseDir + "clustering-result-clusters-original.csv");
522
523 logger.info("kmeans output:");
524
525 final double[] kmeansClusters = new double[numClusters * 2];
526 for (int i = 0; i < clusterResults.size(); i++) {
527 final CentroidCluster<LocationWrapper> centroidCluster = clusterResults.get(i);
528 logger.info("\t{}", centroidCluster.getCenter());
529 kmeansClusters[i * 2] = centroidCluster.getCenter()
530 .getPoint()[0];
531 kmeansClusters[i * 2 + 1] = centroidCluster.getCenter()
532 .getPoint()[1];
533 }
534 final var kmeansGenotype = new Genotype(new DoubleChromosome(k * 2, -100.0d, 100.0d, kmeansClusters));
535 final var kmeansFitness = fitnessFunction.compute(kmeansGenotype);
536 logger.info("kmeans fitness: {}", kmeansFitness);
537
538 final int[] kmeansClusterMembership = FitnessUtils
539 .assignDataToClusters(data, distances, PhenotypeUtils.toPhenotype(kmeansGenotype));
540 IOUtils.persistDataPoints(data, kmeansClusterMembership, baseDir + "clustering-result-kmeans.csv");
541 IOUtils.persistClusters(PhenotypeUtils.toPhenotype(kmeansGenotype),
542 baseDir + "clustering-result-clusters-kmeans.csv");
543
544 logger.info("Done");
545 }
546 }