1 package net.bmahe.genetics4j.moo.nsga2.impl; 2 3 import java.util.Collection; 4 import java.util.Comparator; 5 import java.util.List; 6 import java.util.Objects; 7 import java.util.Set; 8 import java.util.TreeSet; 9 import java.util.function.Function; 10 import java.util.stream.Collectors; 11 12 import org.apache.commons.lang3.Validate; 13 import org.apache.logging.log4j.LogManager; 14 import org.apache.logging.log4j.Logger; 15 16 import net.bmahe.genetics4j.core.Genotype; 17 import net.bmahe.genetics4j.core.Population; 18 import net.bmahe.genetics4j.core.selection.Selector; 19 import net.bmahe.genetics4j.core.spec.AbstractEAConfiguration; 20 import net.bmahe.genetics4j.moo.ObjectiveDistance; 21 import net.bmahe.genetics4j.moo.ParetoUtils; 22 import net.bmahe.genetics4j.moo.nsga2.spec.NSGA2Selection; 23 24 public class NSGA2Selector<T extends Comparable<T>> implements Selector<T> { 25 final static public Logger logger = LogManager.getLogger(NSGA2Selector.class); 26 27 private final NSGA2Selection<T> nsga2Selection; 28 29 public NSGA2Selector(final NSGA2Selection<T> _nsga2Selection) { 30 Objects.requireNonNull(_nsga2Selection); 31 32 this.nsga2Selection = _nsga2Selection; 33 } 34 35 @Override 36 public Population<T> select(final AbstractEAConfiguration<T> eaConfiguration, final long generation, 37 final int numIndividuals, final List<Genotype> population, final List<T> fitnessScore) { 38 Objects.requireNonNull(eaConfiguration); 39 Objects.requireNonNull(population); 40 Objects.requireNonNull(fitnessScore); 41 Validate.isTrue(generation >= 0); 42 Validate.isTrue(numIndividuals > 0); 43 Validate.isTrue(population.size() == fitnessScore.size()); 44 45 logger.debug("Incoming population size is {}", population.size()); 46 47 final Population<T> individuals = new Population<>(); 48 if (nsga2Selection.deduplicate() 49 .isPresent()) { 50 final Comparator<Genotype> individualDeduplicator = nsga2Selection.deduplicate() 51 .get(); 52 final Set<Genotype> seenGenotype = new TreeSet<>(individualDeduplicator); 53 54 for (int i = 0; i < population.size(); i++) { 55 final Genotype genotype = population.get(i); 56 final T fitness = fitnessScore.get(i); 57 58 if (seenGenotype.add(genotype)) { 59 individuals.add(genotype, fitness); 60 } 61 } 62 63 } else { 64 for (int i = 0; i < population.size(); i++) { 65 final Genotype genotype = population.get(i); 66 final T fitness = fitnessScore.get(i); 67 68 individuals.add(genotype, fitness); 69 } 70 } 71 72 logger.debug("Selecting {} individuals from a population of {}", numIndividuals, individuals.size()); 73 74 final int numberObjectives = nsga2Selection.numberObjectives(); 75 76 final Comparator<T> dominance = switch (eaConfiguration.optimization()) { 77 case MAXIMIZE -> nsga2Selection.dominance(); 78 case MINIMIZE -> nsga2Selection.dominance() 79 .reversed(); 80 }; 81 82 final Function<Integer, Comparator<T>> objectiveComparator = switch (eaConfiguration.optimization()) { 83 case MAXIMIZE -> nsga2Selection.objectiveComparator(); 84 case MINIMIZE -> (m) -> nsga2Selection.objectiveComparator() 85 .apply(m) 86 .reversed(); 87 }; 88 89 final ObjectiveDistance<T> objectiveDistance = nsga2Selection.distance(); 90 91 logger.debug("Ranking population"); 92 final List<Set<Integer>> rankedPopulation = ParetoUtils.rankedPopulation(dominance, 93 individuals.getAllFitnesses()); 94 95 logger.debug("Computing crowding distance assignment"); 96 double[] crowdingDistanceAssignment = NSGA2Utils.crowdingDistanceAssignment(numberObjectives, 97 individuals.getAllFitnesses(), 98 objectiveComparator, 99 objectiveDistance); 100 101 logger.debug("Selecting individuals"); 102 final Population<T> selectedIndividuals = new Population<>(); 103 int currentFrontIndex = 0; 104 while (selectedIndividuals.size() < numIndividuals && currentFrontIndex < rankedPopulation.size() 105 && rankedPopulation.get(currentFrontIndex) 106 .size() > 0) { 107 108 final Set<Integer> currentFront = rankedPopulation.get(currentFrontIndex); 109 110 Collection<Integer> bestIndividuals = currentFront; 111 if (currentFront.size() > numIndividuals - selectedIndividuals.size()) { 112 113 bestIndividuals = currentFront.stream() 114 .sorted((a, b) -> Double.compare(crowdingDistanceAssignment[b], crowdingDistanceAssignment[a])) 115 .limit(numIndividuals - selectedIndividuals.size()) 116 .collect(Collectors.toList()); 117 } 118 119 for (final Integer individualIndex : bestIndividuals) { 120 if (logger.isTraceEnabled()) { 121 logger.trace("Adding individual with index {}, fitness {}, rank {}, crowding distance {}", 122 individualIndex, 123 individuals.getFitness(individualIndex), 124 currentFrontIndex, 125 crowdingDistanceAssignment[individualIndex]); 126 } 127 128 selectedIndividuals.add(individuals.getGenotype(individualIndex), individuals.getFitness(individualIndex)); 129 } 130 131 logger.trace("Selected {} individuals from rank {}", bestIndividuals.size(), currentFrontIndex); 132 currentFrontIndex++; 133 } 134 135 return selectedIndividuals; 136 } 137 }