1 package net.bmahe.genetics4j.moo.nsga2.impl; 2 3 import java.util.Collection; 4 import java.util.Comparator; 5 import java.util.List; 6 import java.util.Set; 7 import java.util.TreeSet; 8 import java.util.function.Function; 9 import java.util.stream.Collectors; 10 11 import org.apache.commons.lang3.Validate; 12 import org.apache.logging.log4j.LogManager; 13 import org.apache.logging.log4j.Logger; 14 15 import net.bmahe.genetics4j.core.Genotype; 16 import net.bmahe.genetics4j.core.Population; 17 import net.bmahe.genetics4j.core.selection.Selector; 18 import net.bmahe.genetics4j.core.spec.AbstractEAConfiguration; 19 import net.bmahe.genetics4j.moo.ObjectiveDistance; 20 import net.bmahe.genetics4j.moo.ParetoUtils; 21 import net.bmahe.genetics4j.moo.nsga2.spec.NSGA2Selection; 22 23 public class NSGA2Selector<T extends Comparable<T>> implements Selector<T> { 24 final static public Logger logger = LogManager.getLogger(NSGA2Selector.class); 25 26 private final NSGA2Selection<T> nsga2Selection; 27 28 public NSGA2Selector(final NSGA2Selection<T> _nsga2Selection) { 29 Validate.notNull(_nsga2Selection); 30 31 this.nsga2Selection = _nsga2Selection; 32 } 33 34 @Override 35 public Population<T> select(final AbstractEAConfiguration<T> eaConfiguration, final int numIndividuals, 36 final List<Genotype> population, final List<T> fitnessScore) { 37 Validate.notNull(eaConfiguration); 38 Validate.notNull(population); 39 Validate.notNull(fitnessScore); 40 Validate.isTrue(numIndividuals > 0); 41 Validate.isTrue(population.size() == fitnessScore.size()); 42 43 logger.debug("Incoming population size is {}", population.size()); 44 45 final Population<T> individuals = new Population<>(); 46 if (nsga2Selection.deduplicate() 47 .isPresent()) { 48 final Comparator<Genotype> individualDeduplicator = nsga2Selection.deduplicate() 49 .get(); 50 final Set<Genotype> seenGenotype = new TreeSet<>(individualDeduplicator); 51 52 for (int i = 0; i < population.size(); i++) { 53 final Genotype genotype = population.get(i); 54 final T fitness = fitnessScore.get(i); 55 56 if (seenGenotype.add(genotype)) { 57 individuals.add(genotype, fitness); 58 } 59 } 60 61 } else { 62 for (int i = 0; i < population.size(); i++) { 63 final Genotype genotype = population.get(i); 64 final T fitness = fitnessScore.get(i); 65 66 individuals.add(genotype, fitness); 67 } 68 } 69 70 logger.debug("Selecting {} individuals from a population of {}", numIndividuals, individuals.size()); 71 72 final int numberObjectives = nsga2Selection.numberObjectives(); 73 74 final Comparator<T> dominance = switch (eaConfiguration.optimization()) { 75 case MAXIMIZE -> nsga2Selection.dominance(); 76 case MINIMIZE -> nsga2Selection.dominance() 77 .reversed(); 78 }; 79 80 final Function<Integer, Comparator<T>> objectiveComparator = switch (eaConfiguration.optimization()) { 81 case MAXIMIZE -> nsga2Selection.objectiveComparator(); 82 case MINIMIZE -> (m) -> nsga2Selection.objectiveComparator() 83 .apply(m) 84 .reversed(); 85 }; 86 87 final ObjectiveDistance<T> objectiveDistance = nsga2Selection.distance(); 88 89 logger.debug("Ranking population"); 90 final List<Set<Integer>> rankedPopulation = ParetoUtils.rankedPopulation(dominance, 91 individuals.getAllFitnesses()); 92 93 logger.debug("Computing crowding distance assignment"); 94 double[] crowdingDistanceAssignment = NSGA2Utils.crowdingDistanceAssignment(numberObjectives, 95 individuals.getAllFitnesses(), 96 objectiveComparator, 97 objectiveDistance); 98 99 logger.debug("Selecting individuals"); 100 final Population<T> selectedIndividuals = new Population<>(); 101 int currentFrontIndex = 0; 102 while (selectedIndividuals.size() < numIndividuals && currentFrontIndex < rankedPopulation.size() 103 && rankedPopulation.get(currentFrontIndex) 104 .size() > 0) { 105 106 final Set<Integer> currentFront = rankedPopulation.get(currentFrontIndex); 107 108 Collection<Integer> bestIndividuals = currentFront; 109 if (currentFront.size() > numIndividuals - selectedIndividuals.size()) { 110 111 bestIndividuals = currentFront.stream() 112 .sorted((a, b) -> Double.compare(crowdingDistanceAssignment[b], crowdingDistanceAssignment[a])) 113 .limit(numIndividuals - selectedIndividuals.size()) 114 .collect(Collectors.toList()); 115 } 116 117 for (final Integer individualIndex : bestIndividuals) { 118 if (logger.isTraceEnabled()) { 119 logger.trace("Adding individual with index {}, fitness {}, rank {}, crowding distance {}", 120 individualIndex, 121 individuals.getFitness(individualIndex), 122 currentFrontIndex, 123 crowdingDistanceAssignment[individualIndex]); 124 } 125 126 selectedIndividuals.add(individuals.getGenotype(individualIndex), individuals.getFitness(individualIndex)); 127 } 128 129 logger.trace("Selected {} individuals from rank {}", bestIndividuals.size(), currentFrontIndex); 130 currentFrontIndex++; 131 } 132 133 return selectedIndividuals; 134 } 135 }