1 package net.bmahe.genetics4j.samples.symbolicregression;
2
3 import java.io.File;
4 import java.io.IOException;
5 import java.util.Comparator;
6 import java.util.Optional;
7 import java.util.Random;
8 import java.util.stream.IntStream;
9
10 import org.apache.commons.cli.CommandLine;
11 import org.apache.commons.cli.CommandLineParser;
12 import org.apache.commons.cli.DefaultParser;
13 import org.apache.commons.cli.HelpFormatter;
14 import org.apache.commons.cli.Options;
15 import org.apache.commons.cli.ParseException;
16 import org.apache.commons.io.FileUtils;
17 import org.apache.commons.lang3.StringUtils;
18 import org.apache.commons.lang3.Validate;
19 import org.apache.logging.log4j.LogManager;
20 import org.apache.logging.log4j.Logger;
21
22 import net.bmahe.genetics4j.core.EASystem;
23 import net.bmahe.genetics4j.core.EASystemFactory;
24 import net.bmahe.genetics4j.core.Fitness;
25 import net.bmahe.genetics4j.core.Genotype;
26 import net.bmahe.genetics4j.core.chromosomes.TreeChromosome;
27 import net.bmahe.genetics4j.core.evolutionlisteners.EvolutionListeners;
28 import net.bmahe.genetics4j.core.spec.EAConfiguration;
29 import net.bmahe.genetics4j.core.spec.EAExecutionContext;
30 import net.bmahe.genetics4j.core.spec.EvolutionResult;
31 import net.bmahe.genetics4j.core.spec.Optimization;
32 import net.bmahe.genetics4j.core.spec.mutation.MultiMutations;
33 import net.bmahe.genetics4j.core.spec.replacement.Elitism;
34 import net.bmahe.genetics4j.core.termination.Terminations;
35 import net.bmahe.genetics4j.gp.Operation;
36 import net.bmahe.genetics4j.gp.math.SimplificationRules;
37 import net.bmahe.genetics4j.gp.program.Program;
38 import net.bmahe.genetics4j.gp.spec.GPEAExecutionContexts;
39 import net.bmahe.genetics4j.gp.spec.chromosome.ProgramTreeChromosomeSpec;
40 import net.bmahe.genetics4j.gp.spec.combination.ProgramRandomCombine;
41 import net.bmahe.genetics4j.gp.spec.mutation.NodeReplacement;
42 import net.bmahe.genetics4j.gp.spec.mutation.ProgramApplyRules;
43 import net.bmahe.genetics4j.gp.spec.mutation.ProgramRandomMutate;
44 import net.bmahe.genetics4j.gp.spec.mutation.ProgramRandomPrune;
45 import net.bmahe.genetics4j.gp.utils.ProgramUtils;
46 import net.bmahe.genetics4j.gp.utils.TreeNodeUtils;
47 import net.bmahe.genetics4j.moo.FitnessVector;
48 import net.bmahe.genetics4j.moo.nsga2.spec.NSGA2Selection;
49 import net.bmahe.genetics4j.moo.nsga2.spec.TournamentNSGA2Selection;
50
51 public class SymbolicRegressionWithMOO {
52 final static public Logger logger = LogManager.getLogger(SymbolicRegressionWithMOO.class);
53
54 final static public String PARAM_DEST_CSV = "d";
55 final static public String LONG_PARAM_DEST_CSV = "csv-dest";
56
57 final static public String PARAM_POPULATION_SIZE = "p";
58 final static public String LONG_PARAM_POPULATION_SIZE = "population-size";
59
60 final static public String DEFAULT_DEST_CSV = SymbolicRegressionWithMOO.class.getSimpleName() + ".csv";
61
62 final static public int DEFAULT_POPULATION_SIZE = 500;
63
64 public static void cliError(final Options options, final String errorMessage) {
65 final HelpFormatter formatter = new HelpFormatter();
66 logger.error(errorMessage);
67 formatter.printHelp(SymbolicRegressionWithMOO.class.getSimpleName(), options);
68 System.exit(-1);
69 }
70
71 @SuppressWarnings("unchecked")
72 public void run(String csvFilename, int populationSize) {
73 Validate.isTrue(StringUtils.isNotBlank(csvFilename));
74 Validate.isTrue(populationSize > 0);
75
76 final Random random = new Random();
77
78 final Program program = SymbolicRegressionUtils.buildProgram(random);
79
80 final Comparator<Genotype> deduplicator = (a, b) -> TreeNodeUtils.compare(a, b, 0);
81
82
83 final Fitness<FitnessVector<Double>> computeFitness = (genoType) -> {
84 final TreeChromosome<Operation<?>> chromosome = (TreeChromosome<Operation<?>>) genoType.getChromosome(0);
85 final Double[][] inputs = new Double[100][1];
86 for (int i = 0; i < 100; i++) {
87 inputs[i][0] = (i - 50) * 1.2;
88 }
89
90 double mse = 0;
91 for (final Double[] input : inputs) {
92
93 final double x = input[0];
94 final double expected = SymbolicRegressionUtils.evaluate(x);
95 final Object result = ProgramUtils.execute(chromosome, input);
96
97 if (Double.isFinite(expected)) {
98 final Double resultDouble = (Double) result;
99 if (Double.isFinite(resultDouble)) {
100 mse += (expected - resultDouble) * (expected - resultDouble);
101 } else {
102 mse += 1_000_000_000;
103 }
104 }
105 }
106
107 return Double.isFinite(mse) ? new FitnessVector<Double>(mse / 100.0,
108 (double) chromosome.getRoot()
109 .getSize())
110 : new FitnessVector<Double>(Double.MAX_VALUE, Double.MAX_VALUE);
111 };
112
113
114
115 final var eaConfigurationBuilder = new EAConfiguration.Builder<FitnessVector<Double>>();
116 eaConfigurationBuilder.chromosomeSpecs(ProgramTreeChromosomeSpec.of(program))
117 .parentSelectionPolicy(TournamentNSGA2Selection.ofFitnessVector(2, 3, deduplicator))
118 .replacementStrategy(Elitism.builder()
119 .offspringRatio(0.995)
120 .offspringSelectionPolicy(TournamentNSGA2Selection.ofFitnessVector(2, 3, deduplicator))
121 .survivorSelectionPolicy(NSGA2Selection.ofFitnessVector(2, deduplicator))
122 .build())
123 .combinationPolicy(ProgramRandomCombine.build())
124 .mutationPolicies(MultiMutations
125 .of(ProgramRandomMutate.of(0.15 * 3), ProgramRandomPrune.of(0.15 * 3), NodeReplacement.of(0.15 * 3)),
126 ProgramApplyRules.of(SimplificationRules.SIMPLIFY_RULES))
127 .optimization(Optimization.MINIMIZE)
128 .termination(Terminations.or(Terminations.<FitnessVector<Double>>ofMaxGeneration(200),
129 (eaConfiguration, generation, population, fitness) -> fitness.stream()
130 .anyMatch(fv -> fv.get(0) <= 0.000001 && fv.get(1) <= 20)))
131 .fitness(computeFitness);
132 final EAConfiguration<FitnessVector<Double>> eaConfiguration = eaConfigurationBuilder.build();
133
134
135
136 final var eaExecutionContextBuilder = GPEAExecutionContexts.<FitnessVector<Double>>forGP(random);
137
138 eaExecutionContextBuilder.populationSize(populationSize);
139 eaExecutionContextBuilder.numberOfPartitions(Math.max(1,
140 Runtime.getRuntime()
141 .availableProcessors() - 3));
142
143 eaExecutionContextBuilder.addEvolutionListeners(
144 EvolutionListeners.ofLogTopN(logger,
145 5,
146 Comparator.<FitnessVector<Double>, Double>comparing(fv -> fv.get(0))
147 .reversed(),
148 (genotype) -> TreeNodeUtils.toStringTreeNode(genotype, 0)),
149 SymbolicRegressionUtils.csvLogger(csvFilename,
150 evolutionStep -> evolutionStep.fitness()
151 .get(0),
152 evolutionStep -> evolutionStep.fitness()
153 .get(1)));
154
155 final EAExecutionContext<FitnessVector<Double>> eaExecutionContext = eaExecutionContextBuilder.build();
156 final EASystem<FitnessVector<Double>> eaSystem = EASystemFactory.from(eaConfiguration, eaExecutionContext);
157
158 final EvolutionResult<FitnessVector<Double>> evolutionResult = eaSystem.evolve();
159 final Genotype bestGenotype = evolutionResult.bestGenotype();
160 final TreeChromosome<Operation<?>> bestChromosome = (TreeChromosome<Operation<?>>) bestGenotype.getChromosome(0);
161 logger.info("Best genotype: {}", bestChromosome.getRoot());
162 logger.info("Best genotype - pretty print: {}", TreeNodeUtils.toStringTreeNode(bestChromosome.getRoot()));
163
164 final int depthIdx = 1;
165 for (int i = 0; i < 15; i++) {
166 final int depth = i;
167 final Optional<Integer> optIdx = IntStream.range(0,
168 evolutionResult.fitness()
169 .size())
170 .boxed()
171 .filter((idx) -> evolutionResult.fitness()
172 .get(idx)
173 .get(depthIdx) == depth)
174 .sorted((a, b) -> Double.compare(evolutionResult.fitness()
175 .get(a)
176 .get(0),
177 evolutionResult.fitness()
178 .get(b)
179 .get(0)))
180 .findFirst();
181
182 optIdx.stream()
183 .forEach((idx) -> {
184 final TreeChromosome<Operation<?>> treeChromosome = (TreeChromosome<Operation<?>>) evolutionResult
185 .population()
186 .get(idx)
187 .getChromosome(0);
188
189 logger.info("Best genotype for depth {} - score {} -> {}",
190 depth,
191 evolutionResult.fitness()
192 .get(idx)
193 .get(0),
194 TreeNodeUtils.toStringTreeNode(treeChromosome.getRoot()));
195 });
196 }
197 }
198
199 public static void main(String[] args) throws IOException {
200
201
202
203
204
205 final CommandLineParser parser = new DefaultParser();
206
207 final Options options = new Options();
208 options.addOption(PARAM_DEST_CSV, LONG_PARAM_DEST_CSV, true, "destination csv file");
209
210 options.addOption(PARAM_POPULATION_SIZE, LONG_PARAM_POPULATION_SIZE, true, "Population size");
211
212 String csvFilename = DEFAULT_DEST_CSV;
213 int populationSize = DEFAULT_POPULATION_SIZE;
214 try {
215 final CommandLine line = parser.parse(options, args);
216
217 if (line.hasOption(PARAM_DEST_CSV)) {
218 csvFilename = line.getOptionValue(PARAM_DEST_CSV);
219 }
220
221 if (line.hasOption(PARAM_POPULATION_SIZE)) {
222 populationSize = Integer.parseInt(line.getOptionValue(PARAM_POPULATION_SIZE));
223 }
224
225 } catch (ParseException exp) {
226 cliError(options, "Unexpected exception:" + exp.getMessage());
227 }
228
229 logger.info("Population size: {}", populationSize);
230
231 logger.info("CSV output located at {}", csvFilename);
232 FileUtils.forceMkdirParent(new File(csvFilename));
233
234 final var symbolicRegression = new SymbolicRegressionWithMOO();
235 symbolicRegression.run(csvFilename, populationSize);
236 }
237 }