1 package net.bmahe.genetics4j.moo.nsga2.impl;
2
3 import java.util.Comparator;
4 import java.util.List;
5 import java.util.Objects;
6 import java.util.Set;
7 import java.util.TreeSet;
8 import java.util.function.Function;
9 import java.util.random.RandomGenerator;
10
11 import org.apache.commons.lang3.Validate;
12 import org.apache.logging.log4j.LogManager;
13 import org.apache.logging.log4j.Logger;
14
15 import net.bmahe.genetics4j.core.Genotype;
16 import net.bmahe.genetics4j.core.Population;
17 import net.bmahe.genetics4j.core.selection.Selector;
18 import net.bmahe.genetics4j.core.spec.AbstractEAConfiguration;
19 import net.bmahe.genetics4j.moo.ObjectiveDistance;
20 import net.bmahe.genetics4j.moo.ParetoUtils;
21 import net.bmahe.genetics4j.moo.nsga2.spec.TournamentNSGA2Selection;
22
23 public class TournamentNSGA2Selector<T extends Comparable<T>> implements Selector<T> {
24 final static public Logger logger = LogManager.getLogger(TournamentNSGA2Selector.class);
25
26 private final TournamentNSGA2Selection<T> tournamentNSGA2Selection;
27 private final RandomGenerator randomGenerator;
28
29 public TournamentNSGA2Selector(final RandomGenerator _randomGenerator,
30 final TournamentNSGA2Selection<T> _tournamentNSGA2Selection) {
31 Objects.requireNonNull(_randomGenerator);
32 Objects.requireNonNull(_tournamentNSGA2Selection);
33
34 this.randomGenerator = _randomGenerator;
35 this.tournamentNSGA2Selection = _tournamentNSGA2Selection;
36
37 }
38
39 @Override
40 public Population<T> select(final AbstractEAConfiguration<T> eaConfiguration, final long generation,
41 final int numIndividuals, final List<Genotype> population, final List<T> fitnessScore) {
42 Objects.requireNonNull(eaConfiguration);
43 Objects.requireNonNull(population);
44 Objects.requireNonNull(fitnessScore);
45 Validate.isTrue(generation >= 0);
46 Validate.isTrue(numIndividuals > 0);
47 Validate.isTrue(population.size() == fitnessScore.size());
48
49 logger.debug("Incoming population size is {}", population.size());
50
51 final Population<T> individuals = new Population<>();
52 if (tournamentNSGA2Selection.deduplicate()
53 .isPresent()) {
54 final Comparator<Genotype> individualDeduplicator = tournamentNSGA2Selection.deduplicate()
55 .get();
56 final Set<Genotype> seenGenotype = new TreeSet<>(individualDeduplicator);
57
58 for (int i = 0; i < population.size(); i++) {
59 final Genotype genotype = population.get(i);
60 final T fitness = fitnessScore.get(i);
61
62 if (seenGenotype.add(genotype)) {
63 individuals.add(genotype, fitness);
64 }
65 }
66
67 } else {
68 for (int i = 0; i < population.size(); i++) {
69 final Genotype genotype = population.get(i);
70 final T fitness = fitnessScore.get(i);
71
72 individuals.add(genotype, fitness);
73 }
74 }
75
76 logger.debug("Selecting {} individuals from a population of {}", numIndividuals, individuals.size());
77
78 final int numberObjectives = tournamentNSGA2Selection.numberObjectives();
79
80 final Comparator<T> dominance = switch (eaConfiguration.optimization()) {
81 case MAXIMIZE -> tournamentNSGA2Selection.dominance();
82 case MINIMIZE -> tournamentNSGA2Selection.dominance()
83 .reversed();
84 };
85
86 final Function<Integer, Comparator<T>> objectiveComparator = switch (eaConfiguration.optimization()) {
87 case MAXIMIZE -> tournamentNSGA2Selection.objectiveComparator();
88 case MINIMIZE -> (m) -> tournamentNSGA2Selection.objectiveComparator()
89 .apply(m)
90 .reversed();
91 };
92
93 final ObjectiveDistance<T> objectiveDistance = tournamentNSGA2Selection.distance();
94 final int numCandidates = tournamentNSGA2Selection.numCandidates();
95
96 logger.debug("Ranking population");
97 final List<Set<Integer>> rankedPopulation = ParetoUtils.rankedPopulation(dominance,
98 individuals.getAllFitnesses());
99
100 final int[] individual2Rank = new int[individuals.size()];
101 for (int j = 0; j < rankedPopulation.size(); j++) {
102 final Set<Integer> set = rankedPopulation.get(j);
103
104 for (final Integer idx : set) {
105 individual2Rank[idx] = j;
106 }
107 }
108
109 if (logger.isTraceEnabled()) {
110 logger.trace("Ranked population: {}", rankedPopulation);
111 for (int i = 0; i < rankedPopulation.size(); i++) {
112 final Set<Integer> subPopulationIdx = rankedPopulation.get(i);
113 logger.trace("\tRank {}", i);
114 for (final Integer index : subPopulationIdx) {
115 logger.trace("\t\t{} - Fitness {}", index, individuals.getFitness(index));
116 }
117 }
118 }
119 logger.debug("Computing crowding distance assignment");
120 final double[] crowdingDistanceAssignment = NSGA2Utils.crowdingDistanceAssignment(numberObjectives,
121 individuals.getAllFitnesses(),
122 objectiveComparator,
123 objectiveDistance);
124
125 logger.debug("Performing tournaments");
126 final Population<T> selectedIndividuals = new Population<>();
127 while (selectedIndividuals.size() < numIndividuals) {
128
129 logger.trace("Performing tournament");
130 Genotype bestCandidate = null;
131 int bestCandidateIndex = -1;
132 T bestFitness = null;
133
134 for (int i = 0; i < numCandidates; i++) {
135 final int candidateIndex = randomGenerator.nextInt(individuals.size());
136
137 logger.trace("\tCandidate - index {} - rank {} - crowding distance {} - fitness {}",
138 candidateIndex,
139 individual2Rank[candidateIndex],
140 crowdingDistanceAssignment[candidateIndex],
141 individuals.getFitness(candidateIndex));
142
143 if (bestCandidate == null || individual2Rank[candidateIndex] < individual2Rank[bestCandidateIndex]
144 || (individual2Rank[candidateIndex] == individual2Rank[bestCandidateIndex]
145 && crowdingDistanceAssignment[candidateIndex] > crowdingDistanceAssignment[bestCandidateIndex])) {
146
147 logger.trace("\t candidate win!");
148 bestCandidate = individuals.getGenotype(candidateIndex);
149 bestFitness = individuals.getFitness(candidateIndex);
150 bestCandidateIndex = candidateIndex;
151 }
152 }
153
154 selectedIndividuals.add(bestCandidate, bestFitness);
155 }
156
157 return selectedIndividuals;
158 }
159 }