1 package net.bmahe.genetics4j.samples.clustering;
2
3 import static net.bmahe.genetics4j.core.termination.Terminations.or;
4
5 import java.io.IOException;
6 import java.util.ArrayList;
7 import java.util.List;
8 import java.util.Optional;
9 import java.util.Random;
10
11 import org.apache.commons.cli.CommandLine;
12 import org.apache.commons.cli.CommandLineParser;
13 import org.apache.commons.cli.DefaultParser;
14 import org.apache.commons.cli.Options;
15 import org.apache.commons.cli.ParseException;
16 import org.apache.commons.lang3.Validate;
17 import org.apache.commons.lang3.time.DurationFormatUtils;
18 import org.apache.commons.math3.ml.clustering.CentroidCluster;
19 import org.apache.commons.math3.ml.clustering.KMeansPlusPlusClusterer;
20 import org.apache.logging.log4j.LogManager;
21 import org.apache.logging.log4j.Logger;
22
23 import net.bmahe.genetics4j.core.EASystemFactory;
24 import net.bmahe.genetics4j.core.Fitness;
25 import net.bmahe.genetics4j.core.Genotype;
26 import net.bmahe.genetics4j.core.chromosomes.DoubleChromosome;
27 import net.bmahe.genetics4j.core.evolutionlisteners.EvolutionListeners;
28 import net.bmahe.genetics4j.core.spec.EAConfiguration;
29 import net.bmahe.genetics4j.core.spec.EAExecutionContexts;
30 import net.bmahe.genetics4j.core.spec.chromosome.DoubleChromosomeSpec;
31 import net.bmahe.genetics4j.core.spec.combination.MultiCombinations;
32 import net.bmahe.genetics4j.core.spec.combination.MultiPointArithmetic;
33 import net.bmahe.genetics4j.core.spec.combination.MultiPointCrossover;
34 import net.bmahe.genetics4j.core.spec.mutation.CreepMutation;
35 import net.bmahe.genetics4j.core.spec.mutation.MultiMutations;
36 import net.bmahe.genetics4j.core.spec.mutation.RandomMutation;
37 import net.bmahe.genetics4j.core.spec.selection.Tournament;
38 import net.bmahe.genetics4j.core.termination.Termination;
39 import net.bmahe.genetics4j.core.termination.Terminations;
40 import net.bmahe.genetics4j.extras.evolutionlisteners.CSVEvolutionListener;
41 import net.bmahe.genetics4j.extras.evolutionlisteners.ColumnExtractor;
42 import net.bmahe.genetics4j.samples.CLIUtils;
43
44 public class Clustering {
45 final static public Logger logger = LogManager.getLogger(Clustering.class);
46
47 final static public int DEFAULT_NUM_CLUSTERS = 6;
48 final static public int DEFAULT_NUMBER_TOURNAMENTS = 2;
49 final static public int DEFAULT_POPULATION_SIZE = 120;
50 final static public double DEFAULT_RANDOM_MUTATION_RATE = 0.15d;
51 final static public double DEFAULT_CREEP_MUTATION_RATE = 0.20d;
52 final static public double DEFAULT_CREEP_MUTATION_MEAN = 0.0d;
53 final static public double DEFAULT_CREEP_MUTATION_STDDEV = 5;
54 final static public int DEFAULT_COMBINATION_ARITHMETIC = 3;
55 final static public int DEFAULT_COMBINATION_CROSSOVER_ARITHMETIC = 3;
56
57 final static public String PARAM_NUM_CLUSTERS = "n";
58 final static public String LONG_PARAM_NUM_CLUSTERS = "num-clusters";
59
60 final static public String PARAM_NUMBER_TOURNAMENTS = "t";
61 final static public String LONG_PARAM_NUMBER_TOURNAMENTS = "num-tournaments";
62
63 final static public String PARAM_POPULATION_SIZE = "p";
64 final static public String LONG_PARAM_POPULATION_SIZE = "population-size";
65
66 final static public String PARAM_SOURCE_CLUSTERS_CSV = "c";
67 final static public String LONG_PARAM_SOURCE_CUSTERS_CSV = "source-clusters";
68
69 final static public String PARAM_SOURCE_DATA_CSV = "s";
70 final static public String LONG_PARAM_SOURCE_DATA_CSV = "source-data";
71
72 final static public String PARAM_FIXED_TERMINATION = "f";
73 final static public String LONG_PARAM_FIXED_TERMINATION = "fixed-termination";
74
75 final static public String PARAM_RANDOM_MUTATION_RATE = "r";
76 final static public String LONG_PARAM_RANDOM_MUTATION_RATE = "random-mutation-rate";
77
78 final static public String PARAM_CREEP_MUTATION_RATE = "m";
79 final static public String LONG_PARAM_CREEP_MUTATION_RATE = "creep-mutation-rate";
80
81 final static public String PARAM_CREEP_MUTATION_MEAN = "a";
82 final static public String LONG_PARAM_CREEP_MUTATION_MEAN = "creep-mutation-mean";
83
84 final static public String PARAM_CREEP_MUTATION_STD_DEV = "d";
85 final static public String LONG_PARAM_CREEP_MUTATION_STD_DEV = "creep-mutation-std-dev";
86
87 final static public String PARAM_COMBINATION_ARITHMETIC = "b";
88 final static public String LONG_PARAM_COMBINATION_ARITHMETIC = "combination-arithmetic";
89
90 final static public String PARAM_COMBINATION_CROSSOVER = "e";
91 final static public String LONG_PARAM_COMBINATION_CROSSOVER = "combination-crossover";
92
93 final static public String PARAM_OUTPUT_CSV = "o";
94 final static public String LONG_PARAM_OUTPUT_CSV = "output";
95
96 final static public String PARAM_OUTPUT_WITH_SSE_CSV = "g";
97 final static public String LONG_PARAM_OUTPUT_WITH_SSE_CSV = "output-sse";
98
99 final static public String PARAM_BASE_DIR_OUTPUT = "i";
100 final static public String LONG_PARAM_BASE_DIR_OUTPUT = "base-dir";
101
102 public static void cliError(final Options options, final String errorMessage) {
103 CLIUtils.cliHelpAndExit(logger, Clustering.class, options, errorMessage);
104 }
105
106 private final static double computeDistance(final double[][] array, final int i, final int j) {
107 final double xDiff = array[j][0] - array[i][0];
108 final double yDiff = array[j][1] - array[i][1];
109 return Math.sqrt((xDiff * xDiff) + (yDiff * yDiff));
110 }
111
112 private final static double[][] computeAllDistances(final double[][] array) {
113
114 final double[][] distances = new double[array.length][array.length];
115
116 for (int i = 0; i < array.length; i++) {
117 distances[i][i] = 0.0;
118 }
119
120 for (int i = 0; i < array.length; i++) {
121 for (int j = 0; j < i; j++) {
122 final double distance = computeDistance(array, i, j);
123 distances[i][j] = distance;
124 distances[j][i] = distance;
125 }
126 }
127
128 return distances;
129 }
130
131
132 public static double[][] generateClusters(final Random random, final int numClusters, final double minX,
133 final double maxX, final double minY, final double maxY) {
134 Validate.notNull(random);
135 Validate.isTrue(numClusters > 0);
136 Validate.isTrue(minX <= maxX);
137 Validate.isTrue(minY <= maxY);
138
139 logger.info("Generating {} clusters", numClusters);
140
141 final double[][] clusters = new double[numClusters][2];
142 for (int i = 0; i < numClusters; i++) {
143 clusters[i][0] = minX + random.nextDouble() * (maxX - minX);
144 clusters[i][1] = minY + random.nextDouble() * (maxY - minY);
145 }
146
147 return clusters;
148 }
149
150
151
152 public static double[][] generateDataPoints(final Random random, final double[][] clusters, final int numDataPoints,
153 final double radius) {
154 Validate.notNull(random);
155 Validate.notNull(clusters);
156 Validate.isTrue(clusters.length > 0);
157
158 final int numClusters = clusters.length;
159 final double[][] data = new double[numDataPoints][3];
160 for (int i = 0; i < numDataPoints; i++) {
161 final int clusterIndex = i % numClusters;
162
163 data[i][0] = random.nextGaussian() * radius + clusters[clusterIndex][0];
164 data[i][1] = random.nextGaussian() * radius + clusters[clusterIndex][1];
165 data[i][2] = clusterIndex;
166 }
167
168 return data;
169 }
170
171
172 public static void doGA(final int k, final double min, final double max, final int numberTournaments,
173 final int combinationArithmetic, final int combinationCrossover, final double randomMutationRate,
174 final double creepMutationRate, final double creepMutationMean, final double creepMutationStdDev,
175 final Fitness<Double> fitnessFunction, final Termination<Double> terminations, final int populationSize,
176 final String outputCSV, final double[][] data, final double[][] distances, final String baseDir,
177 final String filenameSuffix) throws IOException {
178
179
180 final var eaConfigurationBuilder = new EAConfiguration.Builder<Double>();
181 eaConfigurationBuilder.chromosomeSpecs(DoubleChromosomeSpec.of(k * 2, min, max))
182 .parentSelectionPolicy(Tournament.of(numberTournaments))
183 .combinationPolicy(
184 MultiCombinations.of(
185 MultiPointArithmetic.of(combinationArithmetic, 0.5),
186 MultiPointCrossover.of(combinationCrossover)))
187 .mutationPolicies(
188 MultiMutations.of(
189 RandomMutation.of(randomMutationRate),
190 CreepMutation.ofNormal(creepMutationRate, creepMutationMean, creepMutationStdDev)))
191 .fitness(fitnessFunction)
192 .postEvaluationProcessor(FitnessSharingUtils.clusterDistance)
193 .termination(terminations);
194 final var eaConfiguration = eaConfigurationBuilder.build();
195
196
197 final var eaExecutionContext = EAExecutionContexts.<Double>forScalarFitness()
198 .populationSize(populationSize)
199 .addEvolutionListeners(
200 EvolutionListeners.ofLogTopN(logger, 3),
201 new CSVEvolutionListener.Builder<Double, Double>().filename(outputCSV)
202 .columnExtractors(
203 List.of(
204 ColumnExtractor.of("generation", (evolutionStep) -> evolutionStep.generation()),
205 ColumnExtractor.of("fitness", (evolutionStep) -> evolutionStep.fitness()),
206 ColumnExtractor
207 .of("combination_arithmetic", (evolutionStep) -> combinationArithmetic),
208 ColumnExtractor
209 .of("combination_crossover", (evolutionStep) -> combinationCrossover),
210 ColumnExtractor
211 .of("random_mutation_rate", (evolutionStep) -> randomMutationRate),
212 ColumnExtractor
213 .of("creep_mutation_mean", (evolutionStep) -> creepMutationMean),
214 ColumnExtractor
215 .of("creep_mutation_stddev", (evolutionStep) -> creepMutationStdDev),
216 ColumnExtractor
217 .of("creep_mutation_rate", (evolutionStep) -> creepMutationRate)))
218 .build())
219 .build();
220
221 final var eaSystem = EASystemFactory.from(eaConfiguration, eaExecutionContext);
222
223 final var evolutionResult = eaSystem.evolve();
224 logger.info("Best genotype: {}", evolutionResult.bestGenotype());
225 logger.info(" with fitness: {}", evolutionResult.bestFitness());
226 logger.info(" at generation: {}", evolutionResult.generation());
227
228 final Genotype bestGenotype = evolutionResult.bestGenotype();
229 final double[][] bestPhenotype = PhenotypeUtils.toPhenotype(bestGenotype);
230 logger.info("Best phenotype:");
231 for (int i = 0; i < k; i++) {
232 logger.info("\tx: {} - y: {}", bestPhenotype[i][0], bestPhenotype[i][1]);
233 }
234 final int[] bestClusterMembership = FitnessUtils.assignDataToClusters(data, distances, bestPhenotype);
235 IOUtils
236 .persistDataPoints(data, bestClusterMembership, baseDir + "clustering-result-ga" + filenameSuffix + ".csv");
237 IOUtils.persistClusters(bestPhenotype, baseDir + "clustering-result-clusters-ga" + filenameSuffix + ".csv");
238
239 }
240
241 public static List<CentroidCluster<LocationWrapper>> apacheCommonsMathCluster(final double[][] clusters,
242 final double[][] data) {
243
244 logger.info("Initializing kmeans from Apache Commons Math");
245
246 final long startTs = System.currentTimeMillis();
247
248 final int numClusters = clusters.length;
249 final int numDataPoints = data.length;
250
251 final List<LocationWrapper> clusterInput = new ArrayList<LocationWrapper>(numDataPoints);
252 for (int i = 0; i < numDataPoints; i++) {
253 clusterInput.add(new LocationWrapper(data[i]));
254 }
255
256 logger.info("Running kmeans");
257
258 final KMeansPlusPlusClusterer<LocationWrapper> clusterer = new KMeansPlusPlusClusterer<LocationWrapper>(
259 numClusters,
260 10_000);
261 final List<CentroidCluster<LocationWrapper>> clusterResults = clusterer.cluster(clusterInput);
262
263 final long durationMs = System.currentTimeMillis() - startTs;
264 logger.info("Computation time: {}", DurationFormatUtils.formatDurationHMS(durationMs));
265
266 return clusterResults;
267 }
268
269 public static void main(String[] args) throws IOException {
270 logger.info("Starting");
271
272 final Random random = new Random();
273
274 final double min = -100;
275 final double max = 100;
276 final double minX = min;
277 final double maxX = max;
278 final double minY = min;
279 final double maxY = max;
280
281 final double radius = 8;
282
283 final int numDataPoints = 1_000;
284
285
286
287
288
289 final CommandLineParser parser = new DefaultParser();
290
291 final Options options = new Options();
292 options.addOption(PARAM_NUM_CLUSTERS, LONG_PARAM_NUM_CLUSTERS, true, "number of clusters");
293 options.addOption(PARAM_NUMBER_TOURNAMENTS, LONG_PARAM_NUMBER_TOURNAMENTS, true, "number of tournaments");
294 options.addOption(PARAM_SOURCE_CLUSTERS_CSV, LONG_PARAM_SOURCE_CUSTERS_CSV, true, "source csv file for clusters");
295 options.addOption(PARAM_SOURCE_DATA_CSV, LONG_PARAM_SOURCE_DATA_CSV, true, "source csv file for data");
296 options.addOption(PARAM_OUTPUT_CSV, LONG_PARAM_OUTPUT_CSV, true, "output csv");
297 options.addOption(PARAM_OUTPUT_WITH_SSE_CSV, LONG_PARAM_OUTPUT_WITH_SSE_CSV, true, "output with sse csv");
298 options
299 .addOption(PARAM_COMBINATION_ARITHMETIC, LONG_PARAM_COMBINATION_ARITHMETIC, true, "combination arithmetic");
300 options.addOption(PARAM_COMBINATION_CROSSOVER, LONG_PARAM_COMBINATION_CROSSOVER, true, "combination crossover");
301 options.addOption(PARAM_POPULATION_SIZE, LONG_PARAM_POPULATION_SIZE, true, "population size");
302 options.addOption(PARAM_BASE_DIR_OUTPUT, LONG_PARAM_BASE_DIR_OUTPUT, true, "base directory");
303 options.addOption(
304 PARAM_CREEP_MUTATION_STD_DEV,
305 LONG_PARAM_CREEP_MUTATION_STD_DEV,
306 true,
307 "creep mutation std dev. Default: " + DEFAULT_CREEP_MUTATION_STDDEV);
308 options.addOption(
309 PARAM_CREEP_MUTATION_MEAN,
310 LONG_PARAM_CREEP_MUTATION_MEAN,
311 true,
312 "creep mutation mean. Default: " + DEFAULT_CREEP_MUTATION_MEAN);
313 options.addOption(
314 PARAM_CREEP_MUTATION_RATE,
315 LONG_PARAM_CREEP_MUTATION_RATE,
316 true,
317 "creep mutation rate. Default: " + DEFAULT_CREEP_MUTATION_RATE);
318 options.addOption(
319 PARAM_RANDOM_MUTATION_RATE,
320 LONG_PARAM_RANDOM_MUTATION_RATE,
321 true,
322 "random mutation rate. Default: " + DEFAULT_RANDOM_MUTATION_RATE);
323 options.addOption(
324 PARAM_FIXED_TERMINATION,
325 LONG_PARAM_FIXED_TERMINATION,
326 true,
327 "Fix the termination to the specified number of generations");
328
329 Optional<String> paramSourceClustersCSV = Optional.empty();
330 Optional<String> paramSourceDataCSV = Optional.empty();
331 Optional<Integer> paramNumClusters = Optional.empty();
332 Optional<String> paramOutputCSV = Optional.empty();
333 Optional<String> paramOutputWithSSECSV = Optional.empty();
334 Optional<Long> paramFixedTermination = Optional.empty();
335
336 int numberTournaments = DEFAULT_NUMBER_TOURNAMENTS;
337 int populationSize = DEFAULT_POPULATION_SIZE;
338 final double randomMutationRate;
339 final double creepMutationRate;
340 final double creepMutationMean;
341 final double creepMutationStdDev;
342 final int combinationArithmetic;
343 final int combinationCrossover;
344 final String baseDir;
345 try {
346 final CommandLine line = parser.parse(options, args);
347
348 if (line.hasOption(PARAM_NUMBER_TOURNAMENTS)) {
349 numberTournaments = Integer.parseInt(line.getOptionValue(PARAM_NUMBER_TOURNAMENTS).strip());
350 }
351
352 baseDir = Optional.ofNullable(line.getOptionValue(PARAM_BASE_DIR_OUTPUT)).orElse("");
353
354 combinationArithmetic = Optional.ofNullable(line.getOptionValue(PARAM_COMBINATION_ARITHMETIC))
355 .map(String::strip)
356 .map(Integer::parseInt)
357 .orElse(DEFAULT_COMBINATION_ARITHMETIC);
358
359 combinationCrossover = Optional.ofNullable(line.getOptionValue(PARAM_COMBINATION_CROSSOVER))
360 .map(String::strip)
361 .map(Integer::parseInt)
362 .orElse(DEFAULT_COMBINATION_CROSSOVER_ARITHMETIC);
363
364 paramNumClusters = Optional.ofNullable(line.getOptionValue(PARAM_NUM_CLUSTERS))
365 .map(String::strip)
366 .map(Integer::parseInt);
367 populationSize = Optional.ofNullable(line.getOptionValue(PARAM_POPULATION_SIZE))
368 .map(String::strip)
369 .map(Integer::parseInt)
370 .orElse(DEFAULT_POPULATION_SIZE);
371
372 paramSourceClustersCSV = Optional.ofNullable(line.getOptionValue(PARAM_SOURCE_CLUSTERS_CSV))
373 .map(String::strip);
374
375 paramSourceDataCSV = Optional.ofNullable(line.getOptionValue(PARAM_SOURCE_DATA_CSV)).map(String::strip);
376
377 paramOutputCSV = Optional.ofNullable(line.getOptionValue(PARAM_OUTPUT_CSV)).map(String::strip);
378 paramOutputWithSSECSV = Optional.ofNullable(line.getOptionValue(PARAM_OUTPUT_WITH_SSE_CSV)).map(String::strip);
379
380 paramFixedTermination = Optional.ofNullable(line.getOptionValue(PARAM_FIXED_TERMINATION))
381 .map(String::strip)
382 .map(Long::parseLong);
383
384 randomMutationRate = Optional.ofNullable(line.getOptionValue(PARAM_RANDOM_MUTATION_RATE))
385 .map(String::strip)
386 .map(Double::parseDouble)
387 .orElse(DEFAULT_RANDOM_MUTATION_RATE);
388
389 creepMutationRate = Optional.ofNullable(line.getOptionValue(PARAM_CREEP_MUTATION_RATE))
390 .map(String::strip)
391 .map(Double::parseDouble)
392 .orElse(DEFAULT_CREEP_MUTATION_RATE);
393
394 creepMutationMean = Optional.ofNullable(line.getOptionValue(PARAM_CREEP_MUTATION_MEAN))
395 .map(String::strip)
396 .map(Double::parseDouble)
397 .orElse(DEFAULT_CREEP_MUTATION_MEAN);
398
399 creepMutationStdDev = Optional.ofNullable(line.getOptionValue(PARAM_CREEP_MUTATION_STD_DEV))
400 .map(String::strip)
401 .map(Double::parseDouble)
402 .orElse(DEFAULT_CREEP_MUTATION_STDDEV);
403
404 logger.info("Unrecognized args:");
405 boolean hasError = false;
406 for (final String extraArg : line.getArgList()) {
407 logger.info("\t[{}]", extraArg);
408 if (extraArg.isBlank() == false) {
409 hasError = true;
410 }
411 }
412
413 if (hasError) {
414 throw new RuntimeException();
415 }
416 } catch (ParseException exp) {
417 cliError(options, "Unexpected exception:" + exp.getMessage());
418
419
420 throw new RuntimeException();
421
422 }
423
424 final int numClusters = paramNumClusters.orElse(DEFAULT_NUM_CLUSTERS);
425
426 logger.info("Preparing {} clusters", numClusters);
427 final double[][] clusters = paramSourceClustersCSV.map(IOUtils::loadClusters)
428 .orElseGet(() -> generateClusters(random, numClusters, minX, maxX, minY, maxY));
429 logger.info("Found {} clusters", clusters.length);
430
431 logger.info("Preparing data points");
432 final double[][] data = paramSourceDataCSV.map(sourceDataCSVFileName -> {
433 try {
434 logger.info("Loading data points");
435 return IOUtils.loadDataPoints(sourceDataCSVFileName);
436 } catch (IOException e) {
437 throw new RuntimeException("Could not load " + sourceDataCSVFileName, e);
438 }
439 }).orElseGet(() -> generateDataPoints(random, clusters, numDataPoints, radius));
440 final double[][] distances = computeAllDistances(data);
441
442 final String originalClustersFilename = "originalClusters.csv";
443 if (paramSourceClustersCSV.isPresent()) {
444 logger.info("Not persisting clusters since it was provided");
445 } else {
446 logger.info("Saving clusters to CSV: {}", originalClustersFilename);
447 IOUtils.persistClusters(clusters, originalClustersFilename);
448 }
449
450 final String originalDataFilename = "originalData.csv";
451 if (paramSourceDataCSV.isPresent()) {
452 logger.info("Not persisting data since it was provided");
453 } else {
454 logger.info("Saving data to CSV: {}", originalDataFilename);
455 IOUtils.persistDataPoints(data, originalDataFilename);
456 }
457
458 logger.info("Clustering data with Apache Commons Math");
459
460 final List<CentroidCluster<LocationWrapper>> clusterResults = apacheCommonsMathCluster(clusters, data);
461
462 logger.info("Definition of genetic problem");
463
464 final int k = numClusters;
465 final var fitnessFunction = FitnessUtils.computeFitness(numDataPoints, data, distances, k);
466
467 final var terminations = paramFixedTermination
468 .map(maxGeneration -> Terminations.<Double>ofMaxGeneration(maxGeneration))
469 .orElseGet(() -> or(Terminations.<Double>ofMaxGeneration(500), Terminations.ofStableFitness(50)));
470
471 logger.info("Terminations: {}", paramFixedTermination);
472 logger.info(
473 "Parameters: random_mutation_rate: {} - creep_mutation_rate: {} - creep_mutation_stddev: {} ",
474 randomMutationRate,
475 creepMutationRate,
476 creepMutationStdDev);
477 logger.info("Combinations: arithmetic {} ; crossover {}", combinationArithmetic, combinationCrossover);
478
479 logger.info("Running GA with Silhouette score");
480 doGA(
481 k,
482 min,
483 max,
484 numberTournaments,
485 combinationArithmetic,
486 combinationCrossover,
487 randomMutationRate,
488 creepMutationRate,
489 creepMutationMean,
490 creepMutationStdDev,
491 fitnessFunction,
492 terminations,
493 populationSize,
494 paramOutputCSV.orElse("output.csv"),
495 data,
496 distances,
497 baseDir,
498 "");
499
500 logger.info("Running GA with Silhouette score + SSE");
501 final var fitnessFunctionWithSumSquareErrors = FitnessUtils
502 .computeFitnessWithSSE(numDataPoints, data, distances, k);
503 doGA(
504 k,
505 min,
506 max,
507 numberTournaments,
508 combinationArithmetic,
509 combinationCrossover,
510 randomMutationRate,
511 creepMutationRate,
512 creepMutationMean,
513 creepMutationStdDev,
514 fitnessFunctionWithSumSquareErrors,
515 terminations,
516 populationSize,
517 paramOutputWithSSECSV.orElse("output-with-sse.csv"),
518 data,
519 distances,
520 baseDir,
521 "-with-sse");
522
523 logger.info("Original clusters:");
524 final double[] originalMeans = new double[numClusters * 2];
525 for (int i = 0; i < numClusters; i++) {
526 logger.info("\tx: {} - y: {}", clusters[i][0], clusters[i][1]);
527 originalMeans[i * 2] = clusters[i][0];
528 originalMeans[i * 2 + 1] = clusters[i][1];
529 }
530 final var originalFitness = fitnessFunction
531 .compute(new Genotype(new DoubleChromosome(k * 2, -100.0d, 100.0d, originalMeans)));
532 logger.info("Original fitness: {}", originalFitness);
533 final int[] originalClusterMembership = FitnessUtils.assignDataToClusters(data, distances, clusters);
534 IOUtils.persistDataPoints(data, originalClusterMembership, baseDir + "clustering-result-original.csv");
535 IOUtils.persistClusters(clusters, baseDir + "clustering-result-clusters-original.csv");
536
537 logger.info("kmeans output:");
538
539 final double[] kmeansClusters = new double[numClusters * 2];
540 for (int i = 0; i < clusterResults.size(); i++) {
541 final CentroidCluster<LocationWrapper> centroidCluster = clusterResults.get(i);
542 logger.info("\t{}", centroidCluster.getCenter());
543 kmeansClusters[i * 2] = centroidCluster.getCenter().getPoint()[0];
544 kmeansClusters[i * 2 + 1] = centroidCluster.getCenter().getPoint()[1];
545 }
546 final var kmeansGenotype = new Genotype(new DoubleChromosome(k * 2, -100.0d, 100.0d, kmeansClusters));
547 final var kmeansFitness = fitnessFunction.compute(kmeansGenotype);
548 logger.info("kmeans fitness: {}", kmeansFitness);
549
550 final int[] kmeansClusterMembership = FitnessUtils
551 .assignDataToClusters(data, distances, PhenotypeUtils.toPhenotype(kmeansGenotype));
552 IOUtils.persistDataPoints(data, kmeansClusterMembership, baseDir + "clustering-result-kmeans.csv");
553 IOUtils.persistClusters(
554 PhenotypeUtils.toPhenotype(kmeansGenotype),
555 baseDir + "clustering-result-clusters-kmeans.csv");
556
557 logger.info("Done");
558 }
559 }