View Javadoc
1   package net.bmahe.genetics4j.moo.nsga2.impl;
2   
3   import java.util.Collection;
4   import java.util.Comparator;
5   import java.util.List;
6   import java.util.Objects;
7   import java.util.Set;
8   import java.util.TreeSet;
9   import java.util.function.Function;
10  import java.util.stream.Collectors;
11  
12  import org.apache.commons.lang3.Validate;
13  import org.apache.logging.log4j.LogManager;
14  import org.apache.logging.log4j.Logger;
15  
16  import net.bmahe.genetics4j.core.Genotype;
17  import net.bmahe.genetics4j.core.Population;
18  import net.bmahe.genetics4j.core.selection.Selector;
19  import net.bmahe.genetics4j.core.spec.AbstractEAConfiguration;
20  import net.bmahe.genetics4j.moo.ObjectiveDistance;
21  import net.bmahe.genetics4j.moo.ParetoUtils;
22  import net.bmahe.genetics4j.moo.nsga2.spec.NSGA2Selection;
23  
24  public class NSGA2Selector<T extends Comparable<T>> implements Selector<T> {
25  	final static public Logger logger = LogManager.getLogger(NSGA2Selector.class);
26  
27  	private final NSGA2Selection<T> nsga2Selection;
28  
29  	public NSGA2Selector(final NSGA2Selection<T> _nsga2Selection) {
30  		Objects.requireNonNull(_nsga2Selection);
31  
32  		this.nsga2Selection = _nsga2Selection;
33  	}
34  
35  	@Override
36  	public Population<T> select(final AbstractEAConfiguration<T> eaConfiguration, final long generation,
37  			final int numIndividuals, final List<Genotype> population, final List<T> fitnessScore) {
38  		Objects.requireNonNull(eaConfiguration);
39  		Objects.requireNonNull(population);
40  		Objects.requireNonNull(fitnessScore);
41  		Validate.isTrue(generation >= 0);
42  		Validate.isTrue(numIndividuals > 0);
43  		Validate.isTrue(population.size() == fitnessScore.size());
44  
45  		logger.debug("Incoming population size is {}", population.size());
46  
47  		final Population<T> individuals = new Population<>();
48  		if (nsga2Selection.deduplicate()
49  				.isPresent()) {
50  			final Comparator<Genotype> individualDeduplicator = nsga2Selection.deduplicate()
51  					.get();
52  			final Set<Genotype> seenGenotype = new TreeSet<>(individualDeduplicator);
53  
54  			for (int i = 0; i < population.size(); i++) {
55  				final Genotype genotype = population.get(i);
56  				final T fitness = fitnessScore.get(i);
57  
58  				if (seenGenotype.add(genotype)) {
59  					individuals.add(genotype, fitness);
60  				}
61  			}
62  
63  		} else {
64  			for (int i = 0; i < population.size(); i++) {
65  				final Genotype genotype = population.get(i);
66  				final T fitness = fitnessScore.get(i);
67  
68  				individuals.add(genotype, fitness);
69  			}
70  		}
71  
72  		logger.debug("Selecting {} individuals from a population of {}", numIndividuals, individuals.size());
73  
74  		final int numberObjectives = nsga2Selection.numberObjectives();
75  
76  		final Comparator<T> dominance = switch (eaConfiguration.optimization()) {
77  			case MAXIMIZE -> nsga2Selection.dominance();
78  			case MINIMIZE -> nsga2Selection.dominance()
79  					.reversed();
80  		};
81  
82  		final Function<Integer, Comparator<T>> objectiveComparator = switch (eaConfiguration.optimization()) {
83  			case MAXIMIZE -> nsga2Selection.objectiveComparator();
84  			case MINIMIZE -> (m) -> nsga2Selection.objectiveComparator()
85  					.apply(m)
86  					.reversed();
87  		};
88  
89  		final ObjectiveDistance<T> objectiveDistance = nsga2Selection.distance();
90  
91  		logger.debug("Ranking population");
92  		final List<Set<Integer>> rankedPopulation = ParetoUtils.rankedPopulation(dominance,
93  				individuals.getAllFitnesses());
94  
95  		logger.debug("Computing crowding distance assignment");
96  		double[] crowdingDistanceAssignment = NSGA2Utils.crowdingDistanceAssignment(numberObjectives,
97  				individuals.getAllFitnesses(),
98  				objectiveComparator,
99  				objectiveDistance);
100 
101 		logger.debug("Selecting individuals");
102 		final Population<T> selectedIndividuals = new Population<>();
103 		int currentFrontIndex = 0;
104 		while (selectedIndividuals.size() < numIndividuals && currentFrontIndex < rankedPopulation.size()
105 				&& rankedPopulation.get(currentFrontIndex)
106 						.size() > 0) {
107 
108 			final Set<Integer> currentFront = rankedPopulation.get(currentFrontIndex);
109 
110 			Collection<Integer> bestIndividuals = currentFront;
111 			if (currentFront.size() > numIndividuals - selectedIndividuals.size()) {
112 
113 				bestIndividuals = currentFront.stream()
114 						.sorted((a, b) -> Double.compare(crowdingDistanceAssignment[b], crowdingDistanceAssignment[a]))
115 						.limit(numIndividuals - selectedIndividuals.size())
116 						.collect(Collectors.toList());
117 			}
118 
119 			for (final Integer individualIndex : bestIndividuals) {
120 				if (logger.isTraceEnabled()) {
121 					logger.trace("Adding individual with index {}, fitness {}, rank {}, crowding distance {}",
122 							individualIndex,
123 							individuals.getFitness(individualIndex),
124 							currentFrontIndex,
125 							crowdingDistanceAssignment[individualIndex]);
126 				}
127 
128 				selectedIndividuals.add(individuals.getGenotype(individualIndex), individuals.getFitness(individualIndex));
129 			}
130 
131 			logger.trace("Selected {} individuals from rank {}", bestIndividuals.size(), currentFrontIndex);
132 			currentFrontIndex++;
133 		}
134 
135 		return selectedIndividuals;
136 	}
137 }