1 package net.bmahe.genetics4j.moo.nsga2.impl;
2
3 import java.util.Comparator;
4 import java.util.List;
5 import java.util.Objects;
6 import java.util.Set;
7 import java.util.TreeSet;
8 import java.util.function.Function;
9 import java.util.random.RandomGenerator;
10
11 import org.apache.commons.lang3.Validate;
12 import org.apache.logging.log4j.LogManager;
13 import org.apache.logging.log4j.Logger;
14
15 import net.bmahe.genetics4j.core.Genotype;
16 import net.bmahe.genetics4j.core.Population;
17 import net.bmahe.genetics4j.core.selection.Selector;
18 import net.bmahe.genetics4j.core.spec.AbstractEAConfiguration;
19 import net.bmahe.genetics4j.moo.ObjectiveDistance;
20 import net.bmahe.genetics4j.moo.ParetoUtils;
21 import net.bmahe.genetics4j.moo.nsga2.spec.TournamentNSGA2Selection;
22
23 public class TournamentNSGA2Selector<T extends Comparable<T>> implements Selector<T> {
24 final static public Logger logger = LogManager.getLogger(TournamentNSGA2Selector.class);
25
26 private final TournamentNSGA2Selection<T> tournamentNSGA2Selection;
27 private final RandomGenerator randomGenerator;
28
29 public TournamentNSGA2Selector(final RandomGenerator _randomGenerator,
30 final TournamentNSGA2Selection<T> _tournamentNSGA2Selection) {
31 Objects.requireNonNull(_randomGenerator);
32 Objects.requireNonNull(_tournamentNSGA2Selection);
33
34 this.randomGenerator = _randomGenerator;
35 this.tournamentNSGA2Selection = _tournamentNSGA2Selection;
36
37 }
38
39 @Override
40 public Population<T> select(final AbstractEAConfiguration<T> eaConfiguration, final long generation,
41 final int numIndividuals, final List<Genotype> population, final List<T> fitnessScore) {
42 Objects.requireNonNull(eaConfiguration);
43 Objects.requireNonNull(population);
44 Objects.requireNonNull(fitnessScore);
45 Validate.isTrue(generation >= 0);
46 Validate.isTrue(numIndividuals > 0);
47 Validate.isTrue(population.size() == fitnessScore.size());
48
49 logger.debug("Incoming population size is {}", population.size());
50
51 final Population<T> individuals = new Population<>();
52 if (tournamentNSGA2Selection.deduplicate().isPresent()) {
53 final Comparator<Genotype> individualDeduplicator = tournamentNSGA2Selection.deduplicate().get();
54 final Set<Genotype> seenGenotype = new TreeSet<>(individualDeduplicator);
55
56 for (int i = 0; i < population.size(); i++) {
57 final Genotype genotype = population.get(i);
58 final T fitness = fitnessScore.get(i);
59
60 if (seenGenotype.add(genotype)) {
61 individuals.add(genotype, fitness);
62 }
63 }
64
65 } else {
66 for (int i = 0; i < population.size(); i++) {
67 final Genotype genotype = population.get(i);
68 final T fitness = fitnessScore.get(i);
69
70 individuals.add(genotype, fitness);
71 }
72 }
73
74 logger.debug("Selecting {} individuals from a population of {}", numIndividuals, individuals.size());
75
76 final int numberObjectives = tournamentNSGA2Selection.numberObjectives();
77
78 final Comparator<T> dominance = switch (eaConfiguration.optimization()) {
79 case MAXIMIZE -> tournamentNSGA2Selection.dominance();
80 case MINIMIZE -> tournamentNSGA2Selection.dominance().reversed();
81 };
82
83 final Function<Integer, Comparator<T>> objectiveComparator = switch (eaConfiguration.optimization()) {
84 case MAXIMIZE -> tournamentNSGA2Selection.objectiveComparator();
85 case MINIMIZE -> (m) -> tournamentNSGA2Selection.objectiveComparator().apply(m).reversed();
86 };
87
88 final ObjectiveDistance<T> objectiveDistance = tournamentNSGA2Selection.distance();
89 final int numCandidates = tournamentNSGA2Selection.numCandidates();
90
91 logger.debug("Ranking population");
92 final List<Set<Integer>> rankedPopulation = ParetoUtils
93 .rankedPopulation(dominance, individuals.getAllFitnesses());
94
95 final int[] individual2Rank = new int[individuals.size()];
96 for (int j = 0; j < rankedPopulation.size(); j++) {
97 final Set<Integer> set = rankedPopulation.get(j);
98
99 for (final Integer idx : set) {
100 individual2Rank[idx] = j;
101 }
102 }
103
104 if (logger.isTraceEnabled()) {
105 logger.trace("Ranked population: {}", rankedPopulation);
106 for (int i = 0; i < rankedPopulation.size(); i++) {
107 final Set<Integer> subPopulationIdx = rankedPopulation.get(i);
108 logger.trace("\tRank {}", i);
109 for (final Integer index : subPopulationIdx) {
110 logger.trace("\t\t{} - Fitness {}", index, individuals.getFitness(index));
111 }
112 }
113 }
114 logger.debug("Computing crowding distance assignment");
115 final double[] crowdingDistanceAssignment = NSGA2Utils.crowdingDistanceAssignment(
116 numberObjectives,
117 individuals.getAllFitnesses(),
118 objectiveComparator,
119 objectiveDistance);
120
121 logger.debug("Performing tournaments");
122 final Population<T> selectedIndividuals = new Population<>();
123 while (selectedIndividuals.size() < numIndividuals) {
124
125 logger.trace("Performing tournament");
126 Genotype bestCandidate = null;
127 int bestCandidateIndex = -1;
128 T bestFitness = null;
129
130 for (int i = 0; i < numCandidates; i++) {
131 final int candidateIndex = randomGenerator.nextInt(individuals.size());
132
133 logger.trace(
134 "\tCandidate - index {} - rank {} - crowding distance {} - fitness {}",
135 candidateIndex,
136 individual2Rank[candidateIndex],
137 crowdingDistanceAssignment[candidateIndex],
138 individuals.getFitness(candidateIndex));
139
140 if (bestCandidate == null || individual2Rank[candidateIndex] < individual2Rank[bestCandidateIndex]
141 || (individual2Rank[candidateIndex] == individual2Rank[bestCandidateIndex]
142 && crowdingDistanceAssignment[candidateIndex] > crowdingDistanceAssignment[bestCandidateIndex])) {
143
144 logger.trace("\t candidate win!");
145 bestCandidate = individuals.getGenotype(candidateIndex);
146 bestFitness = individuals.getFitness(candidateIndex);
147 bestCandidateIndex = candidateIndex;
148 }
149 }
150
151 selectedIndividuals.add(bestCandidate, bestFitness);
152 }
153
154 return selectedIndividuals;
155 }
156 }