View Javadoc
1   package net.bmahe.genetics4j.samples.symbolicregression;
2   
3   import static net.bmahe.genetics4j.core.termination.Terminations.ofFitnessAtMost;
4   import static net.bmahe.genetics4j.core.termination.Terminations.ofMaxGeneration;
5   import static net.bmahe.genetics4j.core.termination.Terminations.or;
6   
7   import java.io.File;
8   import java.io.IOException;
9   import java.util.Comparator;
10  import java.util.Random;
11  
12  import org.apache.commons.cli.CommandLine;
13  import org.apache.commons.cli.CommandLineParser;
14  import org.apache.commons.cli.DefaultParser;
15  import org.apache.commons.cli.HelpFormatter;
16  import org.apache.commons.cli.Options;
17  import org.apache.commons.cli.ParseException;
18  import org.apache.commons.io.FileUtils;
19  import org.apache.commons.lang3.StringUtils;
20  import org.apache.commons.lang3.Validate;
21  import org.apache.logging.log4j.LogManager;
22  import org.apache.logging.log4j.Logger;
23  
24  import net.bmahe.genetics4j.core.EASystem;
25  import net.bmahe.genetics4j.core.EASystemFactory;
26  import net.bmahe.genetics4j.core.Fitness;
27  import net.bmahe.genetics4j.core.Genotype;
28  import net.bmahe.genetics4j.core.Individual;
29  import net.bmahe.genetics4j.core.chromosomes.TreeChromosome;
30  import net.bmahe.genetics4j.core.chromosomes.TreeNode;
31  import net.bmahe.genetics4j.core.evolutionlisteners.EvolutionListeners;
32  import net.bmahe.genetics4j.core.spec.EAConfiguration;
33  import net.bmahe.genetics4j.core.spec.EAExecutionContext;
34  import net.bmahe.genetics4j.core.spec.EAExecutionContexts;
35  import net.bmahe.genetics4j.core.spec.EvolutionResult;
36  import net.bmahe.genetics4j.core.spec.Optimization;
37  import net.bmahe.genetics4j.core.spec.replacement.Elitism;
38  import net.bmahe.genetics4j.core.spec.selection.SelectiveRefinementTournament;
39  import net.bmahe.genetics4j.core.spec.selection.Tournament;
40  import net.bmahe.genetics4j.gp.Operation;
41  import net.bmahe.genetics4j.gp.math.SimplificationRules;
42  import net.bmahe.genetics4j.gp.program.Program;
43  import net.bmahe.genetics4j.gp.spec.GPEAExecutionContexts;
44  import net.bmahe.genetics4j.gp.spec.chromosome.ProgramTreeChromosomeSpec;
45  import net.bmahe.genetics4j.gp.spec.combination.ProgramRandomCombine;
46  import net.bmahe.genetics4j.gp.spec.mutation.NodeReplacement;
47  import net.bmahe.genetics4j.gp.spec.mutation.ProgramApplyRules;
48  import net.bmahe.genetics4j.gp.spec.mutation.ProgramRandomMutate;
49  import net.bmahe.genetics4j.gp.spec.mutation.ProgramRandomPrune;
50  import net.bmahe.genetics4j.gp.utils.ProgramUtils;
51  import net.bmahe.genetics4j.gp.utils.TreeNodeUtils;
52  
53  public class SymbolicRegressionWithSRT {
54  	final static public Logger logger = LogManager.getLogger(SymbolicRegressionWithSRT.class);
55  
56  	final static public String PARAM_DEST_CSV = "d";
57  	final static public String LONG_PARAM_DEST_CSV = "csv-dest";
58  
59  	final static public String PARAM_POPULATION_SIZE = "p";
60  	final static public String LONG_PARAM_POPULATION_SIZE = "population-size";
61  
62  	final static public String DEFAULT_DEST_CSV = SymbolicRegressionWithSRT.class.getSimpleName() + ".csv";
63  
64  	final static public int DEFAULT_POPULATION_SIZE = 500;
65  
66  	public static void cliError(final Options options, final String errorMessage) {
67  		final HelpFormatter formatter = new HelpFormatter();
68  		logger.error(errorMessage);
69  		formatter.printHelp(SymbolicRegressionWithSRT.class.getSimpleName(), options);
70  		System.exit(-1);
71  	}
72  
73  	@SuppressWarnings("unchecked")
74  	public void run(String csvFilename, int populationSize) {
75  		Validate.isTrue(StringUtils.isNotBlank(csvFilename));
76  		Validate.isTrue(populationSize > 0);
77  
78  		final Random random = new Random();
79  
80  		final Program program = SymbolicRegressionUtils.buildProgram(random);
81  
82  		// tag::compute_fitness[]
83  		final Double[][] inputs = new Double[100][1];
84  		for (int i = 0; i < 100; i++) {
85  			inputs[i][0] = (i - 50) * 1.2;
86  		}
87  
88  		final Fitness<Double> computeFitness = (genoType) -> {
89  			final TreeChromosome<Operation<?>> chromosome = (TreeChromosome<Operation<?>>) genoType.getChromosome(0);
90  
91  			double mse = 0;
92  			for (final Double[] input : inputs) {
93  
94  				final double x = input[0];
95  				final double expected = SymbolicRegressionUtils.evaluate(x);
96  				final Object result = ProgramUtils.execute(chromosome, input);
97  
98  				if (Double.isFinite(expected)) {
99  					final Double resultDouble = (Double) result;
100 					mse += Double.isFinite(resultDouble) ? (expected - resultDouble) * (expected - resultDouble)
101 							: 1_000_000_000;
102 				}
103 			}
104 			return Double.isFinite(mse) ? mse / 100.0d : Double.MAX_VALUE;
105 		};
106 		// end::compute_fitness[]
107 
108 		// tag::srt_tournament[]
109 		final Comparator<Individual<Double>> parsimonyComparator = (a, b) -> {
110 			final var treeChromosomeA = a.genotype()
111 					.getChromosome(0, TreeChromosome.class);
112 			final var treeChromosomeB = b.genotype()
113 					.getChromosome(0, TreeChromosome.class);
114 
115 			return -Integer.compare(treeChromosomeA.getSize(), treeChromosomeB.getSize());
116 		};
117 
118 		final SelectiveRefinementTournament<Double> selectiveRefinementTournament = SelectiveRefinementTournament
119 				.<Double>builder()
120 				.tournament(Tournament.of(3))
121 				.refinementComparator(parsimonyComparator)
122 				.refinementRatio(0.65f)
123 				.build();
124 
125 		// end::srt_tournament[]
126 
127 		// tag::ea_config[]
128 		final var eaConfigurationBuilder = new EAConfiguration.Builder<Double>();
129 		eaConfigurationBuilder.chromosomeSpecs(ProgramTreeChromosomeSpec.of(program)) // <1>
130 				.parentSelectionPolicy(selectiveRefinementTournament)
131 				.replacementStrategy(Elitism.builder() // <2>
132 						.offspringRatio(0.99)
133 						.offspringSelectionPolicy(selectiveRefinementTournament)
134 						.survivorSelectionPolicy(selectiveRefinementTournament)
135 						.build())
136 				.combinationPolicy(ProgramRandomCombine.build())
137 				.mutationPolicies(ProgramRandomMutate.of(0.10),
138 						ProgramRandomPrune.of(0.12),
139 						NodeReplacement.of(0.05),
140 						ProgramApplyRules.of(SimplificationRules.SIMPLIFY_RULES))
141 				.optimization(Optimization.MINIMIZE) // <3>
142 				.termination(or(ofMaxGeneration(200), ofFitnessAtMost(0.00001)))
143 				.fitness(computeFitness);
144 		final EAConfiguration<Double> eaConfiguration = eaConfigurationBuilder.build();
145 		// end::ea_config[]
146 
147 		final var eaExecutionContextBuilder = GPEAExecutionContexts.<Double>forGP(random);
148 		EAExecutionContexts.enrichForScalarFitness(eaExecutionContextBuilder);
149 
150 		eaExecutionContextBuilder.populationSize(populationSize);
151 		eaExecutionContextBuilder.numberOfPartitions(Math.max(1,
152 				Runtime.getRuntime()
153 						.availableProcessors() - 1));
154 
155 		eaExecutionContextBuilder.addEvolutionListeners(
156 				EvolutionListeners.ofLogTopN(logger, 5, Comparator.<Double>reverseOrder(), (genotype) -> {
157 					final TreeChromosome<Operation<?>> chromosome = (TreeChromosome<Operation<?>>) genotype.getChromosome(0);
158 					final TreeNode<Operation<?>> root = chromosome.getRoot();
159 
160 					return TreeNodeUtils.toStringTreeNode(root);
161 				}),
162 				SymbolicRegressionUtils.csvLoggerDouble(csvFilename,
163 						evolutionStep -> evolutionStep.fitness(),
164 						evolutionStep -> (double) evolutionStep.individual()
165 								.getChromosome(0, TreeChromosome.class)
166 								.getSize()));
167 
168 		final EAExecutionContext<Double> eaExecutionContext = eaExecutionContextBuilder.build();
169 		final EASystem<Double> eaSystem = EASystemFactory.from(eaConfiguration, eaExecutionContext);
170 
171 		final EvolutionResult<Double> evolutionResult = eaSystem.evolve();
172 		final Genotype bestGenotype = evolutionResult.bestGenotype();
173 		final TreeChromosome<Operation<?>> bestChromosome = (TreeChromosome<Operation<?>>) bestGenotype.getChromosome(0);
174 		logger.info("Best genotype: {}", bestChromosome.getRoot());
175 		logger.info("Best genotype - pretty print: {}", TreeNodeUtils.toStringTreeNode(bestChromosome.getRoot()));
176 	}
177 
178 	public static void main(String[] args) throws IOException {
179 
180 		/**
181 		 * Parse CLI
182 		 */
183 
184 		final CommandLineParser parser = new DefaultParser();
185 
186 		final Options options = new Options();
187 		options.addOption(PARAM_DEST_CSV, LONG_PARAM_DEST_CSV, true, "destination csv file");
188 
189 		options.addOption(PARAM_POPULATION_SIZE, LONG_PARAM_POPULATION_SIZE, true, "Population size");
190 
191 		String csvFilename = DEFAULT_DEST_CSV;
192 		int populationSize = DEFAULT_POPULATION_SIZE;
193 		try {
194 			final CommandLine line = parser.parse(options, args);
195 
196 			if (line.hasOption(PARAM_DEST_CSV)) {
197 				csvFilename = line.getOptionValue(PARAM_DEST_CSV);
198 			}
199 
200 			if (line.hasOption(PARAM_POPULATION_SIZE)) {
201 				populationSize = Integer.parseInt(line.getOptionValue(PARAM_POPULATION_SIZE));
202 			}
203 
204 		} catch (ParseException exp) {
205 			cliError(options, "Unexpected exception:" + exp.getMessage());
206 		}
207 
208 		logger.info("Population size: {}", populationSize);
209 
210 		logger.info("CSV output located at {}", csvFilename);
211 		FileUtils.forceMkdirParent(new File(csvFilename));
212 
213 		final var symbolicRegression = new SymbolicRegressionWithSRT();
214 		symbolicRegression.run(csvFilename, populationSize);
215 	}
216 }