View Javadoc
1   package net.bmahe.genetics4j.samples.symbolicregression;
2   
3   import java.io.File;
4   import java.io.IOException;
5   import java.util.Comparator;
6   import java.util.Optional;
7   import java.util.Random;
8   import java.util.stream.IntStream;
9   
10  import org.apache.commons.cli.CommandLine;
11  import org.apache.commons.cli.CommandLineParser;
12  import org.apache.commons.cli.DefaultParser;
13  import org.apache.commons.cli.HelpFormatter;
14  import org.apache.commons.cli.Options;
15  import org.apache.commons.cli.ParseException;
16  import org.apache.commons.io.FileUtils;
17  import org.apache.commons.lang3.StringUtils;
18  import org.apache.commons.lang3.Validate;
19  import org.apache.logging.log4j.LogManager;
20  import org.apache.logging.log4j.Logger;
21  
22  import net.bmahe.genetics4j.core.EASystem;
23  import net.bmahe.genetics4j.core.EASystemFactory;
24  import net.bmahe.genetics4j.core.Fitness;
25  import net.bmahe.genetics4j.core.Genotype;
26  import net.bmahe.genetics4j.core.chromosomes.TreeChromosome;
27  import net.bmahe.genetics4j.core.evolutionlisteners.EvolutionListeners;
28  import net.bmahe.genetics4j.core.spec.EAConfiguration;
29  import net.bmahe.genetics4j.core.spec.EAExecutionContext;
30  import net.bmahe.genetics4j.core.spec.EvolutionResult;
31  import net.bmahe.genetics4j.core.spec.Optimization;
32  import net.bmahe.genetics4j.core.spec.mutation.MultiMutations;
33  import net.bmahe.genetics4j.core.spec.replacement.Elitism;
34  import net.bmahe.genetics4j.core.termination.Terminations;
35  import net.bmahe.genetics4j.gp.Operation;
36  import net.bmahe.genetics4j.gp.math.SimplificationRules;
37  import net.bmahe.genetics4j.gp.program.Program;
38  import net.bmahe.genetics4j.gp.spec.GPEAExecutionContexts;
39  import net.bmahe.genetics4j.gp.spec.chromosome.ProgramTreeChromosomeSpec;
40  import net.bmahe.genetics4j.gp.spec.combination.ProgramRandomCombine;
41  import net.bmahe.genetics4j.gp.spec.mutation.NodeReplacement;
42  import net.bmahe.genetics4j.gp.spec.mutation.ProgramApplyRules;
43  import net.bmahe.genetics4j.gp.spec.mutation.ProgramRandomMutate;
44  import net.bmahe.genetics4j.gp.spec.mutation.ProgramRandomPrune;
45  import net.bmahe.genetics4j.gp.utils.ProgramUtils;
46  import net.bmahe.genetics4j.gp.utils.TreeNodeUtils;
47  import net.bmahe.genetics4j.moo.FitnessVector;
48  import net.bmahe.genetics4j.moo.nsga2.spec.NSGA2Selection;
49  import net.bmahe.genetics4j.moo.nsga2.spec.TournamentNSGA2Selection;
50  
51  public class SymbolicRegressionWithMOO {
52  	final static public Logger logger = LogManager.getLogger(SymbolicRegressionWithMOO.class);
53  
54  	final static public String PARAM_DEST_CSV = "d";
55  	final static public String LONG_PARAM_DEST_CSV = "csv-dest";
56  
57  	final static public String PARAM_POPULATION_SIZE = "p";
58  	final static public String LONG_PARAM_POPULATION_SIZE = "population-size";
59  
60  	final static public String DEFAULT_DEST_CSV = SymbolicRegressionWithMOO.class.getSimpleName() + ".csv";
61  
62  	final static public int DEFAULT_POPULATION_SIZE = 500;
63  
64  	public static void cliError(final Options options, final String errorMessage) {
65  		final HelpFormatter formatter = new HelpFormatter();
66  		logger.error(errorMessage);
67  		formatter.printHelp(SymbolicRegressionWithMOO.class.getSimpleName(), options);
68  		System.exit(-1);
69  	}
70  
71  	@SuppressWarnings("unchecked")
72  	public void run(String csvFilename, int populationSize) {
73  		Validate.isTrue(StringUtils.isNotBlank(csvFilename));
74  		Validate.isTrue(populationSize > 0);
75  
76  		final Random random = new Random();
77  
78  		final Program program = SymbolicRegressionUtils.buildProgram(random);
79  
80  		final Comparator<Genotype> deduplicator = (a, b) -> TreeNodeUtils.compare(a, b, 0);
81  
82  		// tag::compute_fitness[]
83  		final Fitness<FitnessVector<Double>> computeFitness = (genoType) -> {
84  			final TreeChromosome<Operation<?>> chromosome = (TreeChromosome<Operation<?>>) genoType.getChromosome(0);
85  			final Double[][] inputs = new Double[100][1];
86  			for (int i = 0; i < 100; i++) {
87  				inputs[i][0] = (i - 50) * 1.2;
88  			}
89  
90  			double mse = 0;
91  			for (final Double[] input : inputs) {
92  
93  				final double x = input[0];
94  				final double expected = SymbolicRegressionUtils.evaluate(x);
95  				final Object result = ProgramUtils.execute(chromosome, input);
96  
97  				if (Double.isFinite(expected)) {
98  					final Double resultDouble = (Double) result;
99  					if (Double.isFinite(resultDouble)) {
100 						mse += (expected - resultDouble) * (expected - resultDouble);
101 					} else {
102 						mse += 1_000_000_000;
103 					}
104 				}
105 			}
106 
107 			return Double.isFinite(mse) ? new FitnessVector<Double>(mse / 100.0,
108 					(double) chromosome.getRoot()
109 							.getSize())
110 					: new FitnessVector<Double>(Double.MAX_VALUE, Double.MAX_VALUE);
111 		};
112 		// end::compute_fitness[]
113 
114 		// tag::ea_config[]
115 		final var eaConfigurationBuilder = new EAConfiguration.Builder<FitnessVector<Double>>();
116 		eaConfigurationBuilder.chromosomeSpecs(ProgramTreeChromosomeSpec.of(program)) // <1>
117 				.parentSelectionPolicy(TournamentNSGA2Selection.ofFitnessVector(2, 3, deduplicator)) // <2>
118 				.replacementStrategy(Elitism.builder() // <3>
119 						.offspringRatio(0.995)
120 						.offspringSelectionPolicy(TournamentNSGA2Selection.ofFitnessVector(2, 3, deduplicator))
121 						.survivorSelectionPolicy(NSGA2Selection.ofFitnessVector(2, deduplicator))
122 						.build())
123 				.combinationPolicy(ProgramRandomCombine.build())
124 				.mutationPolicies(MultiMutations
125 						.of(ProgramRandomMutate.of(0.15 * 3), ProgramRandomPrune.of(0.15 * 3), NodeReplacement.of(0.15 * 3)),
126 						ProgramApplyRules.of(SimplificationRules.SIMPLIFY_RULES))
127 				.optimization(Optimization.MINIMIZE)
128 				.termination(Terminations.or(Terminations.<FitnessVector<Double>>ofMaxGeneration(200),
129 						(eaConfiguration, generation, population, fitness) -> fitness.stream()
130 								.anyMatch(fv -> fv.get(0) <= 0.000001 && fv.get(1) <= 20))) // <4>
131 				.fitness(computeFitness);
132 		final EAConfiguration<FitnessVector<Double>> eaConfiguration = eaConfigurationBuilder.build();
133 		// end::ea_config[]
134 
135 		// tag::eae_moo[]
136 		final var eaExecutionContextBuilder = GPEAExecutionContexts.<FitnessVector<Double>>forGP(random);
137 		// end::eae_moo[]
138 		eaExecutionContextBuilder.populationSize(populationSize);
139 		eaExecutionContextBuilder.numberOfPartitions(Math.max(1,
140 				Runtime.getRuntime()
141 						.availableProcessors() - 3));
142 
143 		eaExecutionContextBuilder.addEvolutionListeners(
144 				EvolutionListeners.ofLogTopN(logger,
145 						5,
146 						Comparator.<FitnessVector<Double>, Double>comparing(fv -> fv.get(0))
147 								.reversed(),
148 						(genotype) -> TreeNodeUtils.toStringTreeNode(genotype, 0)),
149 				SymbolicRegressionUtils.csvLogger(csvFilename,
150 						evolutionStep -> evolutionStep.fitness()
151 								.get(0),
152 						evolutionStep -> evolutionStep.fitness()
153 								.get(1)));
154 
155 		final EAExecutionContext<FitnessVector<Double>> eaExecutionContext = eaExecutionContextBuilder.build();
156 		final EASystem<FitnessVector<Double>> eaSystem = EASystemFactory.from(eaConfiguration, eaExecutionContext);
157 
158 		final EvolutionResult<FitnessVector<Double>> evolutionResult = eaSystem.evolve();
159 		final Genotype bestGenotype = evolutionResult.bestGenotype();
160 		final TreeChromosome<Operation<?>> bestChromosome = (TreeChromosome<Operation<?>>) bestGenotype.getChromosome(0);
161 		logger.info("Best genotype: {}", bestChromosome.getRoot());
162 		logger.info("Best genotype - pretty print: {}", TreeNodeUtils.toStringTreeNode(bestChromosome.getRoot()));
163 
164 		final int depthIdx = 1;
165 		for (int i = 0; i < 15; i++) {
166 			final int depth = i;
167 			final Optional<Integer> optIdx = IntStream.range(0,
168 					evolutionResult.fitness()
169 							.size())
170 					.boxed()
171 					.filter((idx) -> evolutionResult.fitness()
172 							.get(idx)
173 							.get(depthIdx) == depth)
174 					.sorted((a, b) -> Double.compare(evolutionResult.fitness()
175 							.get(a)
176 							.get(0),
177 							evolutionResult.fitness()
178 									.get(b)
179 									.get(0)))
180 					.findFirst();
181 
182 			optIdx.stream()
183 					.forEach((idx) -> {
184 						final TreeChromosome<Operation<?>> treeChromosome = (TreeChromosome<Operation<?>>) evolutionResult
185 								.population()
186 								.get(idx)
187 								.getChromosome(0);
188 
189 						logger.info("Best genotype for depth {} - score {} -> {}",
190 								depth,
191 								evolutionResult.fitness()
192 										.get(idx)
193 										.get(0),
194 								TreeNodeUtils.toStringTreeNode(treeChromosome.getRoot()));
195 					});
196 		}
197 	}
198 
199 	public static void main(String[] args) throws IOException {
200 
201 		/**
202 		 * Parse CLI
203 		 */
204 
205 		final CommandLineParser parser = new DefaultParser();
206 
207 		final Options options = new Options();
208 		options.addOption(PARAM_DEST_CSV, LONG_PARAM_DEST_CSV, true, "destination csv file");
209 
210 		options.addOption(PARAM_POPULATION_SIZE, LONG_PARAM_POPULATION_SIZE, true, "Population size");
211 
212 		String csvFilename = DEFAULT_DEST_CSV;
213 		int populationSize = DEFAULT_POPULATION_SIZE;
214 		try {
215 			final CommandLine line = parser.parse(options, args);
216 
217 			if (line.hasOption(PARAM_DEST_CSV)) {
218 				csvFilename = line.getOptionValue(PARAM_DEST_CSV);
219 			}
220 
221 			if (line.hasOption(PARAM_POPULATION_SIZE)) {
222 				populationSize = Integer.parseInt(line.getOptionValue(PARAM_POPULATION_SIZE));
223 			}
224 
225 		} catch (ParseException exp) {
226 			cliError(options, "Unexpected exception:" + exp.getMessage());
227 		}
228 
229 		logger.info("Population size: {}", populationSize);
230 
231 		logger.info("CSV output located at {}", csvFilename);
232 		FileUtils.forceMkdirParent(new File(csvFilename));
233 
234 		final var symbolicRegression = new SymbolicRegressionWithMOO();
235 		symbolicRegression.run(csvFilename, populationSize);
236 	}
237 }