1 package net.bmahe.genetics4j.samples.symbolicregression;
2
3 import java.util.ArrayList;
4 import java.util.Comparator;
5 import java.util.List;
6 import java.util.Random;
7 import java.util.Set;
8 import java.util.function.BiFunction;
9 import java.util.function.Function;
10
11 import org.apache.commons.lang3.StringUtils;
12 import org.apache.commons.lang3.Validate;
13
14 import net.bmahe.genetics4j.core.Genotype;
15 import net.bmahe.genetics4j.core.chromosomes.TreeChromosome;
16 import net.bmahe.genetics4j.core.evolutionlisteners.EvolutionListener;
17 import net.bmahe.genetics4j.extras.evolutionlisteners.CSVEvolutionListener;
18 import net.bmahe.genetics4j.extras.evolutionlisteners.ColumnExtractor;
19 import net.bmahe.genetics4j.extras.evolutionlisteners.EvolutionStep;
20 import net.bmahe.genetics4j.gp.ImmutableInputSpec;
21 import net.bmahe.genetics4j.gp.Operation;
22 import net.bmahe.genetics4j.gp.math.Functions;
23 import net.bmahe.genetics4j.gp.math.Terminals;
24 import net.bmahe.genetics4j.gp.program.ImmutableProgram;
25 import net.bmahe.genetics4j.gp.program.ImmutableProgram.Builder;
26 import net.bmahe.genetics4j.gp.program.Program;
27 import net.bmahe.genetics4j.gp.utils.TreeNodeUtils;
28 import net.bmahe.genetics4j.moo.FitnessVector;
29 import net.bmahe.genetics4j.moo.ParetoUtils;
30
31 public class SymbolicRegressionUtils {
32
33 private SymbolicRegressionUtils() {
34 }
35
36 public static Program buildProgram(final Random random) {
37 Validate.notNull(random);
38
39
40 final Builder programBuilder = ImmutableProgram.builder();
41 programBuilder.addFunctions(Functions.ADD, Functions.MUL, Functions.DIV, Functions.SUB, Functions.POW);
42 programBuilder.addTerminal(Terminals.InputDouble(random), Terminals.CoefficientRounded(random, -10, 10));
43
44 programBuilder.inputSpec(ImmutableInputSpec.of(List.of(Double.class)));
45 programBuilder.maxDepth(4);
46 final Program program = programBuilder.build();
47
48
49 return program;
50 }
51
52 public static double evaluate(final double x) {
53 return (6.0 * x * x) - x + 8;
54 }
55
56
57 public static <T extends Comparable<T>> EvolutionListener<T> csvLogger(final String filename,
58 final Function<EvolutionStep<T, List<Set<Integer>>>, Double> computeScore,
59 final Function<EvolutionStep<T, List<Set<Integer>>>, Double> computeComplexity,
60 final BiFunction<List<Genotype>, List<T>, List<FitnessVector<Double>>> convert2FitnessVector) {
61 Validate.isTrue(StringUtils.isNotBlank(filename));
62 Validate.notNull(computeScore);
63 Validate.notNull(computeComplexity);
64
65 return CSVEvolutionListener.<T, List<Set<Integer>>>of(filename, (generation, population, fitness, isDone) -> {
66 final List<FitnessVector<Double>> fitnessAndSizeVectors = convert2FitnessVector.apply(population, fitness);
67 return ParetoUtils.rankedPopulation(Comparator.<FitnessVector<Double>>reverseOrder(),
68 fitnessAndSizeVectors);
69 },
70 List.of(ColumnExtractor.of("generation", evolutionStep -> evolutionStep.generation()),
71 ColumnExtractor.of("score", evolutionStep -> computeScore.apply(evolutionStep)),
72 ColumnExtractor.of("complexity", evolutionStep -> computeComplexity.apply(evolutionStep)),
73 ColumnExtractor.of("rank", evolutionStep -> {
74
75 final List<Set<Integer>> rankedPopulation = evolutionStep.context().get();
76 Integer rank = null;
77 for (int i = 0; i < 5 && i < rankedPopulation.size() && rank == null; i++) {
78
79 if (rankedPopulation.get(i).contains(evolutionStep.individualIndex())) {
80 rank = i;
81 }
82 }
83
84 return rank != null ? rank : -1;
85 }),
86 ColumnExtractor.of("expression",
87 evolutionStep -> TreeNodeUtils.toStringTreeNode(evolutionStep.individual(), 0)))
88
89 );
90 }
91
92
93
94
95
96
97
98
99
100
101 public static EvolutionListener<FitnessVector<Double>> csvLogger(final String filename,
102 final Function<EvolutionStep<FitnessVector<Double>, List<Set<Integer>>>, Double> computeScore,
103 final Function<EvolutionStep<FitnessVector<Double>, List<Set<Integer>>>, Double> computeComplexity) {
104 Validate.isTrue(StringUtils.isNotBlank(filename));
105 Validate.notNull(computeScore);
106 Validate.notNull(computeComplexity);
107
108 return csvLogger(filename, computeScore, computeComplexity, (population, fitness) -> fitness);
109 }
110
111
112
113
114
115
116
117
118
119
120
121 public static EvolutionListener<Double> csvLoggerDouble(final String filename,
122 final Function<EvolutionStep<Double, List<Set<Integer>>>, Double> computeScore,
123 final Function<EvolutionStep<Double, List<Set<Integer>>>, Double> computeComplexity) {
124 Validate.isTrue(StringUtils.isNotBlank(filename));
125 Validate.notNull(computeScore);
126 Validate.notNull(computeComplexity);
127
128 return csvLogger(filename, computeScore, computeComplexity, (population, fitness) -> {
129 List<FitnessVector<Double>> fvs = new ArrayList<>();
130
131 for (int i = 0; i < fitness.size(); i++) {
132 final TreeChromosome<Operation<?>> chromosome = (TreeChromosome<Operation<?>>) population.get(i)
133 .getChromosome(0);
134 final int size = chromosome.getSize();
135
136
137
138
139
140 fvs.add(new FitnessVector<Double>(fitness.get(i), (double) size));
141 }
142 return fvs;
143 });
144 }
145 }