1 package net.bmahe.genetics4j.samples.symbolicregression;
2
3 import static net.bmahe.genetics4j.core.termination.Terminations.ofFitnessAtMost;
4 import static net.bmahe.genetics4j.core.termination.Terminations.ofMaxGeneration;
5 import static net.bmahe.genetics4j.core.termination.Terminations.or;
6
7 import java.io.File;
8 import java.io.IOException;
9 import java.util.Comparator;
10 import java.util.Random;
11
12 import org.apache.commons.cli.CommandLine;
13 import org.apache.commons.cli.CommandLineParser;
14 import org.apache.commons.cli.DefaultParser;
15 import org.apache.commons.cli.HelpFormatter;
16 import org.apache.commons.cli.Options;
17 import org.apache.commons.cli.ParseException;
18 import org.apache.commons.io.FileUtils;
19 import org.apache.commons.lang3.StringUtils;
20 import org.apache.commons.lang3.Validate;
21 import org.apache.logging.log4j.LogManager;
22 import org.apache.logging.log4j.Logger;
23
24 import net.bmahe.genetics4j.core.EASystem;
25 import net.bmahe.genetics4j.core.EASystemFactory;
26 import net.bmahe.genetics4j.core.Fitness;
27 import net.bmahe.genetics4j.core.Genotype;
28 import net.bmahe.genetics4j.core.Individual;
29 import net.bmahe.genetics4j.core.chromosomes.TreeChromosome;
30 import net.bmahe.genetics4j.core.chromosomes.TreeNode;
31 import net.bmahe.genetics4j.core.evolutionlisteners.EvolutionListeners;
32 import net.bmahe.genetics4j.core.spec.EAConfiguration;
33 import net.bmahe.genetics4j.core.spec.EAExecutionContext;
34 import net.bmahe.genetics4j.core.spec.EAExecutionContexts;
35 import net.bmahe.genetics4j.core.spec.EvolutionResult;
36 import net.bmahe.genetics4j.core.spec.Optimization;
37 import net.bmahe.genetics4j.core.spec.replacement.Elitism;
38 import net.bmahe.genetics4j.core.spec.selection.SelectiveRefinementTournament;
39 import net.bmahe.genetics4j.core.spec.selection.Tournament;
40 import net.bmahe.genetics4j.gp.Operation;
41 import net.bmahe.genetics4j.gp.math.SimplificationRules;
42 import net.bmahe.genetics4j.gp.program.Program;
43 import net.bmahe.genetics4j.gp.spec.GPEAExecutionContexts;
44 import net.bmahe.genetics4j.gp.spec.chromosome.ProgramTreeChromosomeSpec;
45 import net.bmahe.genetics4j.gp.spec.combination.ProgramRandomCombine;
46 import net.bmahe.genetics4j.gp.spec.mutation.NodeReplacement;
47 import net.bmahe.genetics4j.gp.spec.mutation.ProgramApplyRules;
48 import net.bmahe.genetics4j.gp.spec.mutation.ProgramRandomMutate;
49 import net.bmahe.genetics4j.gp.spec.mutation.ProgramRandomPrune;
50 import net.bmahe.genetics4j.gp.utils.ProgramUtils;
51 import net.bmahe.genetics4j.gp.utils.TreeNodeUtils;
52
53 public class SymbolicRegressionWithSRT {
54 final static public Logger logger = LogManager.getLogger(SymbolicRegressionWithSRT.class);
55
56 final static public String PARAM_DEST_CSV = "d";
57 final static public String LONG_PARAM_DEST_CSV = "csv-dest";
58
59 final static public String PARAM_POPULATION_SIZE = "p";
60 final static public String LONG_PARAM_POPULATION_SIZE = "population-size";
61
62 final static public String DEFAULT_DEST_CSV = SymbolicRegressionWithSRT.class.getSimpleName() + ".csv";
63
64 final static public int DEFAULT_POPULATION_SIZE = 500;
65
66 public static void cliError(final Options options, final String errorMessage) {
67 final HelpFormatter formatter = new HelpFormatter();
68 logger.error(errorMessage);
69 formatter.printHelp(SymbolicRegressionWithSRT.class.getSimpleName(), options);
70 System.exit(-1);
71 }
72
73 @SuppressWarnings("unchecked")
74 public void run(String csvFilename, int populationSize) {
75 Validate.isTrue(StringUtils.isNotBlank(csvFilename));
76 Validate.isTrue(populationSize > 0);
77
78 final Random random = new Random();
79
80 final Program program = SymbolicRegressionUtils.buildProgram(random);
81
82
83 final Double[][] inputs = new Double[100][1];
84 for (int i = 0; i < 100; i++) {
85 inputs[i][0] = (i - 50) * 1.2;
86 }
87
88 final Fitness<Double> computeFitness = (genoType) -> {
89 final TreeChromosome<Operation<?>> chromosome = (TreeChromosome<Operation<?>>) genoType.getChromosome(0);
90
91 double mse = 0;
92 for (final Double[] input : inputs) {
93
94 final double x = input[0];
95 final double expected = SymbolicRegressionUtils.evaluate(x);
96 final Object result = ProgramUtils.execute(chromosome, input);
97
98 if (Double.isFinite(expected)) {
99 final Double resultDouble = (Double) result;
100 mse += Double.isFinite(resultDouble) ? (expected - resultDouble) * (expected - resultDouble)
101 : 1_000_000_000;
102 }
103 }
104 return Double.isFinite(mse) ? mse / 100.0d : Double.MAX_VALUE;
105 };
106
107
108
109 final Comparator<Individual<Double>> parsimonyComparator = (a, b) -> {
110 final var treeChromosomeA = a.genotype()
111 .getChromosome(0, TreeChromosome.class);
112 final var treeChromosomeB = b.genotype()
113 .getChromosome(0, TreeChromosome.class);
114
115 return -Integer.compare(treeChromosomeA.getSize(), treeChromosomeB.getSize());
116 };
117
118 final SelectiveRefinementTournament<Double> selectiveRefinementTournament = SelectiveRefinementTournament
119 .<Double>builder()
120 .tournament(Tournament.of(3))
121 .refinementComparator(parsimonyComparator)
122 .refinementRatio(0.65f)
123 .build();
124
125
126
127
128 final var eaConfigurationBuilder = new EAConfiguration.Builder<Double>();
129 eaConfigurationBuilder.chromosomeSpecs(ProgramTreeChromosomeSpec.of(program))
130 .parentSelectionPolicy(selectiveRefinementTournament)
131 .replacementStrategy(Elitism.builder()
132 .offspringRatio(0.99)
133 .offspringSelectionPolicy(selectiveRefinementTournament)
134 .survivorSelectionPolicy(selectiveRefinementTournament)
135 .build())
136 .combinationPolicy(ProgramRandomCombine.build())
137 .mutationPolicies(ProgramRandomMutate.of(0.10),
138 ProgramRandomPrune.of(0.12),
139 NodeReplacement.of(0.05),
140 ProgramApplyRules.of(SimplificationRules.SIMPLIFY_RULES))
141 .optimization(Optimization.MINIMIZE)
142 .termination(or(ofMaxGeneration(200), ofFitnessAtMost(0.00001)))
143 .fitness(computeFitness);
144 final EAConfiguration<Double> eaConfiguration = eaConfigurationBuilder.build();
145
146
147 final var eaExecutionContextBuilder = GPEAExecutionContexts.<Double>forGP(random);
148 EAExecutionContexts.enrichForScalarFitness(eaExecutionContextBuilder);
149
150 eaExecutionContextBuilder.populationSize(populationSize);
151 eaExecutionContextBuilder.numberOfPartitions(Math.max(1,
152 Runtime.getRuntime()
153 .availableProcessors() - 1));
154
155 eaExecutionContextBuilder.addEvolutionListeners(
156 EvolutionListeners.ofLogTopN(logger, 5, Comparator.<Double>reverseOrder(), (genotype) -> {
157 final TreeChromosome<Operation<?>> chromosome = (TreeChromosome<Operation<?>>) genotype.getChromosome(0);
158 final TreeNode<Operation<?>> root = chromosome.getRoot();
159
160 return TreeNodeUtils.toStringTreeNode(root);
161 }),
162 SymbolicRegressionUtils.csvLoggerDouble(csvFilename,
163 evolutionStep -> evolutionStep.fitness(),
164 evolutionStep -> (double) evolutionStep.individual()
165 .getChromosome(0, TreeChromosome.class)
166 .getSize()));
167
168 final EAExecutionContext<Double> eaExecutionContext = eaExecutionContextBuilder.build();
169 final EASystem<Double> eaSystem = EASystemFactory.from(eaConfiguration, eaExecutionContext);
170
171 final EvolutionResult<Double> evolutionResult = eaSystem.evolve();
172 final Genotype bestGenotype = evolutionResult.bestGenotype();
173 final TreeChromosome<Operation<?>> bestChromosome = (TreeChromosome<Operation<?>>) bestGenotype.getChromosome(0);
174 logger.info("Best genotype: {}", bestChromosome.getRoot());
175 logger.info("Best genotype - pretty print: {}", TreeNodeUtils.toStringTreeNode(bestChromosome.getRoot()));
176 }
177
178 public static void main(String[] args) throws IOException {
179
180
181
182
183
184 final CommandLineParser parser = new DefaultParser();
185
186 final Options options = new Options();
187 options.addOption(PARAM_DEST_CSV, LONG_PARAM_DEST_CSV, true, "destination csv file");
188
189 options.addOption(PARAM_POPULATION_SIZE, LONG_PARAM_POPULATION_SIZE, true, "Population size");
190
191 String csvFilename = DEFAULT_DEST_CSV;
192 int populationSize = DEFAULT_POPULATION_SIZE;
193 try {
194 final CommandLine line = parser.parse(options, args);
195
196 if (line.hasOption(PARAM_DEST_CSV)) {
197 csvFilename = line.getOptionValue(PARAM_DEST_CSV);
198 }
199
200 if (line.hasOption(PARAM_POPULATION_SIZE)) {
201 populationSize = Integer.parseInt(line.getOptionValue(PARAM_POPULATION_SIZE));
202 }
203
204 } catch (ParseException exp) {
205 cliError(options, "Unexpected exception:" + exp.getMessage());
206 }
207
208 logger.info("Population size: {}", populationSize);
209
210 logger.info("CSV output located at {}", csvFilename);
211 FileUtils.forceMkdirParent(new File(csvFilename));
212
213 final var symbolicRegression = new SymbolicRegressionWithSRT();
214 symbolicRegression.run(csvFilename, populationSize);
215 }
216 }