View Javadoc
1   package net.bmahe.genetics4j.samples.clustering;
2   
3   import static net.bmahe.genetics4j.core.termination.Terminations.or;
4   
5   import java.io.IOException;
6   import java.util.ArrayList;
7   import java.util.List;
8   import java.util.Optional;
9   import java.util.Random;
10  
11  import org.apache.commons.cli.CommandLine;
12  import org.apache.commons.cli.CommandLineParser;
13  import org.apache.commons.cli.DefaultParser;
14  import org.apache.commons.cli.Options;
15  import org.apache.commons.cli.ParseException;
16  import org.apache.commons.lang3.Validate;
17  import org.apache.commons.lang3.time.DurationFormatUtils;
18  import org.apache.commons.math3.ml.clustering.CentroidCluster;
19  import org.apache.commons.math3.ml.clustering.KMeansPlusPlusClusterer;
20  import org.apache.logging.log4j.LogManager;
21  import org.apache.logging.log4j.Logger;
22  
23  import net.bmahe.genetics4j.core.EASystemFactory;
24  import net.bmahe.genetics4j.core.Fitness;
25  import net.bmahe.genetics4j.core.Genotype;
26  import net.bmahe.genetics4j.core.chromosomes.DoubleChromosome;
27  import net.bmahe.genetics4j.core.evolutionlisteners.EvolutionListeners;
28  import net.bmahe.genetics4j.core.spec.EAConfiguration;
29  import net.bmahe.genetics4j.core.spec.EAExecutionContexts;
30  import net.bmahe.genetics4j.core.spec.chromosome.DoubleChromosomeSpec;
31  import net.bmahe.genetics4j.core.spec.combination.MultiCombinations;
32  import net.bmahe.genetics4j.core.spec.combination.MultiPointArithmetic;
33  import net.bmahe.genetics4j.core.spec.combination.MultiPointCrossover;
34  import net.bmahe.genetics4j.core.spec.mutation.CreepMutation;
35  import net.bmahe.genetics4j.core.spec.mutation.MultiMutations;
36  import net.bmahe.genetics4j.core.spec.mutation.RandomMutation;
37  import net.bmahe.genetics4j.core.spec.selection.Tournament;
38  import net.bmahe.genetics4j.core.termination.Termination;
39  import net.bmahe.genetics4j.core.termination.Terminations;
40  import net.bmahe.genetics4j.extras.evolutionlisteners.CSVEvolutionListener;
41  import net.bmahe.genetics4j.extras.evolutionlisteners.ColumnExtractor;
42  import net.bmahe.genetics4j.samples.CLIUtils;
43  
44  public class Clustering {
45  	final static public Logger logger = LogManager.getLogger(Clustering.class);
46  
47  	final static public int DEFAULT_NUM_CLUSTERS = 6;
48  	final static public int DEFAULT_NUMBER_TOURNAMENTS = 2;
49  	final static public int DEFAULT_POPULATION_SIZE = 120;
50  	final static public double DEFAULT_RANDOM_MUTATION_RATE = 0.15d;
51  	final static public double DEFAULT_CREEP_MUTATION_RATE = 0.20d;
52  	final static public double DEFAULT_CREEP_MUTATION_MEAN = 0.0d;
53  	final static public double DEFAULT_CREEP_MUTATION_STDDEV = 5;
54  	final static public int DEFAULT_COMBINATION_ARITHMETIC = 3;
55  	final static public int DEFAULT_COMBINATION_CROSSOVER_ARITHMETIC = 3;
56  
57  	final static public String PARAM_NUM_CLUSTERS = "n";
58  	final static public String LONG_PARAM_NUM_CLUSTERS = "num-clusters";
59  
60  	final static public String PARAM_NUMBER_TOURNAMENTS = "t";
61  	final static public String LONG_PARAM_NUMBER_TOURNAMENTS = "num-tournaments";
62  
63  	final static public String PARAM_POPULATION_SIZE = "p";
64  	final static public String LONG_PARAM_POPULATION_SIZE = "population-size";
65  
66  	final static public String PARAM_SOURCE_CLUSTERS_CSV = "c";
67  	final static public String LONG_PARAM_SOURCE_CUSTERS_CSV = "source-clusters";
68  
69  	final static public String PARAM_SOURCE_DATA_CSV = "s";
70  	final static public String LONG_PARAM_SOURCE_DATA_CSV = "source-data";
71  
72  	final static public String PARAM_FIXED_TERMINATION = "f";
73  	final static public String LONG_PARAM_FIXED_TERMINATION = "fixed-termination";
74  
75  	final static public String PARAM_RANDOM_MUTATION_RATE = "r";
76  	final static public String LONG_PARAM_RANDOM_MUTATION_RATE = "random-mutation-rate";
77  
78  	final static public String PARAM_CREEP_MUTATION_RATE = "m";
79  	final static public String LONG_PARAM_CREEP_MUTATION_RATE = "creep-mutation-rate";
80  
81  	final static public String PARAM_CREEP_MUTATION_MEAN = "a";
82  	final static public String LONG_PARAM_CREEP_MUTATION_MEAN = "creep-mutation-mean";
83  
84  	final static public String PARAM_CREEP_MUTATION_STD_DEV = "d";
85  	final static public String LONG_PARAM_CREEP_MUTATION_STD_DEV = "creep-mutation-std-dev";
86  
87  	final static public String PARAM_COMBINATION_ARITHMETIC = "b";
88  	final static public String LONG_PARAM_COMBINATION_ARITHMETIC = "combination-arithmetic";
89  
90  	final static public String PARAM_COMBINATION_CROSSOVER = "e";
91  	final static public String LONG_PARAM_COMBINATION_CROSSOVER = "combination-crossover";
92  
93  	final static public String PARAM_OUTPUT_CSV = "o";
94  	final static public String LONG_PARAM_OUTPUT_CSV = "output";
95  
96  	final static public String PARAM_OUTPUT_WITH_SSE_CSV = "g";
97  	final static public String LONG_PARAM_OUTPUT_WITH_SSE_CSV = "output-sse";
98  
99  	final static public String PARAM_BASE_DIR_OUTPUT = "i";
100 	final static public String LONG_PARAM_BASE_DIR_OUTPUT = "base-dir";
101 
102 	public static void cliError(final Options options, final String errorMessage) {
103 		CLIUtils.cliHelpAndExit(logger, Clustering.class, options, errorMessage);
104 	}
105 
106 	private final static double computeDistance(final double[][] array, final int i, final int j) {
107 		final double xDiff = array[j][0] - array[i][0];
108 		final double yDiff = array[j][1] - array[i][1];
109 		return Math.sqrt((xDiff * xDiff) + (yDiff * yDiff));
110 	}
111 
112 	private final static double[][] computeAllDistances(final double[][] array) {
113 
114 		final double[][] distances = new double[array.length][array.length];
115 
116 		for (int i = 0; i < array.length; i++) {
117 			distances[i][i] = 0.0;
118 		}
119 
120 		for (int i = 0; i < array.length; i++) {
121 			for (int j = 0; j < i; j++) {
122 				final double distance = computeDistance(array, i, j);
123 				distances[i][j] = distance;
124 				distances[j][i] = distance;
125 			}
126 		}
127 
128 		return distances;
129 	}
130 
131 	// tag::cluster_generation[]
132 	public static double[][] generateClusters(final Random random, final int numClusters, final double minX,
133 			final double maxX, final double minY, final double maxY) {
134 		Validate.notNull(random);
135 		Validate.isTrue(numClusters > 0);
136 		Validate.isTrue(minX <= maxX);
137 		Validate.isTrue(minY <= maxY);
138 
139 		logger.info("Generating {} clusters", numClusters);
140 
141 		final double[][] clusters = new double[numClusters][2];
142 		for (int i = 0; i < numClusters; i++) {
143 			clusters[i][0] = minX + random.nextDouble() * (maxX - minX);
144 			clusters[i][1] = minY + random.nextDouble() * (maxY - minY);
145 		}
146 
147 		return clusters;
148 	}
149 	// end::cluster_generation[]
150 
151 	// tag::data_generation[]
152 	public static double[][] generateDataPoints(final Random random, final double[][] clusters, final int numDataPoints,
153 			final double radius) {
154 		Validate.notNull(random);
155 		Validate.notNull(clusters);
156 		Validate.isTrue(clusters.length > 0);
157 
158 		final int numClusters = clusters.length;
159 		final double[][] data = new double[numDataPoints][3];
160 		for (int i = 0; i < numDataPoints; i++) {
161 			final int clusterIndex = i % numClusters;
162 
163 			data[i][0] = random.nextGaussian() * radius + clusters[clusterIndex][0];
164 			data[i][1] = random.nextGaussian() * radius + clusters[clusterIndex][1];
165 			data[i][2] = clusterIndex;
166 		}
167 
168 		return data;
169 	}
170 	// end::data_generation[]
171 
172 	public static void doGA(final int k, final double min, final double max, final int numberTournaments,
173 			final int combinationArithmetic, final int combinationCrossover, final double randomMutationRate,
174 			final double creepMutationRate, final double creepMutationMean, final double creepMutationStdDev,
175 			final Fitness<Double> fitnessFunction, final Termination<Double> terminations, final int populationSize,
176 			final String outputCSV, final double[][] data, final double[][] distances, final String baseDir,
177 			final String filenameSuffix) throws IOException {
178 
179 		// tag::ea_configuration[]
180 		final var eaConfigurationBuilder = new EAConfiguration.Builder<Double>();
181 		eaConfigurationBuilder.chromosomeSpecs(DoubleChromosomeSpec.of(k * 2, min, max))
182 				.parentSelectionPolicy(Tournament.of(numberTournaments))
183 				.combinationPolicy(MultiCombinations.of(MultiPointArithmetic.of(combinationArithmetic, 0.5),
184 						MultiPointCrossover.of(combinationCrossover)))
185 				.mutationPolicies(MultiMutations.of(RandomMutation.of(randomMutationRate),
186 						CreepMutation.ofNormal(creepMutationRate, creepMutationMean, creepMutationStdDev)))
187 				.fitness(fitnessFunction)
188 				.postEvaluationProcessor(FitnessSharingUtils.clusterDistance)
189 				.termination(terminations);
190 		final var eaConfiguration = eaConfigurationBuilder.build();
191 		// end::ea_configuration[]
192 
193 		final var eaExecutionContext = EAExecutionContexts.<Double>forScalarFitness()
194 				.populationSize(populationSize)
195 				.addEvolutionListeners(EvolutionListeners.ofLogTopN(logger, 3),
196 						new CSVEvolutionListener.Builder<Double, Double>().filename(outputCSV)
197 								.columnExtractors(List.of(
198 										ColumnExtractor.of("generation", (evolutionStep) -> evolutionStep.generation()),
199 										ColumnExtractor.of("fitness", (evolutionStep) -> evolutionStep.fitness()),
200 										ColumnExtractor.of("combination_arithmetic", (evolutionStep) -> combinationArithmetic),
201 										ColumnExtractor.of("combination_crossover", (evolutionStep) -> combinationCrossover),
202 										ColumnExtractor.of("random_mutation_rate", (evolutionStep) -> randomMutationRate),
203 										ColumnExtractor.of("creep_mutation_mean", (evolutionStep) -> creepMutationMean),
204 										ColumnExtractor.of("creep_mutation_stddev", (evolutionStep) -> creepMutationStdDev),
205 										ColumnExtractor.of("creep_mutation_rate", (evolutionStep) -> creepMutationRate)))
206 								.build())
207 				.build();
208 
209 		final var eaSystem = EASystemFactory.from(eaConfiguration, eaExecutionContext);
210 
211 		final var evolutionResult = eaSystem.evolve();
212 		logger.info("Best genotype: {}", evolutionResult.bestGenotype());
213 		logger.info("  with fitness: {}", evolutionResult.bestFitness());
214 		logger.info("  at generation: {}", evolutionResult.generation());
215 
216 		final Genotype bestGenotype = evolutionResult.bestGenotype();
217 		final double[][] bestPhenotype = PhenotypeUtils.toPhenotype(bestGenotype);
218 		logger.info("Best phenotype:");
219 		for (int i = 0; i < k; i++) {
220 			logger.info("\tx: {} - y: {}", bestPhenotype[i][0], bestPhenotype[i][1]);
221 		}
222 		final int[] bestClusterMembership = FitnessUtils.assignDataToClusters(data, distances, bestPhenotype);
223 		IOUtils
224 				.persistDataPoints(data, bestClusterMembership, baseDir + "clustering-result-ga" + filenameSuffix + ".csv");
225 		IOUtils.persistClusters(bestPhenotype, baseDir + "clustering-result-clusters-ga" + filenameSuffix + ".csv");
226 
227 	}
228 
229 	public static List<CentroidCluster<LocationWrapper>> apacheCommonsMathCluster(final double[][] clusters,
230 			final double[][] data) {
231 
232 		logger.info("Initializing kmeans from Apache Commons Math");
233 
234 		final long startTs = System.currentTimeMillis();
235 
236 		final int numClusters = clusters.length;
237 		final int numDataPoints = data.length;
238 
239 		final List<LocationWrapper> clusterInput = new ArrayList<LocationWrapper>(numDataPoints);
240 		for (int i = 0; i < numDataPoints; i++) {
241 			clusterInput.add(new LocationWrapper(data[i]));
242 		}
243 
244 		logger.info("Running kmeans");
245 
246 		final KMeansPlusPlusClusterer<LocationWrapper> clusterer = new KMeansPlusPlusClusterer<LocationWrapper>(
247 				numClusters,
248 				10_000);
249 		final List<CentroidCluster<LocationWrapper>> clusterResults = clusterer.cluster(clusterInput);
250 
251 		final long durationMs = System.currentTimeMillis() - startTs;
252 		logger.info("Computation time: {}", DurationFormatUtils.formatDurationHMS(durationMs));
253 
254 		return clusterResults;
255 	}
256 
257 	public static void main(String[] args) throws IOException {
258 		logger.info("Starting");
259 
260 		final Random random = new Random();
261 
262 		final double min = -100;
263 		final double max = 100;
264 		final double minX = min;
265 		final double maxX = max;
266 		final double minY = min;
267 		final double maxY = max;
268 
269 		final double radius = 8;
270 
271 		final int numDataPoints = 1_000;
272 
273 		/**
274 		 * Parse CLI
275 		 */
276 
277 		final CommandLineParser parser = new DefaultParser();
278 
279 		final Options options = new Options();
280 		options.addOption(PARAM_NUM_CLUSTERS, LONG_PARAM_NUM_CLUSTERS, true, "number of clusters");
281 		options.addOption(PARAM_NUMBER_TOURNAMENTS, LONG_PARAM_NUMBER_TOURNAMENTS, true, "number of tournaments");
282 		options.addOption(PARAM_SOURCE_CLUSTERS_CSV, LONG_PARAM_SOURCE_CUSTERS_CSV, true, "source csv file for clusters");
283 		options.addOption(PARAM_SOURCE_DATA_CSV, LONG_PARAM_SOURCE_DATA_CSV, true, "source csv file for data");
284 		options.addOption(PARAM_OUTPUT_CSV, LONG_PARAM_OUTPUT_CSV, true, "output csv");
285 		options.addOption(PARAM_OUTPUT_WITH_SSE_CSV, LONG_PARAM_OUTPUT_WITH_SSE_CSV, true, "output with sse csv");
286 		options
287 				.addOption(PARAM_COMBINATION_ARITHMETIC, LONG_PARAM_COMBINATION_ARITHMETIC, true, "combination arithmetic");
288 		options.addOption(PARAM_COMBINATION_CROSSOVER, LONG_PARAM_COMBINATION_CROSSOVER, true, "combination crossover");
289 		options.addOption(PARAM_POPULATION_SIZE, LONG_PARAM_POPULATION_SIZE, true, "population size");
290 		options.addOption(PARAM_BASE_DIR_OUTPUT, LONG_PARAM_BASE_DIR_OUTPUT, true, "base directory");
291 		options.addOption(PARAM_CREEP_MUTATION_STD_DEV,
292 				LONG_PARAM_CREEP_MUTATION_STD_DEV,
293 				true,
294 				"creep mutation std dev. Default: " + DEFAULT_CREEP_MUTATION_STDDEV);
295 		options.addOption(PARAM_CREEP_MUTATION_MEAN,
296 				LONG_PARAM_CREEP_MUTATION_MEAN,
297 				true,
298 				"creep mutation mean. Default: " + DEFAULT_CREEP_MUTATION_MEAN);
299 		options.addOption(PARAM_CREEP_MUTATION_RATE,
300 				LONG_PARAM_CREEP_MUTATION_RATE,
301 				true,
302 				"creep mutation rate. Default: " + DEFAULT_CREEP_MUTATION_RATE);
303 		options.addOption(PARAM_RANDOM_MUTATION_RATE,
304 				LONG_PARAM_RANDOM_MUTATION_RATE,
305 				true,
306 				"random mutation rate. Default: " + DEFAULT_RANDOM_MUTATION_RATE);
307 		options.addOption(PARAM_FIXED_TERMINATION,
308 				LONG_PARAM_FIXED_TERMINATION,
309 				true,
310 				"Fix the termination to the specified number of generations");
311 
312 		Optional<String> paramSourceClustersCSV = Optional.empty();
313 		Optional<String> paramSourceDataCSV = Optional.empty();
314 		Optional<Integer> paramNumClusters = Optional.empty();
315 		Optional<String> paramOutputCSV = Optional.empty();
316 		Optional<String> paramOutputWithSSECSV = Optional.empty();
317 		Optional<Long> paramFixedTermination = Optional.empty();
318 
319 		int numberTournaments = DEFAULT_NUMBER_TOURNAMENTS;
320 		int populationSize = DEFAULT_POPULATION_SIZE;
321 		final double randomMutationRate;
322 		final double creepMutationRate;
323 		final double creepMutationMean;
324 		final double creepMutationStdDev;
325 		final int combinationArithmetic;
326 		final int combinationCrossover;
327 		final String baseDir;
328 		try {
329 			final CommandLine line = parser.parse(options, args);
330 
331 			if (line.hasOption(PARAM_NUMBER_TOURNAMENTS)) {
332 				numberTournaments = Integer.parseInt(line.getOptionValue(PARAM_NUMBER_TOURNAMENTS)
333 						.strip());
334 			}
335 
336 			baseDir = Optional.ofNullable(line.getOptionValue(PARAM_BASE_DIR_OUTPUT))
337 					.orElse("");
338 
339 			combinationArithmetic = Optional.ofNullable(line.getOptionValue(PARAM_COMBINATION_ARITHMETIC))
340 					.map(String::strip)
341 					.map(Integer::parseInt)
342 					.orElse(DEFAULT_COMBINATION_ARITHMETIC);
343 
344 			combinationCrossover = Optional.ofNullable(line.getOptionValue(PARAM_COMBINATION_CROSSOVER))
345 					.map(String::strip)
346 					.map(Integer::parseInt)
347 					.orElse(DEFAULT_COMBINATION_CROSSOVER_ARITHMETIC);
348 
349 			paramNumClusters = Optional.ofNullable(line.getOptionValue(PARAM_NUM_CLUSTERS))
350 					.map(String::strip)
351 					.map(Integer::parseInt);
352 			populationSize = Optional.ofNullable(line.getOptionValue(PARAM_POPULATION_SIZE))
353 					.map(String::strip)
354 					.map(Integer::parseInt)
355 					.orElse(DEFAULT_POPULATION_SIZE);
356 
357 			paramSourceClustersCSV = Optional.ofNullable(line.getOptionValue(PARAM_SOURCE_CLUSTERS_CSV))
358 					.map(String::strip);
359 
360 			paramSourceDataCSV = Optional.ofNullable(line.getOptionValue(PARAM_SOURCE_DATA_CSV))
361 					.map(String::strip);
362 
363 			paramOutputCSV = Optional.ofNullable(line.getOptionValue(PARAM_OUTPUT_CSV))
364 					.map(String::strip);
365 			paramOutputWithSSECSV = Optional.ofNullable(line.getOptionValue(PARAM_OUTPUT_WITH_SSE_CSV))
366 					.map(String::strip);
367 
368 			paramFixedTermination = Optional.ofNullable(line.getOptionValue(PARAM_FIXED_TERMINATION))
369 					.map(String::strip)
370 					.map(Long::parseLong);
371 
372 			randomMutationRate = Optional.ofNullable(line.getOptionValue(PARAM_RANDOM_MUTATION_RATE))
373 					.map(String::strip)
374 					.map(Double::parseDouble)
375 					.orElse(DEFAULT_RANDOM_MUTATION_RATE);
376 
377 			creepMutationRate = Optional.ofNullable(line.getOptionValue(PARAM_CREEP_MUTATION_RATE))
378 					.map(String::strip)
379 					.map(Double::parseDouble)
380 					.orElse(DEFAULT_CREEP_MUTATION_RATE);
381 
382 			creepMutationMean = Optional.ofNullable(line.getOptionValue(PARAM_CREEP_MUTATION_MEAN))
383 					.map(String::strip)
384 					.map(Double::parseDouble)
385 					.orElse(DEFAULT_CREEP_MUTATION_MEAN);
386 
387 			creepMutationStdDev = Optional.ofNullable(line.getOptionValue(PARAM_CREEP_MUTATION_STD_DEV))
388 					.map(String::strip)
389 					.map(Double::parseDouble)
390 					.orElse(DEFAULT_CREEP_MUTATION_STDDEV);
391 
392 			logger.info("Unrecognized args:");
393 			boolean hasError = false;
394 			for (final String extraArg : line.getArgList()) {
395 				logger.info("\t[{}]", extraArg);
396 				if (extraArg.isBlank() == false) {
397 					hasError = true;
398 				}
399 			}
400 
401 			if (hasError) {
402 				throw new RuntimeException();
403 			}
404 		} catch (ParseException exp) {
405 			cliError(options, "Unexpected exception:" + exp.getMessage());
406 
407 			// This piece will never execute
408 			throw new RuntimeException(); // java doesn't detect the System.exit in cliError and create some issues with
409 													// potential not initialized final parameters.
410 		}
411 
412 		final int numClusters = paramNumClusters.orElse(DEFAULT_NUM_CLUSTERS);
413 
414 		logger.info("Preparing {} clusters", numClusters);
415 		final double[][] clusters = paramSourceClustersCSV.map(IOUtils::loadClusters)
416 				.orElseGet(() -> generateClusters(random, numClusters, minX, maxX, minY, maxY));
417 		logger.info("Found {} clusters", clusters.length);
418 
419 		logger.info("Preparing data points");
420 		final double[][] data = paramSourceDataCSV.map(sourceDataCSVFileName -> {
421 			try {
422 				logger.info("Loading data points");
423 				return IOUtils.loadDataPoints(sourceDataCSVFileName);
424 			} catch (IOException e) {
425 				throw new RuntimeException("Could not load " + sourceDataCSVFileName, e);
426 			}
427 		})
428 				.orElseGet(() -> generateDataPoints(random, clusters, numDataPoints, radius));
429 		final double[][] distances = computeAllDistances(data);
430 
431 		final String originalClustersFilename = "originalClusters.csv";
432 		if (paramSourceClustersCSV.isPresent()) {
433 			logger.info("Not persisting clusters since it was provided");
434 		} else {
435 			logger.info("Saving clusters to CSV: {}", originalClustersFilename);
436 			IOUtils.persistClusters(clusters, originalClustersFilename);
437 		}
438 
439 		final String originalDataFilename = "originalData.csv";
440 		if (paramSourceDataCSV.isPresent()) {
441 			logger.info("Not persisting data since it was provided");
442 		} else {
443 			logger.info("Saving data to CSV: {}", originalDataFilename);
444 			IOUtils.persistDataPoints(data, originalDataFilename);
445 		}
446 
447 		logger.info("Clustering data with Apache Commons Math");
448 
449 		final List<CentroidCluster<LocationWrapper>> clusterResults = apacheCommonsMathCluster(clusters, data);
450 
451 		logger.info("Definition of genetic problem");
452 
453 		final int k = numClusters;
454 		final var fitnessFunction = FitnessUtils.computeFitness(numDataPoints, data, distances, k);
455 
456 		final var terminations = paramFixedTermination
457 				.map(maxGeneration -> Terminations.<Double>ofMaxGeneration(maxGeneration))
458 				.orElseGet(() -> or(Terminations.<Double>ofMaxGeneration(500), Terminations.ofStableFitness(50)));
459 
460 		logger.info("Terminations: {}", paramFixedTermination);
461 		logger.info("Parameters: random_mutation_rate: {} - creep_mutation_rate: {} - creep_mutation_stddev: {} ",
462 				randomMutationRate,
463 				creepMutationRate,
464 				creepMutationStdDev);
465 		logger.info("Combinations: arithmetic {} ; crossover {}", combinationArithmetic, combinationCrossover);
466 
467 		logger.info("Running GA with Silhouette score");
468 		doGA(k,
469 				min,
470 				max,
471 				numberTournaments,
472 				combinationArithmetic,
473 				combinationCrossover,
474 				randomMutationRate,
475 				creepMutationRate,
476 				creepMutationMean,
477 				creepMutationStdDev,
478 				fitnessFunction,
479 				terminations,
480 				populationSize,
481 				paramOutputCSV.orElse("output.csv"),
482 				data,
483 				distances,
484 				baseDir,
485 				"");
486 
487 		logger.info("Running GA with Silhouette score + SSE");
488 		final var fitnessFunctionWithSumSquareErrors = FitnessUtils
489 				.computeFitnessWithSSE(numDataPoints, data, distances, k);
490 		doGA(k,
491 				min,
492 				max,
493 				numberTournaments,
494 				combinationArithmetic,
495 				combinationCrossover,
496 				randomMutationRate,
497 				creepMutationRate,
498 				creepMutationMean,
499 				creepMutationStdDev,
500 				fitnessFunctionWithSumSquareErrors,
501 				terminations,
502 				populationSize,
503 				paramOutputWithSSECSV.orElse("output-with-sse.csv"),
504 				data,
505 				distances,
506 				baseDir,
507 				"-with-sse");
508 
509 		logger.info("Original clusters:");
510 		final double[] originalMeans = new double[numClusters * 2];
511 		for (int i = 0; i < numClusters; i++) {
512 			logger.info("\tx: {} - y: {}", clusters[i][0], clusters[i][1]);
513 			originalMeans[i * 2] = clusters[i][0];
514 			originalMeans[i * 2 + 1] = clusters[i][1];
515 		}
516 		final var originalFitness = fitnessFunction
517 				.compute(new Genotype(new DoubleChromosome(k * 2, -100.0d, 100.0d, originalMeans)));
518 		logger.info("Original fitness: {}", originalFitness);
519 		final int[] originalClusterMembership = FitnessUtils.assignDataToClusters(data, distances, clusters);
520 		IOUtils.persistDataPoints(data, originalClusterMembership, baseDir + "clustering-result-original.csv");
521 		IOUtils.persistClusters(clusters, baseDir + "clustering-result-clusters-original.csv");
522 
523 		logger.info("kmeans output:");
524 		// output the clusters
525 		final double[] kmeansClusters = new double[numClusters * 2];
526 		for (int i = 0; i < clusterResults.size(); i++) {
527 			final CentroidCluster<LocationWrapper> centroidCluster = clusterResults.get(i);
528 			logger.info("\t{}", centroidCluster.getCenter());
529 			kmeansClusters[i * 2] = centroidCluster.getCenter()
530 					.getPoint()[0];
531 			kmeansClusters[i * 2 + 1] = centroidCluster.getCenter()
532 					.getPoint()[1];
533 		}
534 		final var kmeansGenotype = new Genotype(new DoubleChromosome(k * 2, -100.0d, 100.0d, kmeansClusters));
535 		final var kmeansFitness = fitnessFunction.compute(kmeansGenotype);
536 		logger.info("kmeans fitness: {}", kmeansFitness);
537 
538 		final int[] kmeansClusterMembership = FitnessUtils
539 				.assignDataToClusters(data, distances, PhenotypeUtils.toPhenotype(kmeansGenotype));
540 		IOUtils.persistDataPoints(data, kmeansClusterMembership, baseDir + "clustering-result-kmeans.csv");
541 		IOUtils.persistClusters(PhenotypeUtils.toPhenotype(kmeansGenotype),
542 				baseDir + "clustering-result-clusters-kmeans.csv");
543 
544 		logger.info("Done");
545 	}
546 }