1 package net.bmahe.genetics4j.samples.mixturemodel;
2
3 import java.io.IOException;
4 import java.util.Collection;
5 import java.util.HashSet;
6 import java.util.List;
7 import java.util.Set;
8 import java.util.stream.IntStream;
9
10 import org.apache.commons.collections4.CollectionUtils;
11 import org.apache.commons.math3.distribution.MultivariateNormalDistribution;
12 import org.apache.commons.math3.exception.MathUnsupportedOperationException;
13 import org.apache.commons.math3.linear.NonPositiveDefiniteMatrixException;
14 import org.apache.commons.math3.linear.SingularMatrixException;
15 import org.apache.logging.log4j.LogManager;
16 import org.apache.logging.log4j.Logger;
17
18 import net.bmahe.genetics4j.core.EASystem;
19 import net.bmahe.genetics4j.core.EASystemFactory;
20 import net.bmahe.genetics4j.core.Fitness;
21 import net.bmahe.genetics4j.core.Genotype;
22 import net.bmahe.genetics4j.core.chromosomes.DoubleChromosome;
23 import net.bmahe.genetics4j.core.evolutionlisteners.EvolutionListeners;
24 import net.bmahe.genetics4j.core.spec.EAConfiguration;
25 import net.bmahe.genetics4j.core.spec.EAConfiguration.Builder;
26 import net.bmahe.genetics4j.core.spec.EAExecutionContexts;
27 import net.bmahe.genetics4j.core.spec.EvolutionResult;
28 import net.bmahe.genetics4j.core.spec.chromosome.DoubleChromosomeSpec;
29 import net.bmahe.genetics4j.core.spec.combination.MultiCombinations;
30 import net.bmahe.genetics4j.core.spec.combination.MultiPointArithmetic;
31 import net.bmahe.genetics4j.core.spec.combination.MultiPointCrossover;
32 import net.bmahe.genetics4j.core.spec.combination.SinglePointArithmetic;
33 import net.bmahe.genetics4j.core.spec.combination.SinglePointCrossover;
34 import net.bmahe.genetics4j.core.spec.mutation.CreepMutation;
35 import net.bmahe.genetics4j.core.spec.mutation.MultiMutations;
36 import net.bmahe.genetics4j.core.spec.mutation.RandomMutation;
37 import net.bmahe.genetics4j.core.spec.mutation.SwapMutation;
38 import net.bmahe.genetics4j.core.spec.selection.Tournament;
39 import net.bmahe.genetics4j.core.spec.statistics.distributions.NormalDistribution;
40 import net.bmahe.genetics4j.core.termination.Terminations;
41 import net.bmahe.genetics4j.extras.evolutionlisteners.CSVEvolutionListener;
42 import net.bmahe.genetics4j.extras.evolutionlisteners.ColumnExtractor;
43
44 public class SingleObjectiveMethod {
45 final static public Logger logger = LogManager.getLogger(SingleObjectiveMethod.class);
46
47 private final int distributionNumParameters;
48 private final String baseDir;
49 private final int maxGenerations;
50
51 public SingleObjectiveMethod(final int _distributionNumParameters, final String _baseDir,
52 final int _maxGenerations) {
53
54 this.distributionNumParameters = _distributionNumParameters;
55 this.baseDir = _baseDir;
56 this.maxGenerations = _maxGenerations;
57 }
58
59
60 public Fitness<Double> fitnessCPU(final int numDistributions, final double[][] samples) {
61 return (genotype) -> {
62 final var fChromosome = genotype.getChromosome(0, DoubleChromosome.class);
63
64
65
66
67 double sumAlpha = 0.0f;
68 int k = 0;
69 while (k < fChromosome.getSize()) {
70 sumAlpha += fChromosome.getAllele(k);
71 k += distributionNumParameters;
72 }
73
74 double[] likelyhoods = new double[samples.length];
75 int i = 0;
76 while (i < fChromosome.getSize()) {
77
78 final double alpha = fChromosome.getAllele(i) / sumAlpha;
79 if (alpha > 0.0001) {
80 final double[] mean = new double[] { fChromosome.getAllele(i + 1), fChromosome.getAllele(i + 2) };
81 final double[][] covariances = new double[][] {
82 { fChromosome.getAllele(i + 3) - 15, fChromosome.getAllele(i + 4) - 15 },
83 { fChromosome.getAllele(i + 4) - 15, fChromosome.getAllele(i + 5) - 15 } };
84
85 try {
86 final var multivariateNormalDistribution = new MultivariateNormalDistribution(mean, covariances);
87
88 for (int j = 0; j < samples.length; j++) {
89 final var density = multivariateNormalDistribution.density(samples[j]);
90 likelyhoods[j] += alpha * density;
91 }
92 } catch (NonPositiveDefiniteMatrixException | MathUnsupportedOperationException
93 | SingularMatrixException e) {
94
95 }
96 }
97 i += distributionNumParameters;
98 }
99
100 double sumLogs = 0.0f;
101 for (int j = 0; j < samples.length; j++) {
102 sumLogs += Math.log(likelyhoods[j]);
103 }
104
105 return sumLogs / samples.length;
106 };
107 }
108
109
110 public EvolutionResult<Double> run(final int maxPossibleDistributions, final double[][] samples, final float[] x,
111 final float[] y, final String algorithmName, final Collection<Genotype> seedPopulation) throws IOException {
112
113
114 final Builder<Double> eaConfigurationBuilder = new EAConfiguration.Builder<>();
115 eaConfigurationBuilder
116 .chromosomeSpecs(DoubleChromosomeSpec.of(distributionNumParameters * maxPossibleDistributions, 0, 30))
117 .parentSelectionPolicy(Tournament.of(2))
118 .combinationPolicy(MultiCombinations.of(SinglePointArithmetic.of(0.9),
119 SinglePointCrossover.build(),
120 MultiPointCrossover.of(2),
121 MultiPointArithmetic.of(2, 0.9)))
122 .mutationPolicies(MultiMutations.of(RandomMutation.of(0.40),
123 CreepMutation.of(0.40, NormalDistribution.of(0.0, 2)),
124 SwapMutation.of(0.30, 5, false)))
125 .fitness(fitnessCPU(maxPossibleDistributions, samples))
126 .termination(Terminations.ofMaxGeneration(maxGenerations));
127
128 if (CollectionUtils.isNotEmpty(seedPopulation)) {
129 eaConfigurationBuilder.seedPopulation(seedPopulation);
130 }
131
132 final var eaConfiguration = eaConfigurationBuilder.build();
133
134
135
136 final var csvEvolutionListener = CSVEvolutionListener.<Double, Void>of(baseDir + "mixturemodel-so-cpu.csv",
137 List.of(ColumnExtractor.of("generation", e -> e.generation()),
138 ColumnExtractor.of("fitness", e -> e.fitness())));
139
140 final var eaExecutionContextBuilder = EAExecutionContexts.<Double>standard();
141 eaExecutionContextBuilder.populationSize(250)
142 .addEvolutionListeners(csvEvolutionListener, EvolutionListeners.ofLogTopN(logger, 5))
143 .build();
144 final var eaExecutionContext = eaExecutionContextBuilder.build();
145
146
147 final EASystem<Double> eaSystem = EASystemFactory.from(eaConfiguration, eaExecutionContext);
148
149 final EvolutionResult<Double> evolutionResult = eaSystem.evolve();
150 logger.info("Best genotype: {}", evolutionResult.bestGenotype());
151 logger.info(" with fitness: {}", evolutionResult.bestFitness());
152 logger.info(" at generation: {}", evolutionResult.generation());
153
154 final int[] assignedClusters = ClusteringUtils
155 .assignClustersDoubleChromosome(distributionNumParameters, samples, evolutionResult.bestGenotype());
156 final Set<Integer> uniqueAssigned = new HashSet<>();
157 uniqueAssigned.addAll(IntStream.of(assignedClusters)
158 .boxed()
159 .toList());
160
161 ClusteringUtils
162 .persistClusters(x, y, assignedClusters, baseDir + "assigned-so-" + uniqueAssigned.size() + ".csv");
163
164 return evolutionResult;
165 }
166 }