View Javadoc
1   package net.bmahe.genetics4j.samples.mixturemodel;
2   
3   import java.io.IOException;
4   import java.util.Collection;
5   import java.util.HashSet;
6   import java.util.List;
7   import java.util.Set;
8   import java.util.stream.IntStream;
9   
10  import org.apache.commons.collections4.CollectionUtils;
11  import org.apache.commons.math3.distribution.MultivariateNormalDistribution;
12  import org.apache.commons.math3.exception.MathUnsupportedOperationException;
13  import org.apache.commons.math3.linear.NonPositiveDefiniteMatrixException;
14  import org.apache.commons.math3.linear.SingularMatrixException;
15  import org.apache.logging.log4j.LogManager;
16  import org.apache.logging.log4j.Logger;
17  
18  import net.bmahe.genetics4j.core.EASystem;
19  import net.bmahe.genetics4j.core.EASystemFactory;
20  import net.bmahe.genetics4j.core.Fitness;
21  import net.bmahe.genetics4j.core.Genotype;
22  import net.bmahe.genetics4j.core.chromosomes.DoubleChromosome;
23  import net.bmahe.genetics4j.core.evolutionlisteners.EvolutionListeners;
24  import net.bmahe.genetics4j.core.spec.EAConfiguration;
25  import net.bmahe.genetics4j.core.spec.EAConfiguration.Builder;
26  import net.bmahe.genetics4j.core.spec.EAExecutionContexts;
27  import net.bmahe.genetics4j.core.spec.EvolutionResult;
28  import net.bmahe.genetics4j.core.spec.chromosome.DoubleChromosomeSpec;
29  import net.bmahe.genetics4j.core.spec.combination.MultiCombinations;
30  import net.bmahe.genetics4j.core.spec.combination.MultiPointArithmetic;
31  import net.bmahe.genetics4j.core.spec.combination.MultiPointCrossover;
32  import net.bmahe.genetics4j.core.spec.combination.SinglePointArithmetic;
33  import net.bmahe.genetics4j.core.spec.combination.SinglePointCrossover;
34  import net.bmahe.genetics4j.core.spec.mutation.CreepMutation;
35  import net.bmahe.genetics4j.core.spec.mutation.MultiMutations;
36  import net.bmahe.genetics4j.core.spec.mutation.RandomMutation;
37  import net.bmahe.genetics4j.core.spec.mutation.SwapMutation;
38  import net.bmahe.genetics4j.core.spec.selection.Tournament;
39  import net.bmahe.genetics4j.core.spec.statistics.distributions.NormalDistribution;
40  import net.bmahe.genetics4j.core.termination.Terminations;
41  import net.bmahe.genetics4j.extras.evolutionlisteners.CSVEvolutionListener;
42  import net.bmahe.genetics4j.extras.evolutionlisteners.ColumnExtractor;
43  
44  public class SingleObjectiveMethod {
45  	final static public Logger logger = LogManager.getLogger(SingleObjectiveMethod.class);
46  
47  	private final int distributionNumParameters;
48  	private final String baseDir;
49  	private final int maxGenerations;
50  
51  	public SingleObjectiveMethod(final int _distributionNumParameters,
52  			final String _baseDir,
53  			final int _maxGenerations) {
54  
55  		this.distributionNumParameters = _distributionNumParameters;
56  		this.baseDir = _baseDir;
57  		this.maxGenerations = _maxGenerations;
58  	}
59  
60  	// tag::som_fitness[]
61  	public Fitness<Double> fitnessCPU(final int numDistributions, final double[][] samples) {
62  		return (genotype) -> {
63  			final var fChromosome = genotype.getChromosome(0, DoubleChromosome.class);
64  
65  			/**
66  			 * Normalize alpha
67  			 */
68  			double sumAlpha = 0.0f;
69  			int k = 0;
70  			while (k < fChromosome.getSize()) {
71  				sumAlpha += fChromosome.getAllele(k);
72  				k += distributionNumParameters;
73  			}
74  
75  			double[] likelyhoods = new double[samples.length];
76  			int i = 0;
77  			while (i < fChromosome.getSize()) {
78  
79  				final double alpha = fChromosome.getAllele(i) / sumAlpha;
80  				if (alpha > 0.0001) {
81  					final double[] mean = new double[] { fChromosome.getAllele(i + 1), fChromosome.getAllele(i + 2) };
82  					final double[][] covariances = new double[][] {
83  							{ fChromosome.getAllele(i + 3) - 15, fChromosome.getAllele(i + 4) - 15 },
84  							{ fChromosome.getAllele(i + 4) - 15, fChromosome.getAllele(i + 5) - 15 } };
85  
86  					try {
87  						final var multivariateNormalDistribution = new MultivariateNormalDistribution(mean, covariances);
88  
89  						for (int j = 0; j < samples.length; j++) {
90  							final var density = multivariateNormalDistribution.density(samples[j]);
91  							likelyhoods[j] += alpha * density;
92  						}
93  					} catch (NonPositiveDefiniteMatrixException | MathUnsupportedOperationException
94  							| SingularMatrixException e) {
95  						// Ignore invalid mixtures
96  					}
97  				}
98  				i += distributionNumParameters;
99  			}
100 
101 			double sumLogs = 0.0f;
102 			for (int j = 0; j < samples.length; j++) {
103 				sumLogs += Math.log(likelyhoods[j]);
104 			}
105 
106 			return sumLogs / samples.length;
107 		};
108 	}
109 	// end::som_fitness[]
110 
111 	public EvolutionResult<Double> run(final int maxPossibleDistributions, final double[][] samples, final float[] x,
112 			final float[] y, final String algorithmName, final Collection<Genotype> seedPopulation) throws IOException {
113 
114 		// tag::som_config[]
115 		final Builder<Double> eaConfigurationBuilder = new EAConfiguration.Builder<>();
116 		eaConfigurationBuilder
117 				.chromosomeSpecs(DoubleChromosomeSpec.of(distributionNumParameters * maxPossibleDistributions, 0, 30))
118 				.parentSelectionPolicy(Tournament.of(2))
119 				.combinationPolicy(
120 						MultiCombinations.of(
121 								SinglePointArithmetic.of(0.9),
122 									SinglePointCrossover.build(),
123 									MultiPointCrossover.of(2),
124 									MultiPointArithmetic.of(2, 0.9)))
125 				.mutationPolicies(
126 						MultiMutations.of(
127 								RandomMutation.of(0.40),
128 									CreepMutation.of(0.40, NormalDistribution.of(0.0, 2)),
129 									SwapMutation.of(0.30, 5, false)))
130 				.fitness(fitnessCPU(maxPossibleDistributions, samples))
131 				.termination(Terminations.ofMaxGeneration(maxGenerations));
132 
133 		if (CollectionUtils.isNotEmpty(seedPopulation)) {
134 			eaConfigurationBuilder.seedPopulation(seedPopulation);
135 		}
136 
137 		final var eaConfiguration = eaConfigurationBuilder.build();
138 		// end::som_config[]
139 
140 		// tag::som_eaexeccontext[]
141 		final var csvEvolutionListener = CSVEvolutionListener.<Double, Void>of(
142 				baseDir + "mixturemodel-so-cpu.csv",
143 					List.of(
144 							ColumnExtractor.of("generation", e -> e.generation()),
145 								ColumnExtractor.of("fitness", e -> e.fitness())));
146 
147 		final var eaExecutionContextBuilder = EAExecutionContexts.<Double>standard();
148 		eaExecutionContextBuilder.populationSize(250)
149 				.addEvolutionListeners(csvEvolutionListener, EvolutionListeners.ofLogTopN(logger, 5))
150 				.build();
151 		final var eaExecutionContext = eaExecutionContextBuilder.build();
152 		// end::som_eaexeccontext[]
153 
154 		final EASystem<Double> eaSystem = EASystemFactory.from(eaConfiguration, eaExecutionContext);
155 
156 		final EvolutionResult<Double> evolutionResult = eaSystem.evolve();
157 		logger.info("Best genotype: {}", evolutionResult.bestGenotype());
158 		logger.info("  with fitness: {}", evolutionResult.bestFitness());
159 		logger.info("  at generation: {}", evolutionResult.generation());
160 
161 		final int[] assignedClusters = ClusteringUtils
162 				.assignClustersDoubleChromosome(distributionNumParameters, samples, evolutionResult.bestGenotype());
163 		final Set<Integer> uniqueAssigned = new HashSet<>();
164 		uniqueAssigned.addAll(IntStream.of(assignedClusters).boxed().toList());
165 
166 		ClusteringUtils
167 				.persistClusters(x, y, assignedClusters, baseDir + "assigned-so-" + uniqueAssigned.size() + ".csv");
168 
169 		return evolutionResult;
170 	}
171 }