1 package net.bmahe.genetics4j.samples.symbolicregression; 2 3 import java.util.Comparator; 4 import java.util.Random; 5 6 import org.apache.logging.log4j.LogManager; 7 import org.apache.logging.log4j.Logger; 8 9 import net.bmahe.genetics4j.core.EASystem; 10 import net.bmahe.genetics4j.core.EASystemFactory; 11 import net.bmahe.genetics4j.core.Fitness; 12 import net.bmahe.genetics4j.core.Genotype; 13 import net.bmahe.genetics4j.core.chromosomes.TreeChromosome; 14 import net.bmahe.genetics4j.core.chromosomes.TreeNode; 15 import net.bmahe.genetics4j.core.evolutionlisteners.EvolutionListeners; 16 import net.bmahe.genetics4j.core.spec.EAConfiguration; 17 import net.bmahe.genetics4j.core.spec.EAExecutionContext; 18 import net.bmahe.genetics4j.core.spec.EAExecutionContexts; 19 import net.bmahe.genetics4j.core.spec.EvolutionResult; 20 import net.bmahe.genetics4j.core.spec.Optimization; 21 import net.bmahe.genetics4j.core.spec.selection.Tournament; 22 import net.bmahe.genetics4j.core.termination.Terminations; 23 import net.bmahe.genetics4j.gp.Operation; 24 import net.bmahe.genetics4j.gp.math.SimplificationRules; 25 import net.bmahe.genetics4j.gp.program.Program; 26 import net.bmahe.genetics4j.gp.spec.GPEAExecutionContexts; 27 import net.bmahe.genetics4j.gp.spec.chromosome.ProgramTreeChromosomeSpec; 28 import net.bmahe.genetics4j.gp.spec.combination.ProgramRandomCombine; 29 import net.bmahe.genetics4j.gp.spec.mutation.NodeReplacement; 30 import net.bmahe.genetics4j.gp.spec.mutation.ProgramApplyRules; 31 import net.bmahe.genetics4j.gp.spec.mutation.ProgramRandomMutate; 32 import net.bmahe.genetics4j.gp.spec.mutation.ProgramRandomPrune; 33 import net.bmahe.genetics4j.gp.spec.mutation.TrimTree; 34 import net.bmahe.genetics4j.gp.utils.ProgramUtils; 35 import net.bmahe.genetics4j.gp.utils.TreeNodeUtils; 36 37 public class SymbolicRegressionWithEnforcedMaxDepth { 38 final static public Logger logger = LogManager.getLogger(SymbolicRegressionWithEnforcedMaxDepth.class); 39 40 @SuppressWarnings("unchecked") 41 public void run() { 42 final Random random = new Random(); 43 44 final Program program = SymbolicRegressionUtils.buildProgram(random); 45 46 final Fitness<Double> computeFitness = (genoType) -> { 47 final TreeChromosome<Operation<?>> chromosome = (TreeChromosome<Operation<?>>) genoType.getChromosome(0); 48 final Double[][] inputs = new Double[100][1]; 49 for (int i = 0; i < 100; i++) { 50 inputs[i][0] = (i - 50) * 1.2; 51 } 52 53 double mse = 0; 54 for (final Double[] input : inputs) { 55 56 final double x = input[0]; 57 final double expected = SymbolicRegressionUtils.evaluate(x); 58 final Object result = ProgramUtils.execute(chromosome, input); 59 60 if (Double.isFinite(expected)) { 61 if (result instanceof Double) { 62 final Double resultDouble = (Double) result; 63 mse += Double.isFinite(resultDouble) ? (expected - resultDouble) * (expected - resultDouble) 64 : 1_000_000_000; 65 } else { 66 logger.error("NOT A DOUBLE: {}", result); 67 mse += 1000; 68 } 69 } 70 } 71 return Double.isFinite(mse) ? mse / 100.0 : Double.MAX_VALUE; 72 }; 73 74 final var eaConfigurationBuilder = new EAConfiguration.Builder<Double>(); 75 eaConfigurationBuilder.chromosomeSpecs(ProgramTreeChromosomeSpec.of(program)) 76 .parentSelectionPolicy(Tournament.of(3)) 77 .combinationPolicy(ProgramRandomCombine.build()) 78 .mutationPolicies(ProgramRandomMutate.of(0.10), 79 ProgramRandomPrune.of(0.12), 80 NodeReplacement.of(0.05), 81 TrimTree.build(), 82 ProgramApplyRules.of(SimplificationRules.SIMPLIFY_RULES)) 83 .optimization(Optimization.MINIMIZE) 84 .termination(Terminations.or(Terminations.ofMaxGeneration(100), Terminations.ofFitnessAtMost(0.0001d))) 85 .fitness(computeFitness); 86 final EAConfiguration<Double> eaConfiguration = eaConfigurationBuilder.build(); 87 88 final var eaExecutionContextBuilder = GPEAExecutionContexts.<Double>forGP(random); 89 EAExecutionContexts.enrichForScalarFitness(eaExecutionContextBuilder); 90 91 eaExecutionContextBuilder.populationSize(1500); 92 eaExecutionContextBuilder.numberOfPartitions(Math.max(1, 93 Runtime.getRuntime() 94 .availableProcessors() - 1)); 95 96 eaExecutionContextBuilder.addEvolutionListeners( 97 EvolutionListeners.ofLogTopN(logger, 5, Comparator.<Double>reverseOrder(), (genotype) -> { 98 final TreeChromosome<Operation<?>> chromosome = (TreeChromosome<Operation<?>>) genotype.getChromosome(0); 99 final TreeNode<Operation<?>> root = chromosome.getRoot(); 100 101 return TreeNodeUtils.toStringTreeNode(root); 102 }), 103 SymbolicRegressionUtils.csvLoggerDouble("symbolicregression-output-enforced-max-depth.csv", 104 evolutionStep -> evolutionStep.fitness(), 105 evolutionStep -> (double) evolutionStep.individual() 106 .getChromosome(0, TreeChromosome.class) 107 .getSize())); 108 109 final EAExecutionContext<Double> eaExecutionContext = eaExecutionContextBuilder.build(); 110 final EASystem<Double> eaSystem = EASystemFactory.from(eaConfiguration, eaExecutionContext); 111 112 final EvolutionResult<Double> evolutionResult = eaSystem.evolve(); 113 final Genotype bestGenotype = evolutionResult.bestGenotype(); 114 final TreeChromosome<Operation<?>> bestChromosome = (TreeChromosome<Operation<?>>) bestGenotype.getChromosome(0); 115 logger.info("Best genotype: {}", bestChromosome.getRoot()); 116 logger.info("Best genotype - pretty print: {}", TreeNodeUtils.toStringTreeNode(bestChromosome.getRoot())); 117 } 118 119 public static int main(String[] args) { 120 121 final var symbolicRegression = new SymbolicRegressionWithEnforcedMaxDepth(); 122 symbolicRegression.run(); 123 124 return 0; 125 } 126 }