1 package net.bmahe.genetics4j.moo.nsga2.impl;
2
3 import java.util.Collection;
4 import java.util.Comparator;
5 import java.util.List;
6 import java.util.Objects;
7 import java.util.Set;
8 import java.util.TreeSet;
9 import java.util.function.Function;
10 import java.util.stream.Collectors;
11
12 import org.apache.commons.lang3.Validate;
13 import org.apache.logging.log4j.LogManager;
14 import org.apache.logging.log4j.Logger;
15
16 import net.bmahe.genetics4j.core.Genotype;
17 import net.bmahe.genetics4j.core.Population;
18 import net.bmahe.genetics4j.core.selection.Selector;
19 import net.bmahe.genetics4j.core.spec.AbstractEAConfiguration;
20 import net.bmahe.genetics4j.moo.ObjectiveDistance;
21 import net.bmahe.genetics4j.moo.ParetoUtils;
22 import net.bmahe.genetics4j.moo.nsga2.spec.NSGA2Selection;
23
24 public class NSGA2Selector<T extends Comparable<T>> implements Selector<T> {
25 final static public Logger logger = LogManager.getLogger(NSGA2Selector.class);
26
27 private final NSGA2Selection<T> nsga2Selection;
28
29 public NSGA2Selector(final NSGA2Selection<T> _nsga2Selection) {
30 Objects.requireNonNull(_nsga2Selection);
31
32 this.nsga2Selection = _nsga2Selection;
33 }
34
35 @Override
36 public Population<T> select(final AbstractEAConfiguration<T> eaConfiguration, final long generation,
37 final int numIndividuals, final List<Genotype> population, final List<T> fitnessScore) {
38 Objects.requireNonNull(eaConfiguration);
39 Objects.requireNonNull(population);
40 Objects.requireNonNull(fitnessScore);
41 Validate.isTrue(generation >= 0);
42 Validate.isTrue(numIndividuals > 0);
43 Validate.isTrue(population.size() == fitnessScore.size());
44
45 logger.debug("Incoming population size is {}", population.size());
46
47 final Population<T> individuals = new Population<>();
48 if (nsga2Selection.deduplicate()
49 .isPresent()) {
50 final Comparator<Genotype> individualDeduplicator = nsga2Selection.deduplicate()
51 .get();
52 final Set<Genotype> seenGenotype = new TreeSet<>(individualDeduplicator);
53
54 for (int i = 0; i < population.size(); i++) {
55 final Genotype genotype = population.get(i);
56 final T fitness = fitnessScore.get(i);
57
58 if (seenGenotype.add(genotype)) {
59 individuals.add(genotype, fitness);
60 }
61 }
62
63 } else {
64 for (int i = 0; i < population.size(); i++) {
65 final Genotype genotype = population.get(i);
66 final T fitness = fitnessScore.get(i);
67
68 individuals.add(genotype, fitness);
69 }
70 }
71
72 logger.debug("Selecting {} individuals from a population of {}", numIndividuals, individuals.size());
73
74 final int numberObjectives = nsga2Selection.numberObjectives();
75
76 final Comparator<T> dominance = switch (eaConfiguration.optimization()) {
77 case MAXIMIZE -> nsga2Selection.dominance();
78 case MINIMIZE -> nsga2Selection.dominance()
79 .reversed();
80 };
81
82 final Function<Integer, Comparator<T>> objectiveComparator = switch (eaConfiguration.optimization()) {
83 case MAXIMIZE -> nsga2Selection.objectiveComparator();
84 case MINIMIZE -> (m) -> nsga2Selection.objectiveComparator()
85 .apply(m)
86 .reversed();
87 };
88
89 final ObjectiveDistance<T> objectiveDistance = nsga2Selection.distance();
90
91 logger.debug("Ranking population");
92 final List<Set<Integer>> rankedPopulation = ParetoUtils.rankedPopulation(dominance,
93 individuals.getAllFitnesses());
94
95 logger.debug("Computing crowding distance assignment");
96 double[] crowdingDistanceAssignment = NSGA2Utils.crowdingDistanceAssignment(numberObjectives,
97 individuals.getAllFitnesses(),
98 objectiveComparator,
99 objectiveDistance);
100
101 logger.debug("Selecting individuals");
102 final Population<T> selectedIndividuals = new Population<>();
103 int currentFrontIndex = 0;
104 while (selectedIndividuals.size() < numIndividuals && currentFrontIndex < rankedPopulation.size()
105 && rankedPopulation.get(currentFrontIndex)
106 .size() > 0) {
107
108 final Set<Integer> currentFront = rankedPopulation.get(currentFrontIndex);
109
110 Collection<Integer> bestIndividuals = currentFront;
111 if (currentFront.size() > numIndividuals - selectedIndividuals.size()) {
112
113 bestIndividuals = currentFront.stream()
114 .sorted((a, b) -> Double.compare(crowdingDistanceAssignment[b], crowdingDistanceAssignment[a]))
115 .limit(numIndividuals - selectedIndividuals.size())
116 .collect(Collectors.toList());
117 }
118
119 for (final Integer individualIndex : bestIndividuals) {
120 if (logger.isTraceEnabled()) {
121 logger.trace("Adding individual with index {}, fitness {}, rank {}, crowding distance {}",
122 individualIndex,
123 individuals.getFitness(individualIndex),
124 currentFrontIndex,
125 crowdingDistanceAssignment[individualIndex]);
126 }
127
128 selectedIndividuals.add(individuals.getGenotype(individualIndex), individuals.getFitness(individualIndex));
129 }
130
131 logger.trace("Selected {} individuals from rank {}", bestIndividuals.size(), currentFrontIndex);
132 currentFrontIndex++;
133 }
134
135 return selectedIndividuals;
136 }
137 }