1 package net.bmahe.genetics4j.samples.symbolicregression;
2
3 import java.util.Comparator;
4 import java.util.Random;
5
6 import org.apache.logging.log4j.LogManager;
7 import org.apache.logging.log4j.Logger;
8
9 import net.bmahe.genetics4j.core.EASystem;
10 import net.bmahe.genetics4j.core.EASystemFactory;
11 import net.bmahe.genetics4j.core.Fitness;
12 import net.bmahe.genetics4j.core.Genotype;
13 import net.bmahe.genetics4j.core.chromosomes.TreeChromosome;
14 import net.bmahe.genetics4j.core.chromosomes.TreeNode;
15 import net.bmahe.genetics4j.core.evolutionlisteners.EvolutionListeners;
16 import net.bmahe.genetics4j.core.spec.EAConfiguration;
17 import net.bmahe.genetics4j.core.spec.EAExecutionContext;
18 import net.bmahe.genetics4j.core.spec.EAExecutionContexts;
19 import net.bmahe.genetics4j.core.spec.EvolutionResult;
20 import net.bmahe.genetics4j.core.spec.Optimization;
21 import net.bmahe.genetics4j.core.spec.selection.Tournament;
22 import net.bmahe.genetics4j.core.termination.Terminations;
23 import net.bmahe.genetics4j.gp.Operation;
24 import net.bmahe.genetics4j.gp.math.SimplificationRules;
25 import net.bmahe.genetics4j.gp.program.Program;
26 import net.bmahe.genetics4j.gp.spec.GPEAExecutionContexts;
27 import net.bmahe.genetics4j.gp.spec.chromosome.ProgramTreeChromosomeSpec;
28 import net.bmahe.genetics4j.gp.spec.combination.ProgramRandomCombine;
29 import net.bmahe.genetics4j.gp.spec.mutation.NodeReplacement;
30 import net.bmahe.genetics4j.gp.spec.mutation.ProgramApplyRules;
31 import net.bmahe.genetics4j.gp.spec.mutation.ProgramRandomMutate;
32 import net.bmahe.genetics4j.gp.spec.mutation.ProgramRandomPrune;
33 import net.bmahe.genetics4j.gp.spec.mutation.TrimTree;
34 import net.bmahe.genetics4j.gp.utils.ProgramUtils;
35 import net.bmahe.genetics4j.gp.utils.TreeNodeUtils;
36
37 public class SymbolicRegressionWithEnforcedMaxDepth {
38 final static public Logger logger = LogManager.getLogger(SymbolicRegressionWithEnforcedMaxDepth.class);
39
40 @SuppressWarnings("unchecked")
41 public void run() {
42 final Random random = new Random();
43
44 final Program program = SymbolicRegressionUtils.buildProgram(random);
45
46 final Fitness<Double> computeFitness = (genoType) -> {
47 final TreeChromosome<Operation<?>> chromosome = (TreeChromosome<Operation<?>>) genoType.getChromosome(0);
48 final Double[][] inputs = new Double[100][1];
49 for (int i = 0; i < 100; i++) {
50 inputs[i][0] = (i - 50) * 1.2;
51 }
52
53 double mse = 0;
54 for (final Double[] input : inputs) {
55
56 final double x = input[0];
57 final double expected = SymbolicRegressionUtils.evaluate(x);
58 final Object result = ProgramUtils.execute(chromosome, input);
59
60 if (Double.isFinite(expected)) {
61 if (result instanceof Double) {
62 final Double resultDouble = (Double) result;
63 mse += Double.isFinite(resultDouble) ? (expected - resultDouble) * (expected - resultDouble)
64 : 1_000_000_000;
65 } else {
66 logger.error("NOT A DOUBLE: {}", result);
67 mse += 1000;
68 }
69 }
70 }
71 return Double.isFinite(mse) ? mse / 100.0 : Double.MAX_VALUE;
72 };
73
74 final var eaConfigurationBuilder = new EAConfiguration.Builder<Double>();
75 eaConfigurationBuilder.chromosomeSpecs(ProgramTreeChromosomeSpec.of(program))
76 .parentSelectionPolicy(Tournament.of(3))
77 .combinationPolicy(ProgramRandomCombine.build())
78 .mutationPolicies(ProgramRandomMutate.of(0.10),
79 ProgramRandomPrune.of(0.12),
80 NodeReplacement.of(0.05),
81 TrimTree.build(),
82 ProgramApplyRules.of(SimplificationRules.SIMPLIFY_RULES))
83 .optimization(Optimization.MINIMIZE)
84 .termination(Terminations.or(Terminations.ofMaxGeneration(100), Terminations.ofFitnessAtMost(0.0001d)))
85 .fitness(computeFitness);
86 final EAConfiguration<Double> eaConfiguration = eaConfigurationBuilder.build();
87
88 final var eaExecutionContextBuilder = GPEAExecutionContexts.<Double>forGP(random);
89 EAExecutionContexts.enrichForScalarFitness(eaExecutionContextBuilder);
90
91 eaExecutionContextBuilder.populationSize(1500);
92 eaExecutionContextBuilder.numberOfPartitions(Math.max(1,
93 Runtime.getRuntime()
94 .availableProcessors() - 1));
95
96 eaExecutionContextBuilder.addEvolutionListeners(
97 EvolutionListeners.ofLogTopN(logger, 5, Comparator.<Double>reverseOrder(), (genotype) -> {
98 final TreeChromosome<Operation<?>> chromosome = (TreeChromosome<Operation<?>>) genotype.getChromosome(0);
99 final TreeNode<Operation<?>> root = chromosome.getRoot();
100
101 return TreeNodeUtils.toStringTreeNode(root);
102 }),
103 SymbolicRegressionUtils.csvLoggerDouble("symbolicregression-output-enforced-max-depth.csv",
104 evolutionStep -> evolutionStep.fitness(),
105 evolutionStep -> (double) evolutionStep.individual()
106 .getChromosome(0, TreeChromosome.class)
107 .getSize()));
108
109 final EAExecutionContext<Double> eaExecutionContext = eaExecutionContextBuilder.build();
110 final EASystem<Double> eaSystem = EASystemFactory.from(eaConfiguration, eaExecutionContext);
111
112 final EvolutionResult<Double> evolutionResult = eaSystem.evolve();
113 final Genotype bestGenotype = evolutionResult.bestGenotype();
114 final TreeChromosome<Operation<?>> bestChromosome = (TreeChromosome<Operation<?>>) bestGenotype.getChromosome(0);
115 logger.info("Best genotype: {}", bestChromosome.getRoot());
116 logger.info("Best genotype - pretty print: {}", TreeNodeUtils.toStringTreeNode(bestChromosome.getRoot()));
117 }
118
119 public static int main(String[] args) {
120
121 final var symbolicRegression = new SymbolicRegressionWithEnforcedMaxDepth();
122 symbolicRegression.run();
123
124 return 0;
125 }
126 }