1 package net.bmahe.genetics4j.samples.symbolicregression;
2
3 import java.io.File;
4 import java.io.IOException;
5 import java.util.Comparator;
6 import java.util.List;
7 import java.util.Optional;
8 import java.util.Random;
9 import java.util.stream.IntStream;
10
11 import org.apache.commons.cli.CommandLine;
12 import org.apache.commons.cli.CommandLineParser;
13 import org.apache.commons.cli.DefaultParser;
14 import org.apache.commons.cli.HelpFormatter;
15 import org.apache.commons.cli.Options;
16 import org.apache.commons.cli.ParseException;
17 import org.apache.commons.io.FileUtils;
18 import org.apache.commons.lang3.StringUtils;
19 import org.apache.commons.lang3.Validate;
20 import org.apache.commons.lang3.time.DurationFormatUtils;
21 import org.apache.logging.log4j.LogManager;
22 import org.apache.logging.log4j.Logger;
23
24 import net.bmahe.genetics4j.core.EASystem;
25 import net.bmahe.genetics4j.core.EASystemFactory;
26 import net.bmahe.genetics4j.core.Fitness;
27 import net.bmahe.genetics4j.core.Genotype;
28 import net.bmahe.genetics4j.core.chromosomes.TreeChromosome;
29 import net.bmahe.genetics4j.core.evolutionlisteners.EvolutionListener;
30 import net.bmahe.genetics4j.core.evolutionlisteners.EvolutionListeners;
31 import net.bmahe.genetics4j.core.spec.EAConfiguration;
32 import net.bmahe.genetics4j.core.spec.EAExecutionContext;
33 import net.bmahe.genetics4j.core.spec.EvolutionResult;
34 import net.bmahe.genetics4j.core.spec.Optimization;
35 import net.bmahe.genetics4j.core.spec.mutation.MultiMutations;
36 import net.bmahe.genetics4j.core.termination.Terminations;
37 import net.bmahe.genetics4j.gp.Operation;
38 import net.bmahe.genetics4j.gp.math.SimplificationRules;
39 import net.bmahe.genetics4j.gp.program.Program;
40 import net.bmahe.genetics4j.gp.spec.GPEAExecutionContexts;
41 import net.bmahe.genetics4j.gp.spec.chromosome.ProgramTreeChromosomeSpec;
42 import net.bmahe.genetics4j.gp.spec.combination.ProgramRandomCombine;
43 import net.bmahe.genetics4j.gp.spec.mutation.NodeReplacement;
44 import net.bmahe.genetics4j.gp.spec.mutation.ProgramApplyRules;
45 import net.bmahe.genetics4j.gp.spec.mutation.ProgramRandomMutate;
46 import net.bmahe.genetics4j.gp.spec.mutation.ProgramRandomPrune;
47 import net.bmahe.genetics4j.gp.utils.ProgramUtils;
48 import net.bmahe.genetics4j.gp.utils.TreeNodeUtils;
49 import net.bmahe.genetics4j.moo.FitnessVector;
50 import net.bmahe.genetics4j.moo.nsga2.spec.TournamentNSGA2Selection;
51 import net.bmahe.genetics4j.moo.spea2.spec.replacement.SPEA2Replacement;
52
53 public class SymbolicRegressionWithMOOSPEA2 {
54 final static public Logger logger = LogManager.getLogger(SymbolicRegressionWithMOOSPEA2.class);
55
56 final static public String PARAM_DEST_CSV = "d";
57 final static public String LONG_PARAM_DEST_CSV = "csv-dest";
58
59 final static public String PARAM_POPULATION_SIZE = "p";
60 final static public String LONG_PARAM_POPULATION_SIZE = "population-size";
61
62 final static public String DEFAULT_DEST_CSV = SymbolicRegressionWithMOOSPEA2.class.getSimpleName() + ".csv";
63
64 final static public int DEFAULT_POPULATION_SIZE = 500;
65
66 public static void cliError(final Options options, final String errorMessage) {
67 final HelpFormatter formatter = new HelpFormatter();
68 logger.error(errorMessage);
69 formatter.printHelp(SymbolicRegressionWithMOOSPEA2.class.getSimpleName(), options);
70 System.exit(-1);
71 }
72
73 @SuppressWarnings("unchecked")
74 public void run(String csvFilename, int populationSize) {
75 Validate.isTrue(StringUtils.isNotBlank(csvFilename));
76 Validate.isTrue(populationSize > 0);
77
78 final Random random = new Random();
79
80 final Program program = SymbolicRegressionUtils.buildProgram(random);
81
82 final Comparator<Genotype> deduplicator = (a, b) -> TreeNodeUtils.compare(a, b, 0);
83
84
85 final Fitness<FitnessVector<Double>> computeFitness = (genoType) -> {
86 final TreeChromosome<Operation<?>> chromosome = (TreeChromosome<Operation<?>>) genoType.getChromosome(0);
87 final Double[][] inputs = new Double[100][1];
88 for (int i = 0; i < 100; i++) {
89 inputs[i][0] = (i - 50) * 1.2;
90 }
91
92 double mse = 0;
93 for (final Double[] input : inputs) {
94
95 final double x = input[0];
96 final double expected = SymbolicRegressionUtils.evaluate(x);
97 final Object result = ProgramUtils.execute(chromosome, input);
98
99 if (Double.isFinite(expected)) {
100 final Double resultDouble = (Double) result;
101 if (Double.isFinite(resultDouble)) {
102 mse += (expected - resultDouble) * (expected - resultDouble);
103 } else {
104 mse += 1_000_000_000;
105 }
106 }
107 }
108
109 return Double.isFinite(mse) ? new FitnessVector<Double>(mse / 100.0,
110 (double) chromosome.getRoot()
111 .getSize())
112 : new FitnessVector<Double>(Double.MAX_VALUE, Double.MAX_VALUE);
113 };
114
115
116
117 final var eaConfigurationBuilder = new EAConfiguration.Builder<FitnessVector<Double>>();
118 eaConfigurationBuilder.chromosomeSpecs(ProgramTreeChromosomeSpec.of(program))
119 .parentSelectionPolicy(TournamentNSGA2Selection.ofFitnessVector(2, 3, deduplicator))
120 .replacementStrategy(SPEA2Replacement.ofFitnessVector(deduplicator))
121 .combinationPolicy(ProgramRandomCombine.build())
122 .mutationPolicies(MultiMutations
123 .of(ProgramRandomMutate.of(0.15 * 3), ProgramRandomPrune.of(0.15 * 3), NodeReplacement.of(0.15 * 3)),
124 ProgramApplyRules.of(SimplificationRules.SIMPLIFY_RULES))
125 .optimization(Optimization.MINIMIZE)
126 .termination(Terminations.or(Terminations.<FitnessVector<Double>>ofMaxGeneration(200),
127 (eaConfiguration, generation, population, fitness) -> fitness.stream()
128 .anyMatch(fv -> fv.get(0) <= 0.00001 && fv.get(1) <= 20)))
129 .fitness(computeFitness);
130 final EAConfiguration<FitnessVector<Double>> eaConfiguration = eaConfigurationBuilder.build();
131
132
133 final var eaExecutionContextBuilder = GPEAExecutionContexts.<FitnessVector<Double>>forGP(random);
134 eaExecutionContextBuilder.populationSize(populationSize);
135 eaExecutionContextBuilder.numberOfPartitions(Math.max(1,
136 Runtime.getRuntime()
137 .availableProcessors() - 3));
138
139 eaExecutionContextBuilder.addEvolutionListeners(
140 EvolutionListeners.ofLogTopN(logger,
141 5,
142 Comparator.<FitnessVector<Double>, Double>comparing(fv -> fv.get(0))
143 .reversed(),
144 (genotype) -> TreeNodeUtils.toStringTreeNode(genotype, 0)),
145 SymbolicRegressionUtils.csvLogger(csvFilename,
146 evolutionStep -> evolutionStep.fitness()
147 .get(0),
148 evolutionStep -> evolutionStep.fitness()
149 .get(1)),
150 new EvolutionListener<FitnessVector<Double>>() {
151
152 long startTime = -1;
153 long previousTime = -1;
154
155 @Override
156 public void onEvolution(final long generation, final List<Genotype> population,
157 final List<FitnessVector<Double>> fitness, final boolean isDone) {
158 final long now = System.currentTimeMillis();
159
160 if (startTime < 0) {
161 startTime = now;
162 }
163
164 if (previousTime < 0) {
165 previousTime = now;
166 }
167
168 if (generation > 1) {
169 logger.info("Execution time:");
170 logger.info("\tCurrent generation duration: {}",
171 DurationFormatUtils.formatDurationHMS(now - previousTime));
172 logger.info("\tAverage duration: {}",
173 DurationFormatUtils.formatDurationHMS((now - startTime) / (generation - 1)));
174 }
175
176 previousTime = now;
177 }
178 });
179
180 final EAExecutionContext<FitnessVector<Double>> eaExecutionContext = eaExecutionContextBuilder.build();
181 final EASystem<FitnessVector<Double>> eaSystem = EASystemFactory.from(eaConfiguration, eaExecutionContext);
182
183 final EvolutionResult<FitnessVector<Double>> evolutionResult = eaSystem.evolve();
184 final Genotype bestGenotype = evolutionResult.bestGenotype();
185 final TreeChromosome<Operation<?>> bestChromosome = (TreeChromosome<Operation<?>>) bestGenotype.getChromosome(0);
186 logger.info("Best genotype: {}", bestChromosome.getRoot());
187 logger.info("Best genotype - pretty print: {}", TreeNodeUtils.toStringTreeNode(bestChromosome.getRoot()));
188
189 final int depthIdx = 1;
190 for (int i = 0; i < 15; i++) {
191 final int depth = i;
192 final Optional<Integer> optIdx = IntStream.range(0,
193 evolutionResult.fitness()
194 .size())
195 .boxed()
196 .filter((idx) -> evolutionResult.fitness()
197 .get(idx)
198 .get(depthIdx) == depth)
199 .sorted((a, b) -> Double.compare(evolutionResult.fitness()
200 .get(a)
201 .get(0),
202 evolutionResult.fitness()
203 .get(b)
204 .get(0)))
205 .findFirst();
206
207 optIdx.stream()
208 .forEach((idx) -> {
209 final TreeChromosome<Operation<?>> treeChromosome = (TreeChromosome<Operation<?>>) evolutionResult
210 .population()
211 .get(idx)
212 .getChromosome(0);
213
214 logger.info("Best genotype for depth {} - score {} -> {}",
215 depth,
216 evolutionResult.fitness()
217 .get(idx)
218 .get(0),
219 TreeNodeUtils.toStringTreeNode(treeChromosome.getRoot()));
220 });
221 }
222 }
223
224 public static void main(String[] args) throws IOException {
225
226
227
228
229
230 final CommandLineParser parser = new DefaultParser();
231
232 final Options options = new Options();
233 options.addOption(PARAM_DEST_CSV, LONG_PARAM_DEST_CSV, true, "destination csv file");
234
235 options.addOption(PARAM_POPULATION_SIZE, LONG_PARAM_POPULATION_SIZE, true, "Population size");
236
237 String csvFilename = DEFAULT_DEST_CSV;
238 int populationSize = DEFAULT_POPULATION_SIZE;
239 try {
240 final CommandLine line = parser.parse(options, args);
241
242 if (line.hasOption(PARAM_DEST_CSV)) {
243 csvFilename = line.getOptionValue(PARAM_DEST_CSV);
244 }
245
246 if (line.hasOption(PARAM_POPULATION_SIZE)) {
247 populationSize = Integer.parseInt(line.getOptionValue(PARAM_POPULATION_SIZE));
248 }
249
250 } catch (ParseException exp) {
251 cliError(options, "Unexpected exception:" + exp.getMessage());
252 }
253
254 logger.info("Population size: {}", populationSize);
255
256 logger.info("CSV output located at {}", csvFilename);
257 FileUtils.forceMkdirParent(new File(csvFilename));
258
259 final var symbolicRegression = new SymbolicRegressionWithMOOSPEA2();
260 symbolicRegression.run(csvFilename, populationSize);
261 }
262 }