View Javadoc
1   package net.bmahe.genetics4j.samples.symbolicregression;
2   
3   import java.util.ArrayList;
4   import java.util.Comparator;
5   import java.util.List;
6   import java.util.Random;
7   import java.util.Set;
8   import java.util.function.BiFunction;
9   import java.util.function.Function;
10  
11  import org.apache.commons.lang3.StringUtils;
12  import org.apache.commons.lang3.Validate;
13  
14  import net.bmahe.genetics4j.core.Genotype;
15  import net.bmahe.genetics4j.core.chromosomes.TreeChromosome;
16  import net.bmahe.genetics4j.core.evolutionlisteners.EvolutionListener;
17  import net.bmahe.genetics4j.extras.evolutionlisteners.CSVEvolutionListener;
18  import net.bmahe.genetics4j.extras.evolutionlisteners.ColumnExtractor;
19  import net.bmahe.genetics4j.extras.evolutionlisteners.EvolutionStep;
20  import net.bmahe.genetics4j.gp.ImmutableInputSpec;
21  import net.bmahe.genetics4j.gp.Operation;
22  import net.bmahe.genetics4j.gp.math.Functions;
23  import net.bmahe.genetics4j.gp.math.Terminals;
24  import net.bmahe.genetics4j.gp.program.ImmutableProgram;
25  import net.bmahe.genetics4j.gp.program.ImmutableProgram.Builder;
26  import net.bmahe.genetics4j.gp.program.Program;
27  import net.bmahe.genetics4j.gp.utils.TreeNodeUtils;
28  import net.bmahe.genetics4j.moo.FitnessVector;
29  import net.bmahe.genetics4j.moo.ParetoUtils;
30  
31  public class SymbolicRegressionUtils {
32  
33  	private SymbolicRegressionUtils() {
34  	}
35  
36  	public static Program buildProgram(final Random random) {
37  		Validate.notNull(random);
38  
39  		// tag::program_def[]
40  		final Builder programBuilder = ImmutableProgram.builder();
41  		programBuilder.addFunctions(Functions.ADD, Functions.MUL, Functions.DIV, Functions.SUB, Functions.POW);
42  		programBuilder.addTerminal(Terminals.InputDouble(random), Terminals.CoefficientRounded(random, -10, 10));
43  
44  		programBuilder.inputSpec(ImmutableInputSpec.of(List.of(Double.class)));
45  		programBuilder.maxDepth(4);
46  		final Program program = programBuilder.build();
47  		// end::program_def[]
48  
49  		return program;
50  	}
51  
52  	public static double evaluate(final double x) {
53  		return (6.0 * x * x) - x + 8;
54  	}
55  
56  	// tag::csv_logger[]
57  	public static <T extends Comparable<T>> EvolutionListener<T> csvLogger(final String filename,
58  			final Function<EvolutionStep<T, List<Set<Integer>>>, Double> computeScore,
59  			final Function<EvolutionStep<T, List<Set<Integer>>>, Double> computeComplexity,
60  			final BiFunction<List<Genotype>, List<T>, List<FitnessVector<Double>>> convert2FitnessVector) {
61  		Validate.isTrue(StringUtils.isNotBlank(filename));
62  		Validate.notNull(computeScore);
63  		Validate.notNull(computeComplexity);
64  
65  		return CSVEvolutionListener.<T, List<Set<Integer>>>of(filename, (generation, population, fitness, isDone) -> {
66  			final List<FitnessVector<Double>> fitnessAndSizeVectors = convert2FitnessVector.apply(population, fitness);
67  			return ParetoUtils.rankedPopulation(Comparator.<FitnessVector<Double>>reverseOrder(),
68  					fitnessAndSizeVectors); // <1>
69  		},
70  				List.of(ColumnExtractor.of("generation", evolutionStep -> evolutionStep.generation()),
71  						ColumnExtractor.of("score", evolutionStep -> computeScore.apply(evolutionStep)),
72  						ColumnExtractor.of("complexity", evolutionStep -> computeComplexity.apply(evolutionStep)),
73  						ColumnExtractor.of("rank", evolutionStep -> {
74  
75  							final List<Set<Integer>> rankedPopulation = evolutionStep.context().get();
76  							Integer rank = null;
77  							for (int i = 0; i < 5 && i < rankedPopulation.size() && rank == null; i++) {
78  
79  								if (rankedPopulation.get(i).contains(evolutionStep.individualIndex())) {
80  									rank = i;
81  								}
82  							}
83  
84  							return rank != null ? rank : -1;
85  						}),
86  						ColumnExtractor.of("expression",
87  								evolutionStep -> TreeNodeUtils.toStringTreeNode(evolutionStep.individual(), 0)))
88  
89  		);
90  	}
91  	// end::csv_logger[]
92  
93  	/**
94  	 * Sepcialization for FitnessVector<Double>
95  	 * 
96  	 * @param filename
97  	 * @param computeScore
98  	 * @param computeComplexity
99  	 * @return
100 	 */
101 	public static EvolutionListener<FitnessVector<Double>> csvLogger(final String filename,
102 			final Function<EvolutionStep<FitnessVector<Double>, List<Set<Integer>>>, Double> computeScore,
103 			final Function<EvolutionStep<FitnessVector<Double>, List<Set<Integer>>>, Double> computeComplexity) {
104 		Validate.isTrue(StringUtils.isNotBlank(filename));
105 		Validate.notNull(computeScore);
106 		Validate.notNull(computeComplexity);
107 
108 		return csvLogger(filename, computeScore, computeComplexity, (population, fitness) -> fitness);
109 	}
110 
111 	/**
112 	 * Sepcialization for Double
113 	 * <p>
114 	 * We can't have the same method name as type erasure wouldn't allow it :(
115 	 * 
116 	 * @param filename
117 	 * @param computeScore
118 	 * @param computeComplexity
119 	 * @return
120 	 */
121 	public static EvolutionListener<Double> csvLoggerDouble(final String filename,
122 			final Function<EvolutionStep<Double, List<Set<Integer>>>, Double> computeScore,
123 			final Function<EvolutionStep<Double, List<Set<Integer>>>, Double> computeComplexity) {
124 		Validate.isTrue(StringUtils.isNotBlank(filename));
125 		Validate.notNull(computeScore);
126 		Validate.notNull(computeComplexity);
127 
128 		return csvLogger(filename, computeScore, computeComplexity, (population, fitness) -> {
129 			List<FitnessVector<Double>> fvs = new ArrayList<>();
130 
131 			for (int i = 0; i < fitness.size(); i++) {
132 				final TreeChromosome<Operation<?>> chromosome = (TreeChromosome<Operation<?>>) population.get(i)
133 						.getChromosome(0);
134 				final int size = chromosome.getSize();
135 
136 				/**
137 				 * Ideally we would re-compute the pure fitness but that would end up too
138 				 * expensive
139 				 */
140 				fvs.add(new FitnessVector<Double>(fitness.get(i), (double) size));
141 			}
142 			return fvs;
143 		});
144 	}
145 }