View Javadoc
1   package net.bmahe.genetics4j.moo.nsga2.impl;
2   
3   import java.util.Collection;
4   import java.util.Comparator;
5   import java.util.List;
6   import java.util.Objects;
7   import java.util.Set;
8   import java.util.TreeSet;
9   import java.util.function.Function;
10  import java.util.stream.Collectors;
11  
12  import org.apache.commons.lang3.Validate;
13  import org.apache.logging.log4j.LogManager;
14  import org.apache.logging.log4j.Logger;
15  
16  import net.bmahe.genetics4j.core.Genotype;
17  import net.bmahe.genetics4j.core.Population;
18  import net.bmahe.genetics4j.core.selection.Selector;
19  import net.bmahe.genetics4j.core.spec.AbstractEAConfiguration;
20  import net.bmahe.genetics4j.moo.ObjectiveDistance;
21  import net.bmahe.genetics4j.moo.ParetoUtils;
22  import net.bmahe.genetics4j.moo.nsga2.spec.NSGA2Selection;
23  
24  public class NSGA2Selector<T extends Comparable<T>> implements Selector<T> {
25  	final static public Logger logger = LogManager.getLogger(NSGA2Selector.class);
26  
27  	private final NSGA2Selection<T> nsga2Selection;
28  
29  	public NSGA2Selector(final NSGA2Selection<T> _nsga2Selection) {
30  		Objects.requireNonNull(_nsga2Selection);
31  
32  		this.nsga2Selection = _nsga2Selection;
33  	}
34  
35  	@Override
36  	public Population<T> select(final AbstractEAConfiguration<T> eaConfiguration, final long generation,
37  			final int numIndividuals, final List<Genotype> population, final List<T> fitnessScore) {
38  		Objects.requireNonNull(eaConfiguration);
39  		Objects.requireNonNull(population);
40  		Objects.requireNonNull(fitnessScore);
41  		Validate.isTrue(generation >= 0);
42  		Validate.isTrue(numIndividuals > 0);
43  		Validate.isTrue(population.size() == fitnessScore.size());
44  
45  		logger.debug("Incoming population size is {}", population.size());
46  
47  		final Population<T> individuals = new Population<>();
48  		if (nsga2Selection.deduplicate().isPresent()) {
49  			final Comparator<Genotype> individualDeduplicator = nsga2Selection.deduplicate().get();
50  			final Set<Genotype> seenGenotype = new TreeSet<>(individualDeduplicator);
51  
52  			for (int i = 0; i < population.size(); i++) {
53  				final Genotype genotype = population.get(i);
54  				final T fitness = fitnessScore.get(i);
55  
56  				if (seenGenotype.add(genotype)) {
57  					individuals.add(genotype, fitness);
58  				}
59  			}
60  
61  		} else {
62  			for (int i = 0; i < population.size(); i++) {
63  				final Genotype genotype = population.get(i);
64  				final T fitness = fitnessScore.get(i);
65  
66  				individuals.add(genotype, fitness);
67  			}
68  		}
69  
70  		logger.debug("Selecting {} individuals from a population of {}", numIndividuals, individuals.size());
71  
72  		final int numberObjectives = nsga2Selection.numberObjectives();
73  
74  		final Comparator<T> dominance = switch (eaConfiguration.optimization()) {
75  			case MAXIMIZE -> nsga2Selection.dominance();
76  			case MINIMIZE -> nsga2Selection.dominance().reversed();
77  		};
78  
79  		final Function<Integer, Comparator<T>> objectiveComparator = switch (eaConfiguration.optimization()) {
80  			case MAXIMIZE -> nsga2Selection.objectiveComparator();
81  			case MINIMIZE -> (m) -> nsga2Selection.objectiveComparator().apply(m).reversed();
82  		};
83  
84  		final ObjectiveDistance<T> objectiveDistance = nsga2Selection.distance();
85  
86  		logger.debug("Ranking population");
87  		final List<Set<Integer>> rankedPopulation = ParetoUtils
88  				.rankedPopulation(dominance, individuals.getAllFitnesses());
89  
90  		logger.debug("Computing crowding distance assignment");
91  		double[] crowdingDistanceAssignment = NSGA2Utils.crowdingDistanceAssignment(
92  				numberObjectives,
93  					individuals.getAllFitnesses(),
94  					objectiveComparator,
95  					objectiveDistance);
96  
97  		logger.debug("Selecting individuals");
98  		final Population<T> selectedIndividuals = new Population<>();
99  		int currentFrontIndex = 0;
100 		while (selectedIndividuals.size() < numIndividuals && currentFrontIndex < rankedPopulation.size()
101 				&& rankedPopulation.get(currentFrontIndex).size() > 0) {
102 
103 			final Set<Integer> currentFront = rankedPopulation.get(currentFrontIndex);
104 
105 			Collection<Integer> bestIndividuals = currentFront;
106 			if (currentFront.size() > numIndividuals - selectedIndividuals.size()) {
107 
108 				bestIndividuals = currentFront.stream()
109 						.sorted((a, b) -> Double.compare(crowdingDistanceAssignment[b], crowdingDistanceAssignment[a]))
110 						.limit(numIndividuals - selectedIndividuals.size())
111 						.collect(Collectors.toList());
112 			}
113 
114 			for (final Integer individualIndex : bestIndividuals) {
115 				if (logger.isTraceEnabled()) {
116 					logger.trace(
117 							"Adding individual with index {}, fitness {}, rank {}, crowding distance {}",
118 								individualIndex,
119 								individuals.getFitness(individualIndex),
120 								currentFrontIndex,
121 								crowdingDistanceAssignment[individualIndex]);
122 				}
123 
124 				selectedIndividuals.add(individuals.getGenotype(individualIndex), individuals.getFitness(individualIndex));
125 			}
126 
127 			logger.trace("Selected {} individuals from rank {}", bestIndividuals.size(), currentFrontIndex);
128 			currentFrontIndex++;
129 		}
130 
131 		return selectedIndividuals;
132 	}
133 }