NSGA2Selector.java
package net.bmahe.genetics4j.moo.nsga2.impl;
import java.util.Collection;
import java.util.Comparator;
import java.util.List;
import java.util.Set;
import java.util.TreeSet;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.commons.lang3.Validate;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import net.bmahe.genetics4j.core.Genotype;
import net.bmahe.genetics4j.core.Population;
import net.bmahe.genetics4j.core.selection.Selector;
import net.bmahe.genetics4j.core.spec.AbstractEAConfiguration;
import net.bmahe.genetics4j.moo.ObjectiveDistance;
import net.bmahe.genetics4j.moo.ParetoUtils;
import net.bmahe.genetics4j.moo.nsga2.spec.NSGA2Selection;
public class NSGA2Selector<T extends Comparable<T>> implements Selector<T> {
final static public Logger logger = LogManager.getLogger(NSGA2Selector.class);
private final NSGA2Selection<T> nsga2Selection;
public NSGA2Selector(final NSGA2Selection<T> _nsga2Selection) {
Validate.notNull(_nsga2Selection);
this.nsga2Selection = _nsga2Selection;
}
@Override
public Population<T> select(final AbstractEAConfiguration<T> eaConfiguration, final int numIndividuals,
final List<Genotype> population, final List<T> fitnessScore) {
Validate.notNull(eaConfiguration);
Validate.notNull(population);
Validate.notNull(fitnessScore);
Validate.isTrue(numIndividuals > 0);
Validate.isTrue(population.size() == fitnessScore.size());
logger.debug("Incoming population size is {}", population.size());
final Population<T> individuals = new Population<>();
if (nsga2Selection.deduplicate()
.isPresent()) {
final Comparator<Genotype> individualDeduplicator = nsga2Selection.deduplicate()
.get();
final Set<Genotype> seenGenotype = new TreeSet<>(individualDeduplicator);
for (int i = 0; i < population.size(); i++) {
final Genotype genotype = population.get(i);
final T fitness = fitnessScore.get(i);
if (seenGenotype.add(genotype)) {
individuals.add(genotype, fitness);
}
}
} else {
for (int i = 0; i < population.size(); i++) {
final Genotype genotype = population.get(i);
final T fitness = fitnessScore.get(i);
individuals.add(genotype, fitness);
}
}
logger.debug("Selecting {} individuals from a population of {}", numIndividuals, individuals.size());
final int numberObjectives = nsga2Selection.numberObjectives();
final Comparator<T> dominance = switch (eaConfiguration.optimization()) {
case MAXIMIZE -> nsga2Selection.dominance();
case MINIMIZE -> nsga2Selection.dominance()
.reversed();
};
final Function<Integer, Comparator<T>> objectiveComparator = switch (eaConfiguration.optimization()) {
case MAXIMIZE -> nsga2Selection.objectiveComparator();
case MINIMIZE -> (m) -> nsga2Selection.objectiveComparator()
.apply(m)
.reversed();
};
final ObjectiveDistance<T> objectiveDistance = nsga2Selection.distance();
logger.debug("Ranking population");
final List<Set<Integer>> rankedPopulation = ParetoUtils.rankedPopulation(dominance,
individuals.getAllFitnesses());
logger.debug("Computing crowding distance assignment");
double[] crowdingDistanceAssignment = NSGA2Utils.crowdingDistanceAssignment(numberObjectives,
individuals.getAllFitnesses(),
objectiveComparator,
objectiveDistance);
logger.debug("Selecting individuals");
final Population<T> selectedIndividuals = new Population<>();
int currentFrontIndex = 0;
while (selectedIndividuals.size() < numIndividuals && currentFrontIndex < rankedPopulation.size()
&& rankedPopulation.get(currentFrontIndex)
.size() > 0) {
final Set<Integer> currentFront = rankedPopulation.get(currentFrontIndex);
Collection<Integer> bestIndividuals = currentFront;
if (currentFront.size() > numIndividuals - selectedIndividuals.size()) {
bestIndividuals = currentFront.stream()
.sorted((a, b) -> Double.compare(crowdingDistanceAssignment[b], crowdingDistanceAssignment[a]))
.limit(numIndividuals - selectedIndividuals.size())
.collect(Collectors.toList());
}
for (final Integer individualIndex : bestIndividuals) {
if (logger.isTraceEnabled()) {
logger.trace("Adding individual with index {}, fitness {}, rank {}, crowding distance {}",
individualIndex,
individuals.getFitness(individualIndex),
currentFrontIndex,
crowdingDistanceAssignment[individualIndex]);
}
selectedIndividuals.add(individuals.getGenotype(individualIndex), individuals.getFitness(individualIndex));
}
logger.trace("Selected {} individuals from rank {}", bestIndividuals.size(), currentFrontIndex);
currentFrontIndex++;
}
return selectedIndividuals;
}
}