View Javadoc
1   package net.bmahe.genetics4j.samples.symbolicregression;
2   
3   import java.util.Comparator;
4   import java.util.Random;
5   
6   import org.apache.logging.log4j.LogManager;
7   import org.apache.logging.log4j.Logger;
8   
9   import net.bmahe.genetics4j.core.EASystem;
10  import net.bmahe.genetics4j.core.EASystemFactory;
11  import net.bmahe.genetics4j.core.Fitness;
12  import net.bmahe.genetics4j.core.Genotype;
13  import net.bmahe.genetics4j.core.chromosomes.TreeChromosome;
14  import net.bmahe.genetics4j.core.chromosomes.TreeNode;
15  import net.bmahe.genetics4j.core.evolutionlisteners.EvolutionListeners;
16  import net.bmahe.genetics4j.core.spec.EAConfiguration;
17  import net.bmahe.genetics4j.core.spec.EAExecutionContext;
18  import net.bmahe.genetics4j.core.spec.EAExecutionContexts;
19  import net.bmahe.genetics4j.core.spec.EvolutionResult;
20  import net.bmahe.genetics4j.core.spec.Optimization;
21  import net.bmahe.genetics4j.core.spec.selection.Tournament;
22  import net.bmahe.genetics4j.core.termination.Terminations;
23  import net.bmahe.genetics4j.gp.Operation;
24  import net.bmahe.genetics4j.gp.math.SimplificationRules;
25  import net.bmahe.genetics4j.gp.program.Program;
26  import net.bmahe.genetics4j.gp.spec.GPEAExecutionContexts;
27  import net.bmahe.genetics4j.gp.spec.chromosome.ProgramTreeChromosomeSpec;
28  import net.bmahe.genetics4j.gp.spec.combination.ProgramRandomCombine;
29  import net.bmahe.genetics4j.gp.spec.mutation.NodeReplacement;
30  import net.bmahe.genetics4j.gp.spec.mutation.ProgramApplyRules;
31  import net.bmahe.genetics4j.gp.spec.mutation.ProgramRandomMutate;
32  import net.bmahe.genetics4j.gp.spec.mutation.ProgramRandomPrune;
33  import net.bmahe.genetics4j.gp.spec.mutation.TrimTree;
34  import net.bmahe.genetics4j.gp.utils.ProgramUtils;
35  import net.bmahe.genetics4j.gp.utils.TreeNodeUtils;
36  
37  public class SymbolicRegressionWithEnforcedMaxDepth {
38  	final static public Logger logger = LogManager.getLogger(SymbolicRegressionWithEnforcedMaxDepth.class);
39  
40  	@SuppressWarnings("unchecked")
41  	public void run() {
42  		final Random random = new Random();
43  
44  		final Program program = SymbolicRegressionUtils.buildProgram(random);
45  
46  		final Fitness<Double> computeFitness = (genoType) -> {
47  			final TreeChromosome<Operation<?>> chromosome = (TreeChromosome<Operation<?>>) genoType.getChromosome(0);
48  			final Double[][] inputs = new Double[100][1];
49  			for (int i = 0; i < 100; i++) {
50  				inputs[i][0] = (i - 50) * 1.2;
51  			}
52  
53  			double mse = 0;
54  			for (final Double[] input : inputs) {
55  
56  				final double x = input[0];
57  				final double expected = SymbolicRegressionUtils.evaluate(x);
58  				final Object result = ProgramUtils.execute(chromosome, input);
59  
60  				if (Double.isFinite(expected)) {
61  					if (result instanceof Double) {
62  						final Double resultDouble = (Double) result;
63  						mse += Double.isFinite(resultDouble) ? (expected - resultDouble) * (expected - resultDouble)
64  								: 1_000_000_000;
65  					} else {
66  						logger.error("NOT A DOUBLE: {}", result);
67  						mse += 1000;
68  					}
69  				}
70  			}
71  			return Double.isFinite(mse) ? mse / 100.0 : Double.MAX_VALUE;
72  		};
73  
74  		final var eaConfigurationBuilder = new EAConfiguration.Builder<Double>();
75  		eaConfigurationBuilder.chromosomeSpecs(ProgramTreeChromosomeSpec.of(program))
76  				.parentSelectionPolicy(Tournament.of(3))
77  				.combinationPolicy(ProgramRandomCombine.build())
78  				.mutationPolicies(ProgramRandomMutate.of(0.10),
79  						ProgramRandomPrune.of(0.12),
80  						NodeReplacement.of(0.05),
81  						TrimTree.build(),
82  						ProgramApplyRules.of(SimplificationRules.SIMPLIFY_RULES))
83  				.optimization(Optimization.MINIMIZE)
84  				.termination(Terminations.or(Terminations.ofMaxGeneration(100), Terminations.ofFitnessAtMost(0.0001d)))
85  				.fitness(computeFitness);
86  		final EAConfiguration<Double> eaConfiguration = eaConfigurationBuilder.build();
87  
88  		final var eaExecutionContextBuilder = GPEAExecutionContexts.<Double>forGP(random);
89  		EAExecutionContexts.enrichForScalarFitness(eaExecutionContextBuilder);
90  
91  		eaExecutionContextBuilder.populationSize(1500);
92  		eaExecutionContextBuilder.numberOfPartitions(Math.max(1,
93  				Runtime.getRuntime()
94  						.availableProcessors() - 1));
95  
96  		eaExecutionContextBuilder.addEvolutionListeners(
97  				EvolutionListeners.ofLogTopN(logger, 5, Comparator.<Double>reverseOrder(), (genotype) -> {
98  					final TreeChromosome<Operation<?>> chromosome = (TreeChromosome<Operation<?>>) genotype.getChromosome(0);
99  					final TreeNode<Operation<?>> root = chromosome.getRoot();
100 
101 					return TreeNodeUtils.toStringTreeNode(root);
102 				}),
103 				SymbolicRegressionUtils.csvLoggerDouble("symbolicregression-output-enforced-max-depth.csv",
104 						evolutionStep -> evolutionStep.fitness(),
105 						evolutionStep -> (double) evolutionStep.individual()
106 								.getChromosome(0, TreeChromosome.class)
107 								.getSize()));
108 
109 		final EAExecutionContext<Double> eaExecutionContext = eaExecutionContextBuilder.build();
110 		final EASystem<Double> eaSystem = EASystemFactory.from(eaConfiguration, eaExecutionContext);
111 
112 		final EvolutionResult<Double> evolutionResult = eaSystem.evolve();
113 		final Genotype bestGenotype = evolutionResult.bestGenotype();
114 		final TreeChromosome<Operation<?>> bestChromosome = (TreeChromosome<Operation<?>>) bestGenotype.getChromosome(0);
115 		logger.info("Best genotype: {}", bestChromosome.getRoot());
116 		logger.info("Best genotype - pretty print: {}", TreeNodeUtils.toStringTreeNode(bestChromosome.getRoot()));
117 	}
118 
119 	public static int main(String[] args) {
120 
121 		final var symbolicRegression = new SymbolicRegressionWithEnforcedMaxDepth();
122 		symbolicRegression.run();
123 
124 		return 0;
125 	}
126 }