View Javadoc
1   package net.bmahe.genetics4j.samples.symbolicregression;
2   
3   import java.io.File;
4   import java.io.IOException;
5   import java.util.Comparator;
6   import java.util.List;
7   import java.util.Optional;
8   import java.util.Random;
9   import java.util.stream.IntStream;
10  
11  import org.apache.commons.cli.CommandLine;
12  import org.apache.commons.cli.CommandLineParser;
13  import org.apache.commons.cli.DefaultParser;
14  import org.apache.commons.cli.HelpFormatter;
15  import org.apache.commons.cli.Options;
16  import org.apache.commons.cli.ParseException;
17  import org.apache.commons.io.FileUtils;
18  import org.apache.commons.lang3.StringUtils;
19  import org.apache.commons.lang3.Validate;
20  import org.apache.commons.lang3.time.DurationFormatUtils;
21  import org.apache.logging.log4j.LogManager;
22  import org.apache.logging.log4j.Logger;
23  
24  import net.bmahe.genetics4j.core.EASystem;
25  import net.bmahe.genetics4j.core.EASystemFactory;
26  import net.bmahe.genetics4j.core.Fitness;
27  import net.bmahe.genetics4j.core.Genotype;
28  import net.bmahe.genetics4j.core.chromosomes.TreeChromosome;
29  import net.bmahe.genetics4j.core.evolutionlisteners.EvolutionListener;
30  import net.bmahe.genetics4j.core.evolutionlisteners.EvolutionListeners;
31  import net.bmahe.genetics4j.core.spec.EAConfiguration;
32  import net.bmahe.genetics4j.core.spec.EAExecutionContext;
33  import net.bmahe.genetics4j.core.spec.EvolutionResult;
34  import net.bmahe.genetics4j.core.spec.Optimization;
35  import net.bmahe.genetics4j.core.spec.mutation.MultiMutations;
36  import net.bmahe.genetics4j.core.termination.Terminations;
37  import net.bmahe.genetics4j.gp.Operation;
38  import net.bmahe.genetics4j.gp.math.SimplificationRules;
39  import net.bmahe.genetics4j.gp.program.Program;
40  import net.bmahe.genetics4j.gp.spec.GPEAExecutionContexts;
41  import net.bmahe.genetics4j.gp.spec.chromosome.ProgramTreeChromosomeSpec;
42  import net.bmahe.genetics4j.gp.spec.combination.ProgramRandomCombine;
43  import net.bmahe.genetics4j.gp.spec.mutation.NodeReplacement;
44  import net.bmahe.genetics4j.gp.spec.mutation.ProgramApplyRules;
45  import net.bmahe.genetics4j.gp.spec.mutation.ProgramRandomMutate;
46  import net.bmahe.genetics4j.gp.spec.mutation.ProgramRandomPrune;
47  import net.bmahe.genetics4j.gp.utils.ProgramUtils;
48  import net.bmahe.genetics4j.gp.utils.TreeNodeUtils;
49  import net.bmahe.genetics4j.moo.FitnessVector;
50  import net.bmahe.genetics4j.moo.nsga2.spec.TournamentNSGA2Selection;
51  import net.bmahe.genetics4j.moo.spea2.spec.replacement.SPEA2Replacement;
52  
53  public class SymbolicRegressionWithMOOSPEA2 {
54  	final static public Logger logger = LogManager.getLogger(SymbolicRegressionWithMOOSPEA2.class);
55  
56  	final static public String PARAM_DEST_CSV = "d";
57  	final static public String LONG_PARAM_DEST_CSV = "csv-dest";
58  
59  	final static public String PARAM_POPULATION_SIZE = "p";
60  	final static public String LONG_PARAM_POPULATION_SIZE = "population-size";
61  
62  	final static public String DEFAULT_DEST_CSV = SymbolicRegressionWithMOOSPEA2.class.getSimpleName() + ".csv";
63  
64  	final static public int DEFAULT_POPULATION_SIZE = 500;
65  
66  	public static void cliError(final Options options, final String errorMessage) {
67  		final HelpFormatter formatter = new HelpFormatter();
68  		logger.error(errorMessage);
69  		formatter.printHelp(SymbolicRegressionWithMOOSPEA2.class.getSimpleName(), options);
70  		System.exit(-1);
71  	}
72  
73  	@SuppressWarnings("unchecked")
74  	public void run(String csvFilename, int populationSize) {
75  		Validate.isTrue(StringUtils.isNotBlank(csvFilename));
76  		Validate.isTrue(populationSize > 0);
77  
78  		final Random random = new Random();
79  
80  		final Program program = SymbolicRegressionUtils.buildProgram(random);
81  
82  		final Comparator<Genotype> deduplicator = (a, b) -> TreeNodeUtils.compare(a, b, 0);
83  
84  		// tag::compute_fitness[]
85  		final Fitness<FitnessVector<Double>> computeFitness = (genoType) -> {
86  			final TreeChromosome<Operation<?>> chromosome = (TreeChromosome<Operation<?>>) genoType.getChromosome(0);
87  			final Double[][] inputs = new Double[100][1];
88  			for (int i = 0; i < 100; i++) {
89  				inputs[i][0] = (i - 50) * 1.2;
90  			}
91  
92  			double mse = 0;
93  			for (final Double[] input : inputs) {
94  
95  				final double x = input[0];
96  				final double expected = SymbolicRegressionUtils.evaluate(x);
97  				final Object result = ProgramUtils.execute(chromosome, input);
98  
99  				if (Double.isFinite(expected)) {
100 					final Double resultDouble = (Double) result;
101 					if (Double.isFinite(resultDouble)) {
102 						mse += (expected - resultDouble) * (expected - resultDouble);
103 					} else {
104 						mse += 1_000_000_000;
105 					}
106 				}
107 			}
108 
109 			return Double.isFinite(mse) ? new FitnessVector<Double>(mse / 100.0,
110 					(double) chromosome.getRoot()
111 							.getSize())
112 					: new FitnessVector<Double>(Double.MAX_VALUE, Double.MAX_VALUE);
113 		};
114 		// end::compute_fitness[]
115 
116 		// tag::ea_config[]
117 		final var eaConfigurationBuilder = new EAConfiguration.Builder<FitnessVector<Double>>();
118 		eaConfigurationBuilder.chromosomeSpecs(ProgramTreeChromosomeSpec.of(program)) // <1>
119 				.parentSelectionPolicy(TournamentNSGA2Selection.ofFitnessVector(2, 3, deduplicator)) // <2>
120 				.replacementStrategy(SPEA2Replacement.ofFitnessVector(deduplicator)) // <3>
121 				.combinationPolicy(ProgramRandomCombine.build())
122 				.mutationPolicies(MultiMutations
123 						.of(ProgramRandomMutate.of(0.15 * 3), ProgramRandomPrune.of(0.15 * 3), NodeReplacement.of(0.15 * 3)),
124 						ProgramApplyRules.of(SimplificationRules.SIMPLIFY_RULES))
125 				.optimization(Optimization.MINIMIZE)
126 				.termination(Terminations.or(Terminations.<FitnessVector<Double>>ofMaxGeneration(200),
127 						(eaConfiguration, generation, population, fitness) -> fitness.stream()
128 								.anyMatch(fv -> fv.get(0) <= 0.00001 && fv.get(1) <= 20)))
129 				.fitness(computeFitness);
130 		final EAConfiguration<FitnessVector<Double>> eaConfiguration = eaConfigurationBuilder.build();
131 		// end::ea_config[]
132 
133 		final var eaExecutionContextBuilder = GPEAExecutionContexts.<FitnessVector<Double>>forGP(random);
134 		eaExecutionContextBuilder.populationSize(populationSize);
135 		eaExecutionContextBuilder.numberOfPartitions(Math.max(1,
136 				Runtime.getRuntime()
137 						.availableProcessors() - 3));
138 
139 		eaExecutionContextBuilder.addEvolutionListeners(
140 				EvolutionListeners.ofLogTopN(logger,
141 						5,
142 						Comparator.<FitnessVector<Double>, Double>comparing(fv -> fv.get(0))
143 								.reversed(),
144 						(genotype) -> TreeNodeUtils.toStringTreeNode(genotype, 0)),
145 				SymbolicRegressionUtils.csvLogger(csvFilename,
146 						evolutionStep -> evolutionStep.fitness()
147 								.get(0),
148 						evolutionStep -> evolutionStep.fitness()
149 								.get(1)),
150 				new EvolutionListener<FitnessVector<Double>>() {
151 
152 					long startTime = -1;
153 					long previousTime = -1;
154 
155 					@Override
156 					public void onEvolution(final long generation, final List<Genotype> population,
157 							final List<FitnessVector<Double>> fitness, final boolean isDone) {
158 						final long now = System.currentTimeMillis();
159 
160 						if (startTime < 0) {
161 							startTime = now;
162 						}
163 
164 						if (previousTime < 0) {
165 							previousTime = now;
166 						}
167 
168 						if (generation > 1) {
169 							logger.info("Execution time:");
170 							logger.info("\tCurrent generation duration: {}",
171 									DurationFormatUtils.formatDurationHMS(now - previousTime));
172 							logger.info("\tAverage duration: {}",
173 									DurationFormatUtils.formatDurationHMS((now - startTime) / (generation - 1)));
174 						}
175 
176 						previousTime = now;
177 					}
178 				});
179 
180 		final EAExecutionContext<FitnessVector<Double>> eaExecutionContext = eaExecutionContextBuilder.build();
181 		final EASystem<FitnessVector<Double>> eaSystem = EASystemFactory.from(eaConfiguration, eaExecutionContext);
182 
183 		final EvolutionResult<FitnessVector<Double>> evolutionResult = eaSystem.evolve();
184 		final Genotype bestGenotype = evolutionResult.bestGenotype();
185 		final TreeChromosome<Operation<?>> bestChromosome = (TreeChromosome<Operation<?>>) bestGenotype.getChromosome(0);
186 		logger.info("Best genotype: {}", bestChromosome.getRoot());
187 		logger.info("Best genotype - pretty print: {}", TreeNodeUtils.toStringTreeNode(bestChromosome.getRoot()));
188 
189 		final int depthIdx = 1;
190 		for (int i = 0; i < 15; i++) {
191 			final int depth = i;
192 			final Optional<Integer> optIdx = IntStream.range(0,
193 					evolutionResult.fitness()
194 							.size())
195 					.boxed()
196 					.filter((idx) -> evolutionResult.fitness()
197 							.get(idx)
198 							.get(depthIdx) == depth)
199 					.sorted((a, b) -> Double.compare(evolutionResult.fitness()
200 							.get(a)
201 							.get(0),
202 							evolutionResult.fitness()
203 									.get(b)
204 									.get(0)))
205 					.findFirst();
206 
207 			optIdx.stream()
208 					.forEach((idx) -> {
209 						final TreeChromosome<Operation<?>> treeChromosome = (TreeChromosome<Operation<?>>) evolutionResult
210 								.population()
211 								.get(idx)
212 								.getChromosome(0);
213 
214 						logger.info("Best genotype for depth {} - score {} -> {}",
215 								depth,
216 								evolutionResult.fitness()
217 										.get(idx)
218 										.get(0),
219 								TreeNodeUtils.toStringTreeNode(treeChromosome.getRoot()));
220 					});
221 		}
222 	}
223 
224 	public static void main(String[] args) throws IOException {
225 
226 		/**
227 		 * Parse CLI
228 		 */
229 
230 		final CommandLineParser parser = new DefaultParser();
231 
232 		final Options options = new Options();
233 		options.addOption(PARAM_DEST_CSV, LONG_PARAM_DEST_CSV, true, "destination csv file");
234 
235 		options.addOption(PARAM_POPULATION_SIZE, LONG_PARAM_POPULATION_SIZE, true, "Population size");
236 
237 		String csvFilename = DEFAULT_DEST_CSV;
238 		int populationSize = DEFAULT_POPULATION_SIZE;
239 		try {
240 			final CommandLine line = parser.parse(options, args);
241 
242 			if (line.hasOption(PARAM_DEST_CSV)) {
243 				csvFilename = line.getOptionValue(PARAM_DEST_CSV);
244 			}
245 
246 			if (line.hasOption(PARAM_POPULATION_SIZE)) {
247 				populationSize = Integer.parseInt(line.getOptionValue(PARAM_POPULATION_SIZE));
248 			}
249 
250 		} catch (ParseException exp) {
251 			cliError(options, "Unexpected exception:" + exp.getMessage());
252 		}
253 
254 		logger.info("Population size: {}", populationSize);
255 
256 		logger.info("CSV output located at {}", csvFilename);
257 		FileUtils.forceMkdirParent(new File(csvFilename));
258 
259 		final var symbolicRegression = new SymbolicRegressionWithMOOSPEA2();
260 		symbolicRegression.run(csvFilename, populationSize);
261 	}
262 }