TournamentNSGA2Selector.java
package net.bmahe.genetics4j.moo.nsga2.impl;
import java.util.Comparator;
import java.util.List;
import java.util.Set;
import java.util.TreeSet;
import java.util.function.Function;
import java.util.random.RandomGenerator;
import org.apache.commons.lang3.Validate;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import net.bmahe.genetics4j.core.Genotype;
import net.bmahe.genetics4j.core.Population;
import net.bmahe.genetics4j.core.selection.Selector;
import net.bmahe.genetics4j.core.spec.AbstractEAConfiguration;
import net.bmahe.genetics4j.moo.ObjectiveDistance;
import net.bmahe.genetics4j.moo.ParetoUtils;
import net.bmahe.genetics4j.moo.nsga2.spec.TournamentNSGA2Selection;
public class TournamentNSGA2Selector<T extends Comparable<T>> implements Selector<T> {
final static public Logger logger = LogManager.getLogger(TournamentNSGA2Selector.class);
private final TournamentNSGA2Selection<T> tournamentNSGA2Selection;
private final RandomGenerator randomGenerator;
public TournamentNSGA2Selector(final RandomGenerator _randomGenerator,
final TournamentNSGA2Selection<T> _tournamentNSGA2Selection) {
Validate.notNull(_randomGenerator);
Validate.notNull(_tournamentNSGA2Selection);
this.randomGenerator = _randomGenerator;
this.tournamentNSGA2Selection = _tournamentNSGA2Selection;
}
@Override
public Population<T> select(final AbstractEAConfiguration<T> eaConfiguration, final int numIndividuals,
final List<Genotype> population, final List<T> fitnessScore) {
Validate.notNull(eaConfiguration);
Validate.notNull(population);
Validate.notNull(fitnessScore);
Validate.isTrue(numIndividuals > 0);
Validate.isTrue(population.size() == fitnessScore.size());
logger.debug("Incoming population size is {}", population.size());
final Population<T> individuals = new Population<>();
if (tournamentNSGA2Selection.deduplicate()
.isPresent()) {
final Comparator<Genotype> individualDeduplicator = tournamentNSGA2Selection.deduplicate()
.get();
final Set<Genotype> seenGenotype = new TreeSet<>(individualDeduplicator);
for (int i = 0; i < population.size(); i++) {
final Genotype genotype = population.get(i);
final T fitness = fitnessScore.get(i);
if (seenGenotype.add(genotype)) {
individuals.add(genotype, fitness);
}
}
} else {
for (int i = 0; i < population.size(); i++) {
final Genotype genotype = population.get(i);
final T fitness = fitnessScore.get(i);
individuals.add(genotype, fitness);
}
}
logger.debug("Selecting {} individuals from a population of {}", numIndividuals, individuals.size());
final int numberObjectives = tournamentNSGA2Selection.numberObjectives();
final Comparator<T> dominance = switch (eaConfiguration.optimization()) {
case MAXIMIZE -> tournamentNSGA2Selection.dominance();
case MINIMIZE -> tournamentNSGA2Selection.dominance()
.reversed();
};
final Function<Integer, Comparator<T>> objectiveComparator = switch (eaConfiguration.optimization()) {
case MAXIMIZE -> tournamentNSGA2Selection.objectiveComparator();
case MINIMIZE -> (m) -> tournamentNSGA2Selection.objectiveComparator()
.apply(m)
.reversed();
};
final ObjectiveDistance<T> objectiveDistance = tournamentNSGA2Selection.distance();
final int numCandidates = tournamentNSGA2Selection.numCandidates();
logger.debug("Ranking population");
final List<Set<Integer>> rankedPopulation = ParetoUtils.rankedPopulation(dominance,
individuals.getAllFitnesses());
// Build a reverse index
final int[] individual2Rank = new int[individuals.size()];
for (int j = 0; j < rankedPopulation.size(); j++) {
final Set<Integer> set = rankedPopulation.get(j);
for (final Integer idx : set) {
individual2Rank[idx] = j;
}
}
if (logger.isTraceEnabled()) {
logger.trace("Ranked population: {}", rankedPopulation);
for (int i = 0; i < rankedPopulation.size(); i++) {
final Set<Integer> subPopulationIdx = rankedPopulation.get(i);
logger.trace("\tRank {}", i);
for (final Integer index : subPopulationIdx) {
logger.trace("\t\t{} - Fitness {}", index, individuals.getFitness(index));
}
}
}
logger.debug("Computing crowding distance assignment");
final double[] crowdingDistanceAssignment = NSGA2Utils.crowdingDistanceAssignment(numberObjectives,
individuals.getAllFitnesses(),
objectiveComparator,
objectiveDistance);
logger.debug("Performing tournaments");
final Population<T> selectedIndividuals = new Population<>();
while (selectedIndividuals.size() < numIndividuals) {
logger.trace("Performing tournament");
Genotype bestCandidate = null;
int bestCandidateIndex = -1;
T bestFitness = null;
for (int i = 0; i < numCandidates; i++) {
final int candidateIndex = randomGenerator.nextInt(individuals.size());
logger.trace("\tCandidate - index {} - rank {} - crowding distance {} - fitness {}",
candidateIndex,
individual2Rank[candidateIndex],
crowdingDistanceAssignment[candidateIndex],
individuals.getFitness(candidateIndex));
if (bestCandidate == null || individual2Rank[candidateIndex] < individual2Rank[bestCandidateIndex]
|| (individual2Rank[candidateIndex] == individual2Rank[bestCandidateIndex]
&& crowdingDistanceAssignment[candidateIndex] > crowdingDistanceAssignment[bestCandidateIndex])) {
logger.trace("\t candidate win!");
bestCandidate = individuals.getGenotype(candidateIndex);
bestFitness = individuals.getFitness(candidateIndex);
bestCandidateIndex = candidateIndex;
}
}
selectedIndividuals.add(bestCandidate, bestFitness);
}
return selectedIndividuals;
}
}