1 package net.bmahe.genetics4j.moo.nsga2.impl;
2
3 import java.util.Collection;
4 import java.util.Comparator;
5 import java.util.List;
6 import java.util.Objects;
7 import java.util.Set;
8 import java.util.TreeSet;
9 import java.util.function.Function;
10 import java.util.stream.Collectors;
11
12 import org.apache.commons.lang3.Validate;
13 import org.apache.logging.log4j.LogManager;
14 import org.apache.logging.log4j.Logger;
15
16 import net.bmahe.genetics4j.core.Genotype;
17 import net.bmahe.genetics4j.core.Population;
18 import net.bmahe.genetics4j.core.selection.Selector;
19 import net.bmahe.genetics4j.core.spec.AbstractEAConfiguration;
20 import net.bmahe.genetics4j.moo.ObjectiveDistance;
21 import net.bmahe.genetics4j.moo.ParetoUtils;
22 import net.bmahe.genetics4j.moo.nsga2.spec.NSGA2Selection;
23
24 public class NSGA2Selector<T extends Comparable<T>> implements Selector<T> {
25 final static public Logger logger = LogManager.getLogger(NSGA2Selector.class);
26
27 private final NSGA2Selection<T> nsga2Selection;
28
29 public NSGA2Selector(final NSGA2Selection<T> _nsga2Selection) {
30 Objects.requireNonNull(_nsga2Selection);
31
32 this.nsga2Selection = _nsga2Selection;
33 }
34
35 @Override
36 public Population<T> select(final AbstractEAConfiguration<T> eaConfiguration, final long generation,
37 final int numIndividuals, final List<Genotype> population, final List<T> fitnessScore) {
38 Objects.requireNonNull(eaConfiguration);
39 Objects.requireNonNull(population);
40 Objects.requireNonNull(fitnessScore);
41 Validate.isTrue(generation >= 0);
42 Validate.isTrue(numIndividuals > 0);
43 Validate.isTrue(population.size() == fitnessScore.size());
44
45 logger.debug("Incoming population size is {}", population.size());
46
47 final Population<T> individuals = new Population<>();
48 if (nsga2Selection.deduplicate().isPresent()) {
49 final Comparator<Genotype> individualDeduplicator = nsga2Selection.deduplicate().get();
50 final Set<Genotype> seenGenotype = new TreeSet<>(individualDeduplicator);
51
52 for (int i = 0; i < population.size(); i++) {
53 final Genotype genotype = population.get(i);
54 final T fitness = fitnessScore.get(i);
55
56 if (seenGenotype.add(genotype)) {
57 individuals.add(genotype, fitness);
58 }
59 }
60
61 } else {
62 for (int i = 0; i < population.size(); i++) {
63 final Genotype genotype = population.get(i);
64 final T fitness = fitnessScore.get(i);
65
66 individuals.add(genotype, fitness);
67 }
68 }
69
70 logger.debug("Selecting {} individuals from a population of {}", numIndividuals, individuals.size());
71
72 final int numberObjectives = nsga2Selection.numberObjectives();
73
74 final Comparator<T> dominance = switch (eaConfiguration.optimization()) {
75 case MAXIMIZE -> nsga2Selection.dominance();
76 case MINIMIZE -> nsga2Selection.dominance().reversed();
77 };
78
79 final Function<Integer, Comparator<T>> objectiveComparator = switch (eaConfiguration.optimization()) {
80 case MAXIMIZE -> nsga2Selection.objectiveComparator();
81 case MINIMIZE -> (m) -> nsga2Selection.objectiveComparator().apply(m).reversed();
82 };
83
84 final ObjectiveDistance<T> objectiveDistance = nsga2Selection.distance();
85
86 logger.debug("Ranking population");
87 final List<Set<Integer>> rankedPopulation = ParetoUtils
88 .rankedPopulation(dominance, individuals.getAllFitnesses());
89
90 logger.debug("Computing crowding distance assignment");
91 double[] crowdingDistanceAssignment = NSGA2Utils.crowdingDistanceAssignment(
92 numberObjectives,
93 individuals.getAllFitnesses(),
94 objectiveComparator,
95 objectiveDistance);
96
97 logger.debug("Selecting individuals");
98 final Population<T> selectedIndividuals = new Population<>();
99 int currentFrontIndex = 0;
100 while (selectedIndividuals.size() < numIndividuals && currentFrontIndex < rankedPopulation.size()
101 && rankedPopulation.get(currentFrontIndex).size() > 0) {
102
103 final Set<Integer> currentFront = rankedPopulation.get(currentFrontIndex);
104
105 Collection<Integer> bestIndividuals = currentFront;
106 if (currentFront.size() > numIndividuals - selectedIndividuals.size()) {
107
108 bestIndividuals = currentFront.stream()
109 .sorted((a, b) -> Double.compare(crowdingDistanceAssignment[b], crowdingDistanceAssignment[a]))
110 .limit(numIndividuals - selectedIndividuals.size())
111 .collect(Collectors.toList());
112 }
113
114 for (final Integer individualIndex : bestIndividuals) {
115 if (logger.isTraceEnabled()) {
116 logger.trace(
117 "Adding individual with index {}, fitness {}, rank {}, crowding distance {}",
118 individualIndex,
119 individuals.getFitness(individualIndex),
120 currentFrontIndex,
121 crowdingDistanceAssignment[individualIndex]);
122 }
123
124 selectedIndividuals.add(individuals.getGenotype(individualIndex), individuals.getFitness(individualIndex));
125 }
126
127 logger.trace("Selected {} individuals from rank {}", bestIndividuals.size(), currentFrontIndex);
128 currentFrontIndex++;
129 }
130
131 return selectedIndividuals;
132 }
133 }