View Javadoc
1   package net.bmahe.genetics4j.moo.nsga2.impl;
2   
3   import java.util.Comparator;
4   import java.util.List;
5   import java.util.Objects;
6   import java.util.Set;
7   import java.util.TreeSet;
8   import java.util.function.Function;
9   import java.util.random.RandomGenerator;
10  
11  import org.apache.commons.lang3.Validate;
12  import org.apache.logging.log4j.LogManager;
13  import org.apache.logging.log4j.Logger;
14  
15  import net.bmahe.genetics4j.core.Genotype;
16  import net.bmahe.genetics4j.core.Population;
17  import net.bmahe.genetics4j.core.selection.Selector;
18  import net.bmahe.genetics4j.core.spec.AbstractEAConfiguration;
19  import net.bmahe.genetics4j.moo.ObjectiveDistance;
20  import net.bmahe.genetics4j.moo.ParetoUtils;
21  import net.bmahe.genetics4j.moo.nsga2.spec.TournamentNSGA2Selection;
22  
23  public class TournamentNSGA2Selector<T extends Comparable<T>> implements Selector<T> {
24  	final static public Logger logger = LogManager.getLogger(TournamentNSGA2Selector.class);
25  
26  	private final TournamentNSGA2Selection<T> tournamentNSGA2Selection;
27  	private final RandomGenerator randomGenerator;
28  
29  	public TournamentNSGA2Selector(final RandomGenerator _randomGenerator,
30  			final TournamentNSGA2Selection<T> _tournamentNSGA2Selection) {
31  		Objects.requireNonNull(_randomGenerator);
32  		Objects.requireNonNull(_tournamentNSGA2Selection);
33  
34  		this.randomGenerator = _randomGenerator;
35  		this.tournamentNSGA2Selection = _tournamentNSGA2Selection;
36  
37  	}
38  
39  	@Override
40  	public Population<T> select(final AbstractEAConfiguration<T> eaConfiguration, final long generation,
41  			final int numIndividuals, final List<Genotype> population, final List<T> fitnessScore) {
42  		Objects.requireNonNull(eaConfiguration);
43  		Objects.requireNonNull(population);
44  		Objects.requireNonNull(fitnessScore);
45  		Validate.isTrue(generation >= 0);
46  		Validate.isTrue(numIndividuals > 0);
47  		Validate.isTrue(population.size() == fitnessScore.size());
48  
49  		logger.debug("Incoming population size is {}", population.size());
50  
51  		final Population<T> individuals = new Population<>();
52  		if (tournamentNSGA2Selection.deduplicate()
53  				.isPresent()) {
54  			final Comparator<Genotype> individualDeduplicator = tournamentNSGA2Selection.deduplicate()
55  					.get();
56  			final Set<Genotype> seenGenotype = new TreeSet<>(individualDeduplicator);
57  
58  			for (int i = 0; i < population.size(); i++) {
59  				final Genotype genotype = population.get(i);
60  				final T fitness = fitnessScore.get(i);
61  
62  				if (seenGenotype.add(genotype)) {
63  					individuals.add(genotype, fitness);
64  				}
65  			}
66  
67  		} else {
68  			for (int i = 0; i < population.size(); i++) {
69  				final Genotype genotype = population.get(i);
70  				final T fitness = fitnessScore.get(i);
71  
72  				individuals.add(genotype, fitness);
73  			}
74  		}
75  
76  		logger.debug("Selecting {} individuals from a population of {}", numIndividuals, individuals.size());
77  
78  		final int numberObjectives = tournamentNSGA2Selection.numberObjectives();
79  
80  		final Comparator<T> dominance = switch (eaConfiguration.optimization()) {
81  			case MAXIMIZE -> tournamentNSGA2Selection.dominance();
82  			case MINIMIZE -> tournamentNSGA2Selection.dominance()
83  					.reversed();
84  		};
85  
86  		final Function<Integer, Comparator<T>> objectiveComparator = switch (eaConfiguration.optimization()) {
87  			case MAXIMIZE -> tournamentNSGA2Selection.objectiveComparator();
88  			case MINIMIZE -> (m) -> tournamentNSGA2Selection.objectiveComparator()
89  					.apply(m)
90  					.reversed();
91  		};
92  
93  		final ObjectiveDistance<T> objectiveDistance = tournamentNSGA2Selection.distance();
94  		final int numCandidates = tournamentNSGA2Selection.numCandidates();
95  
96  		logger.debug("Ranking population");
97  		final List<Set<Integer>> rankedPopulation = ParetoUtils.rankedPopulation(dominance,
98  				individuals.getAllFitnesses());
99  		// Build a reverse index
100 		final int[] individual2Rank = new int[individuals.size()];
101 		for (int j = 0; j < rankedPopulation.size(); j++) {
102 			final Set<Integer> set = rankedPopulation.get(j);
103 
104 			for (final Integer idx : set) {
105 				individual2Rank[idx] = j;
106 			}
107 		}
108 
109 		if (logger.isTraceEnabled()) {
110 			logger.trace("Ranked population: {}", rankedPopulation);
111 			for (int i = 0; i < rankedPopulation.size(); i++) {
112 				final Set<Integer> subPopulationIdx = rankedPopulation.get(i);
113 				logger.trace("\tRank {}", i);
114 				for (final Integer index : subPopulationIdx) {
115 					logger.trace("\t\t{} - Fitness {}", index, individuals.getFitness(index));
116 				}
117 			}
118 		}
119 		logger.debug("Computing crowding distance assignment");
120 		final double[] crowdingDistanceAssignment = NSGA2Utils.crowdingDistanceAssignment(numberObjectives,
121 				individuals.getAllFitnesses(),
122 				objectiveComparator,
123 				objectiveDistance);
124 
125 		logger.debug("Performing tournaments");
126 		final Population<T> selectedIndividuals = new Population<>();
127 		while (selectedIndividuals.size() < numIndividuals) {
128 
129 			logger.trace("Performing tournament");
130 			Genotype bestCandidate = null;
131 			int bestCandidateIndex = -1;
132 			T bestFitness = null;
133 
134 			for (int i = 0; i < numCandidates; i++) {
135 				final int candidateIndex = randomGenerator.nextInt(individuals.size());
136 
137 				logger.trace("\tCandidate - index {} - rank {} - crowding distance {} - fitness {}",
138 						candidateIndex,
139 						individual2Rank[candidateIndex],
140 						crowdingDistanceAssignment[candidateIndex],
141 						individuals.getFitness(candidateIndex));
142 
143 				if (bestCandidate == null || individual2Rank[candidateIndex] < individual2Rank[bestCandidateIndex]
144 						|| (individual2Rank[candidateIndex] == individual2Rank[bestCandidateIndex]
145 								&& crowdingDistanceAssignment[candidateIndex] > crowdingDistanceAssignment[bestCandidateIndex])) {
146 
147 					logger.trace("\t candidate win!");
148 					bestCandidate = individuals.getGenotype(candidateIndex);
149 					bestFitness = individuals.getFitness(candidateIndex);
150 					bestCandidateIndex = candidateIndex;
151 				}
152 			}
153 
154 			selectedIndividuals.add(bestCandidate, bestFitness);
155 		}
156 
157 		return selectedIndividuals;
158 	}
159 }