View Javadoc
1   package net.bmahe.genetics4j.samples.mixturemodel;
2   
3   import java.io.IOException;
4   import java.util.Collection;
5   import java.util.HashSet;
6   import java.util.List;
7   import java.util.Set;
8   import java.util.stream.IntStream;
9   
10  import org.apache.commons.collections4.CollectionUtils;
11  import org.apache.commons.math3.distribution.MultivariateNormalDistribution;
12  import org.apache.commons.math3.exception.MathUnsupportedOperationException;
13  import org.apache.commons.math3.linear.NonPositiveDefiniteMatrixException;
14  import org.apache.commons.math3.linear.SingularMatrixException;
15  import org.apache.logging.log4j.LogManager;
16  import org.apache.logging.log4j.Logger;
17  
18  import net.bmahe.genetics4j.core.EASystem;
19  import net.bmahe.genetics4j.core.EASystemFactory;
20  import net.bmahe.genetics4j.core.Fitness;
21  import net.bmahe.genetics4j.core.Genotype;
22  import net.bmahe.genetics4j.core.chromosomes.DoubleChromosome;
23  import net.bmahe.genetics4j.core.evolutionlisteners.EvolutionListeners;
24  import net.bmahe.genetics4j.core.spec.EAConfiguration;
25  import net.bmahe.genetics4j.core.spec.EAConfiguration.Builder;
26  import net.bmahe.genetics4j.core.spec.EAExecutionContexts;
27  import net.bmahe.genetics4j.core.spec.EvolutionResult;
28  import net.bmahe.genetics4j.core.spec.chromosome.DoubleChromosomeSpec;
29  import net.bmahe.genetics4j.core.spec.combination.MultiCombinations;
30  import net.bmahe.genetics4j.core.spec.combination.MultiPointArithmetic;
31  import net.bmahe.genetics4j.core.spec.combination.MultiPointCrossover;
32  import net.bmahe.genetics4j.core.spec.combination.SinglePointArithmetic;
33  import net.bmahe.genetics4j.core.spec.combination.SinglePointCrossover;
34  import net.bmahe.genetics4j.core.spec.mutation.CreepMutation;
35  import net.bmahe.genetics4j.core.spec.mutation.MultiMutations;
36  import net.bmahe.genetics4j.core.spec.mutation.RandomMutation;
37  import net.bmahe.genetics4j.core.spec.mutation.SwapMutation;
38  import net.bmahe.genetics4j.core.spec.selection.Tournament;
39  import net.bmahe.genetics4j.core.spec.statistics.distributions.NormalDistribution;
40  import net.bmahe.genetics4j.core.termination.Terminations;
41  import net.bmahe.genetics4j.extras.evolutionlisteners.CSVEvolutionListener;
42  import net.bmahe.genetics4j.extras.evolutionlisteners.ColumnExtractor;
43  
44  public class SingleObjectiveMethod {
45  	final static public Logger logger = LogManager.getLogger(SingleObjectiveMethod.class);
46  
47  	private final int distributionNumParameters;
48  	private final String baseDir;
49  	private final int maxGenerations;
50  
51  	public SingleObjectiveMethod(final int _distributionNumParameters, final String _baseDir,
52  			final int _maxGenerations) {
53  
54  		this.distributionNumParameters = _distributionNumParameters;
55  		this.baseDir = _baseDir;
56  		this.maxGenerations = _maxGenerations;
57  	}
58  
59  	// tag::som_fitness[]
60  	public Fitness<Double> fitnessCPU(final int numDistributions, final double[][] samples) {
61  		return (genotype) -> {
62  			final var fChromosome = genotype.getChromosome(0, DoubleChromosome.class);
63  
64  			/**
65  			 * Normalize alpha
66  			 */
67  			double sumAlpha = 0.0f;
68  			int k = 0;
69  			while (k < fChromosome.getSize()) {
70  				sumAlpha += fChromosome.getAllele(k);
71  				k += distributionNumParameters;
72  			}
73  
74  			double[] likelyhoods = new double[samples.length];
75  			int i = 0;
76  			while (i < fChromosome.getSize()) {
77  
78  				final double alpha = fChromosome.getAllele(i) / sumAlpha;
79  				if (alpha > 0.0001) {
80  					final double[] mean = new double[] { fChromosome.getAllele(i + 1), fChromosome.getAllele(i + 2) };
81  					final double[][] covariances = new double[][] {
82  							{ fChromosome.getAllele(i + 3) - 15, fChromosome.getAllele(i + 4) - 15 },
83  							{ fChromosome.getAllele(i + 4) - 15, fChromosome.getAllele(i + 5) - 15 } };
84  
85  					try {
86  						final var multivariateNormalDistribution = new MultivariateNormalDistribution(mean, covariances);
87  
88  						for (int j = 0; j < samples.length; j++) {
89  							final var density = multivariateNormalDistribution.density(samples[j]);
90  							likelyhoods[j] += alpha * density;
91  						}
92  					} catch (NonPositiveDefiniteMatrixException | MathUnsupportedOperationException
93  							| SingularMatrixException e) {
94  						// Ignore invalid mixtures
95  					}
96  				}
97  				i += distributionNumParameters;
98  			}
99  
100 			double sumLogs = 0.0f;
101 			for (int j = 0; j < samples.length; j++) {
102 				sumLogs += Math.log(likelyhoods[j]);
103 			}
104 
105 			return sumLogs / samples.length;
106 		};
107 	}
108 	// end::som_fitness[]
109 
110 	public EvolutionResult<Double> run(final int maxPossibleDistributions, final double[][] samples, final float[] x,
111 			final float[] y, final String algorithmName, final Collection<Genotype> seedPopulation) throws IOException {
112 
113 		// tag::som_config[]
114 		final Builder<Double> eaConfigurationBuilder = new EAConfiguration.Builder<>();
115 		eaConfigurationBuilder
116 				.chromosomeSpecs(DoubleChromosomeSpec.of(distributionNumParameters * maxPossibleDistributions, 0, 30))
117 				.parentSelectionPolicy(Tournament.of(2))
118 				.combinationPolicy(MultiCombinations.of(SinglePointArithmetic.of(0.9),
119 						SinglePointCrossover.build(),
120 						MultiPointCrossover.of(2),
121 						MultiPointArithmetic.of(2, 0.9)))
122 				.mutationPolicies(MultiMutations.of(RandomMutation.of(0.40),
123 						CreepMutation.of(0.40, NormalDistribution.of(0.0, 2)),
124 						SwapMutation.of(0.30, 5, false)))
125 				.fitness(fitnessCPU(maxPossibleDistributions, samples))
126 				.termination(Terminations.ofMaxGeneration(maxGenerations));
127 
128 		if (CollectionUtils.isNotEmpty(seedPopulation)) {
129 			eaConfigurationBuilder.seedPopulation(seedPopulation);
130 		}
131 
132 		final var eaConfiguration = eaConfigurationBuilder.build();
133 		// end::som_config[]
134 
135 		// tag::som_eaexeccontext[]
136 		final var csvEvolutionListener = CSVEvolutionListener.<Double, Void>of(baseDir + "mixturemodel-so-cpu.csv",
137 				List.of(ColumnExtractor.of("generation", e -> e.generation()),
138 						ColumnExtractor.of("fitness", e -> e.fitness())));
139 
140 		final var eaExecutionContextBuilder = EAExecutionContexts.<Double>standard();
141 		eaExecutionContextBuilder.populationSize(250)
142 				.addEvolutionListeners(csvEvolutionListener, EvolutionListeners.ofLogTopN(logger, 5))
143 				.build();
144 		final var eaExecutionContext = eaExecutionContextBuilder.build();
145 		// end::som_eaexeccontext[]
146 
147 		final EASystem<Double> eaSystem = EASystemFactory.from(eaConfiguration, eaExecutionContext);
148 
149 		final EvolutionResult<Double> evolutionResult = eaSystem.evolve();
150 		logger.info("Best genotype: {}", evolutionResult.bestGenotype());
151 		logger.info("  with fitness: {}", evolutionResult.bestFitness());
152 		logger.info("  at generation: {}", evolutionResult.generation());
153 
154 		final int[] assignedClusters = ClusteringUtils
155 				.assignClustersDoubleChromosome(distributionNumParameters, samples, evolutionResult.bestGenotype());
156 		final Set<Integer> uniqueAssigned = new HashSet<>();
157 		uniqueAssigned.addAll(IntStream.of(assignedClusters)
158 				.boxed()
159 				.toList());
160 
161 		ClusteringUtils
162 				.persistClusters(x, y, assignedClusters, baseDir + "assigned-so-" + uniqueAssigned.size() + ".csv");
163 
164 		return evolutionResult;
165 	}
166 }