View Javadoc
1   package net.bmahe.genetics4j.moo.nsga2.impl;
2   
3   import java.util.Collection;
4   import java.util.Comparator;
5   import java.util.List;
6   import java.util.Set;
7   import java.util.TreeSet;
8   import java.util.function.Function;
9   import java.util.stream.Collectors;
10  
11  import org.apache.commons.lang3.Validate;
12  import org.apache.logging.log4j.LogManager;
13  import org.apache.logging.log4j.Logger;
14  
15  import net.bmahe.genetics4j.core.Genotype;
16  import net.bmahe.genetics4j.core.Population;
17  import net.bmahe.genetics4j.core.selection.Selector;
18  import net.bmahe.genetics4j.core.spec.AbstractEAConfiguration;
19  import net.bmahe.genetics4j.moo.ObjectiveDistance;
20  import net.bmahe.genetics4j.moo.ParetoUtils;
21  import net.bmahe.genetics4j.moo.nsga2.spec.NSGA2Selection;
22  
23  public class NSGA2Selector<T extends Comparable<T>> implements Selector<T> {
24  	final static public Logger logger = LogManager.getLogger(NSGA2Selector.class);
25  
26  	private final NSGA2Selection<T> nsga2Selection;
27  
28  	public NSGA2Selector(final NSGA2Selection<T> _nsga2Selection) {
29  		Validate.notNull(_nsga2Selection);
30  
31  		this.nsga2Selection = _nsga2Selection;
32  	}
33  
34  	@Override
35  	public Population<T> select(final AbstractEAConfiguration<T> eaConfiguration, final int numIndividuals,
36  			final List<Genotype> population, final List<T> fitnessScore) {
37  		Validate.notNull(eaConfiguration);
38  		Validate.notNull(population);
39  		Validate.notNull(fitnessScore);
40  		Validate.isTrue(numIndividuals > 0);
41  		Validate.isTrue(population.size() == fitnessScore.size());
42  
43  		logger.debug("Incoming population size is {}", population.size());
44  
45  		final Population<T> individuals = new Population<>();
46  		if (nsga2Selection.deduplicate()
47  				.isPresent()) {
48  			final Comparator<Genotype> individualDeduplicator = nsga2Selection.deduplicate()
49  					.get();
50  			final Set<Genotype> seenGenotype = new TreeSet<>(individualDeduplicator);
51  
52  			for (int i = 0; i < population.size(); i++) {
53  				final Genotype genotype = population.get(i);
54  				final T fitness = fitnessScore.get(i);
55  
56  				if (seenGenotype.add(genotype)) {
57  					individuals.add(genotype, fitness);
58  				}
59  			}
60  
61  		} else {
62  			for (int i = 0; i < population.size(); i++) {
63  				final Genotype genotype = population.get(i);
64  				final T fitness = fitnessScore.get(i);
65  
66  				individuals.add(genotype, fitness);
67  			}
68  		}
69  
70  		logger.debug("Selecting {} individuals from a population of {}", numIndividuals, individuals.size());
71  
72  		final int numberObjectives = nsga2Selection.numberObjectives();
73  
74  		final Comparator<T> dominance = switch (eaConfiguration.optimization()) {
75  			case MAXIMIZE -> nsga2Selection.dominance();
76  			case MINIMIZE -> nsga2Selection.dominance()
77  					.reversed();
78  		};
79  
80  		final Function<Integer, Comparator<T>> objectiveComparator = switch (eaConfiguration.optimization()) {
81  			case MAXIMIZE -> nsga2Selection.objectiveComparator();
82  			case MINIMIZE -> (m) -> nsga2Selection.objectiveComparator()
83  					.apply(m)
84  					.reversed();
85  		};
86  
87  		final ObjectiveDistance<T> objectiveDistance = nsga2Selection.distance();
88  
89  		logger.debug("Ranking population");
90  		final List<Set<Integer>> rankedPopulation = ParetoUtils.rankedPopulation(dominance,
91  				individuals.getAllFitnesses());
92  
93  		logger.debug("Computing crowding distance assignment");
94  		double[] crowdingDistanceAssignment = NSGA2Utils.crowdingDistanceAssignment(numberObjectives,
95  				individuals.getAllFitnesses(),
96  				objectiveComparator,
97  				objectiveDistance);
98  
99  		logger.debug("Selecting individuals");
100 		final Population<T> selectedIndividuals = new Population<>();
101 		int currentFrontIndex = 0;
102 		while (selectedIndividuals.size() < numIndividuals && currentFrontIndex < rankedPopulation.size()
103 				&& rankedPopulation.get(currentFrontIndex)
104 						.size() > 0) {
105 
106 			final Set<Integer> currentFront = rankedPopulation.get(currentFrontIndex);
107 
108 			Collection<Integer> bestIndividuals = currentFront;
109 			if (currentFront.size() > numIndividuals - selectedIndividuals.size()) {
110 
111 				bestIndividuals = currentFront.stream()
112 						.sorted((a, b) -> Double.compare(crowdingDistanceAssignment[b], crowdingDistanceAssignment[a]))
113 						.limit(numIndividuals - selectedIndividuals.size())
114 						.collect(Collectors.toList());
115 			}
116 
117 			for (final Integer individualIndex : bestIndividuals) {
118 				if (logger.isTraceEnabled()) {
119 					logger.trace("Adding individual with index {}, fitness {}, rank {}, crowding distance {}",
120 							individualIndex,
121 							individuals.getFitness(individualIndex),
122 							currentFrontIndex,
123 							crowdingDistanceAssignment[individualIndex]);
124 				}
125 
126 				selectedIndividuals.add(individuals.getGenotype(individualIndex), individuals.getFitness(individualIndex));
127 			}
128 
129 			logger.trace("Selected {} individuals from rank {}", bestIndividuals.size(), currentFrontIndex);
130 			currentFrontIndex++;
131 		}
132 
133 		return selectedIndividuals;
134 	}
135 }