View Javadoc
1   package net.bmahe.genetics4j.moo.nsga2.impl;
2   
3   import java.util.Comparator;
4   import java.util.List;
5   import java.util.Objects;
6   import java.util.Set;
7   import java.util.TreeSet;
8   import java.util.function.Function;
9   import java.util.random.RandomGenerator;
10  
11  import org.apache.commons.lang3.Validate;
12  import org.apache.logging.log4j.LogManager;
13  import org.apache.logging.log4j.Logger;
14  
15  import net.bmahe.genetics4j.core.Genotype;
16  import net.bmahe.genetics4j.core.Population;
17  import net.bmahe.genetics4j.core.selection.Selector;
18  import net.bmahe.genetics4j.core.spec.AbstractEAConfiguration;
19  import net.bmahe.genetics4j.moo.ObjectiveDistance;
20  import net.bmahe.genetics4j.moo.ParetoUtils;
21  import net.bmahe.genetics4j.moo.nsga2.spec.TournamentNSGA2Selection;
22  
23  public class TournamentNSGA2Selector<T extends Comparable<T>> implements Selector<T> {
24  	final static public Logger logger = LogManager.getLogger(TournamentNSGA2Selector.class);
25  
26  	private final TournamentNSGA2Selection<T> tournamentNSGA2Selection;
27  	private final RandomGenerator randomGenerator;
28  
29  	public TournamentNSGA2Selector(final RandomGenerator _randomGenerator,
30  			final TournamentNSGA2Selection<T> _tournamentNSGA2Selection) {
31  		Objects.requireNonNull(_randomGenerator);
32  		Objects.requireNonNull(_tournamentNSGA2Selection);
33  
34  		this.randomGenerator = _randomGenerator;
35  		this.tournamentNSGA2Selection = _tournamentNSGA2Selection;
36  
37  	}
38  
39  	@Override
40  	public Population<T> select(final AbstractEAConfiguration<T> eaConfiguration, final long generation,
41  			final int numIndividuals, final List<Genotype> population, final List<T> fitnessScore) {
42  		Objects.requireNonNull(eaConfiguration);
43  		Objects.requireNonNull(population);
44  		Objects.requireNonNull(fitnessScore);
45  		Validate.isTrue(generation >= 0);
46  		Validate.isTrue(numIndividuals > 0);
47  		Validate.isTrue(population.size() == fitnessScore.size());
48  
49  		logger.debug("Incoming population size is {}", population.size());
50  
51  		final Population<T> individuals = new Population<>();
52  		if (tournamentNSGA2Selection.deduplicate().isPresent()) {
53  			final Comparator<Genotype> individualDeduplicator = tournamentNSGA2Selection.deduplicate().get();
54  			final Set<Genotype> seenGenotype = new TreeSet<>(individualDeduplicator);
55  
56  			for (int i = 0; i < population.size(); i++) {
57  				final Genotype genotype = population.get(i);
58  				final T fitness = fitnessScore.get(i);
59  
60  				if (seenGenotype.add(genotype)) {
61  					individuals.add(genotype, fitness);
62  				}
63  			}
64  
65  		} else {
66  			for (int i = 0; i < population.size(); i++) {
67  				final Genotype genotype = population.get(i);
68  				final T fitness = fitnessScore.get(i);
69  
70  				individuals.add(genotype, fitness);
71  			}
72  		}
73  
74  		logger.debug("Selecting {} individuals from a population of {}", numIndividuals, individuals.size());
75  
76  		final int numberObjectives = tournamentNSGA2Selection.numberObjectives();
77  
78  		final Comparator<T> dominance = switch (eaConfiguration.optimization()) {
79  			case MAXIMIZE -> tournamentNSGA2Selection.dominance();
80  			case MINIMIZE -> tournamentNSGA2Selection.dominance().reversed();
81  		};
82  
83  		final Function<Integer, Comparator<T>> objectiveComparator = switch (eaConfiguration.optimization()) {
84  			case MAXIMIZE -> tournamentNSGA2Selection.objectiveComparator();
85  			case MINIMIZE -> (m) -> tournamentNSGA2Selection.objectiveComparator().apply(m).reversed();
86  		};
87  
88  		final ObjectiveDistance<T> objectiveDistance = tournamentNSGA2Selection.distance();
89  		final int numCandidates = tournamentNSGA2Selection.numCandidates();
90  
91  		logger.debug("Ranking population");
92  		final List<Set<Integer>> rankedPopulation = ParetoUtils
93  				.rankedPopulation(dominance, individuals.getAllFitnesses());
94  		// Build a reverse index
95  		final int[] individual2Rank = new int[individuals.size()];
96  		for (int j = 0; j < rankedPopulation.size(); j++) {
97  			final Set<Integer> set = rankedPopulation.get(j);
98  
99  			for (final Integer idx : set) {
100 				individual2Rank[idx] = j;
101 			}
102 		}
103 
104 		if (logger.isTraceEnabled()) {
105 			logger.trace("Ranked population: {}", rankedPopulation);
106 			for (int i = 0; i < rankedPopulation.size(); i++) {
107 				final Set<Integer> subPopulationIdx = rankedPopulation.get(i);
108 				logger.trace("\tRank {}", i);
109 				for (final Integer index : subPopulationIdx) {
110 					logger.trace("\t\t{} - Fitness {}", index, individuals.getFitness(index));
111 				}
112 			}
113 		}
114 		logger.debug("Computing crowding distance assignment");
115 		final double[] crowdingDistanceAssignment = NSGA2Utils.crowdingDistanceAssignment(
116 				numberObjectives,
117 					individuals.getAllFitnesses(),
118 					objectiveComparator,
119 					objectiveDistance);
120 
121 		logger.debug("Performing tournaments");
122 		final Population<T> selectedIndividuals = new Population<>();
123 		while (selectedIndividuals.size() < numIndividuals) {
124 
125 			logger.trace("Performing tournament");
126 			Genotype bestCandidate = null;
127 			int bestCandidateIndex = -1;
128 			T bestFitness = null;
129 
130 			for (int i = 0; i < numCandidates; i++) {
131 				final int candidateIndex = randomGenerator.nextInt(individuals.size());
132 
133 				logger.trace(
134 						"\tCandidate - index {} - rank {} - crowding distance {} - fitness {}",
135 							candidateIndex,
136 							individual2Rank[candidateIndex],
137 							crowdingDistanceAssignment[candidateIndex],
138 							individuals.getFitness(candidateIndex));
139 
140 				if (bestCandidate == null || individual2Rank[candidateIndex] < individual2Rank[bestCandidateIndex]
141 						|| (individual2Rank[candidateIndex] == individual2Rank[bestCandidateIndex]
142 								&& crowdingDistanceAssignment[candidateIndex] > crowdingDistanceAssignment[bestCandidateIndex])) {
143 
144 					logger.trace("\t candidate win!");
145 					bestCandidate = individuals.getGenotype(candidateIndex);
146 					bestFitness = individuals.getFitness(candidateIndex);
147 					bestCandidateIndex = candidateIndex;
148 				}
149 			}
150 
151 			selectedIndividuals.add(bestCandidate, bestFitness);
152 		}
153 
154 		return selectedIndividuals;
155 	}
156 }