View Javadoc
1   package net.bmahe.genetics4j.moo.spea2.replacement;
2   
3   import java.util.ArrayList;
4   import java.util.Collections;
5   import java.util.Comparator;
6   import java.util.HashMap;
7   import java.util.List;
8   import java.util.Map;
9   import java.util.Map.Entry;
10  import java.util.Objects;
11  import java.util.Set;
12  import java.util.TreeSet;
13  import java.util.function.BiFunction;
14  import java.util.stream.Collectors;
15  import java.util.stream.IntStream;
16  
17  import org.apache.commons.lang3.Validate;
18  import org.apache.commons.lang3.time.DurationFormatUtils;
19  import org.apache.commons.lang3.tuple.Pair;
20  import org.apache.logging.log4j.LogManager;
21  import org.apache.logging.log4j.Logger;
22  
23  import net.bmahe.genetics4j.core.Genotype;
24  import net.bmahe.genetics4j.core.Population;
25  import net.bmahe.genetics4j.core.replacement.ReplacementStrategyImplementor;
26  import net.bmahe.genetics4j.core.spec.AbstractEAConfiguration;
27  import net.bmahe.genetics4j.moo.spea2.spec.replacement.SPEA2Replacement;
28  
29  public class SPEA2ReplacementStrategyImplementor<T extends Comparable<T>> implements ReplacementStrategyImplementor<T> {
30  	final static public Logger logger = LogManager.getLogger(SPEA2ReplacementStrategyImplementor.class);
31  
32  	private final SPEA2Replacement<T> spea2Replacement;
33  
34  	public SPEA2ReplacementStrategyImplementor(final SPEA2Replacement<T> _spea2Replacement) {
35  		this.spea2Replacement = _spea2Replacement;
36  	}
37  
38  	protected double[] computeStrength(final Comparator<T> dominance, final Population<T> population) {
39  		Objects.requireNonNull(dominance);
40  		Objects.requireNonNull(population);
41  		Validate.isTrue(population.size() > 0);
42  
43  		final double[] strengths = new double[population.size()];
44  		for (int i = 0; i < population.size(); i++) {
45  			final T fitness = population.getFitness(i);
46  
47  			strengths[i] = SPEA2Utils.strength(dominance, i, fitness, population);
48  		}
49  
50  		return strengths;
51  	}
52  
53  	protected double[][] computeObjectiveDistances(final BiFunction<T, T, Double> distance,
54  			final Population<T> population) {
55  		Objects.requireNonNull(distance);
56  		Objects.requireNonNull(population);
57  		Validate.isTrue(population.size() > 0);
58  
59  		final double[][] distanceObjectives = new double[population.size()][population.size()];
60  
61  		for (int i = 0; i < population.size(); i++) {
62  			for (int j = 0; j < i; j++) {
63  				final Double distanceMeasure = distance.apply(population.getFitness(i), population.getFitness(j));
64  				distanceObjectives[i][j] = distanceMeasure;
65  				distanceObjectives[j][i] = distanceMeasure;
66  			}
67  
68  			distanceObjectives[i][i] = 0.0;
69  		}
70  		return distanceObjectives;
71  	}
72  
73  	protected double[] computeRawFitness(final Comparator<T> dominance, final double[] strengths,
74  			final Population<T> population) {
75  		Objects.requireNonNull(dominance);
76  		Objects.requireNonNull(strengths);
77  		Objects.requireNonNull(population);
78  		Validate.isTrue(population.size() == strengths.length);
79  		Validate.isTrue(population.size() > 0);
80  
81  		final double[] rawFitness = new double[population.size()];
82  		for (int i = 0; i < population.size(); i++) {
83  			final T fitness = population.getFitness(i);
84  
85  			rawFitness[i] = SPEA2Utils.rawFitness(dominance, strengths, i, fitness, population);
86  		}
87  
88  		return rawFitness;
89  	}
90  
91  	protected List<List<Pair<Integer, Double>>> computeSortedDistances(final double[][] distanceObjectives,
92  			final Population<T> population) {
93  		Objects.requireNonNull(distanceObjectives);
94  		Objects.requireNonNull(population);
95  		Validate.isTrue(population.size() == distanceObjectives.length); // won't test all the rows
96  		Validate.isTrue(population.size() > 0);
97  
98  		final List<List<Pair<Integer, Double>>> distances = new ArrayList<>();
99  		for (int i = 0; i < population.size(); i++) {
100 			final T fitness = population.getFitness(i);
101 
102 			final List<Pair<Integer, Double>> kthDistances = SPEA2Utils
103 					.kthDistances(distanceObjectives, i, fitness, population);
104 			distances.add(kthDistances);
105 
106 		}
107 		return distances;
108 	}
109 
110 	protected double[] computeDensity(final List<List<Pair<Integer, Double>>> distances, final int k,
111 			final Population<T> population) {
112 		Objects.requireNonNull(distances);
113 		Validate.isTrue(population.size() == distances.size());
114 		Validate.isTrue(k > 0);
115 		Objects.requireNonNull(population);
116 		Validate.isTrue(population.size() > 0);
117 
118 		final double[] density = new double[population.size()];
119 		for (int i = 0; i < population.size(); i++) {
120 			density[i] = 1.0d / (distances.get(i)
121 					.get(k)
122 					.getRight() + 2);
123 		}
124 
125 		return density;
126 	}
127 
128 	protected double[] computeFinalFitness(final double[] rawFitness, final double[] density,
129 			final Population<T> population) {
130 		Objects.requireNonNull(rawFitness);
131 		Objects.requireNonNull(density);
132 		Validate.isTrue(rawFitness.length == density.length);
133 		Objects.requireNonNull(population);
134 		Validate.isTrue(population.size() > 0);
135 		Validate.isTrue(population.size() == density.length);
136 
137 		final double[] finalFitness = new double[population.size()];
138 		for (int i = 0; i < population.size(); i++) {
139 			finalFitness[i] = rawFitness[i] + density[i];
140 		}
141 
142 		return finalFitness;
143 	}
144 
145 	protected int skipNull(final List<Pair<Integer, Double>> distances, final int i) {
146 		Objects.requireNonNull(distances);
147 		Validate.isTrue(i >= 0);
148 		Validate.isTrue(i <= distances.size());
149 
150 		int j = i;
151 
152 		while (j < distances.size() && distances.get(j) == null) {
153 			j++;
154 		}
155 
156 		return j;
157 	}
158 
159 	protected List<Integer> computeAdditionalIndividuals(final Set<Integer> selectedIndex, final double[] rawFitness,
160 			final Population<T> population, final int numIndividuals) {
161 		Objects.requireNonNull(selectedIndex);
162 		Objects.requireNonNull(rawFitness);
163 		Objects.requireNonNull(population);
164 		Validate.isTrue(rawFitness.length == population.size());
165 		Validate.isTrue(numIndividuals >= selectedIndex.size());
166 
167 		if (numIndividuals == selectedIndex.size()) {
168 			return Collections.emptyList();
169 		}
170 
171 		final List<Integer> additionalIndividuals = IntStream.range(0, population.size())
172 				.boxed()
173 				.filter((i) -> selectedIndex.contains(i) == false)
174 				.sorted((a, b) -> Double.compare(rawFitness[a], rawFitness[b]))
175 				.limit(numIndividuals - selectedIndex.size())
176 				.collect(Collectors.toList());
177 
178 		return additionalIndividuals;
179 	}
180 
181 	protected void truncatePopulation(final List<List<Pair<Integer, Double>>> distances, final Population<T> population,
182 			final int numIndividuals, final Set<Integer> selectedIndex) {
183 
184 		final Map<Integer, List<Pair<Integer, Double>>> selectedDistances = new HashMap<>();
185 		final Map<Integer, Map<Integer, Integer>> selectedDistancesIndex = new HashMap<>();
186 
187 		/**
188 		 * The goal here is two fold: - Build selectedDistances, which is a map of individual index -> ordered list of
189 		 * nearest neighbors, with only the individuals from selectedIndex. This will prevent the unnecessary processing
190 		 * of ignored individuals
191 		 * 
192 		 * - Build an inverted index selectedDistancesIndex so that we know where to delete entries in selectedDistances
193 		 * whenever an individual has been removed The index is in the form: individual -> key in selectedDistance ->
194 		 * Which position in the nearest neighbors
195 		 */
196 		for (final int index : selectedIndex) {
197 
198 			final List<Pair<Integer, Double>> kthDistances = distances.get(index)
199 					.stream()
200 					.filter(p -> selectedIndex.contains(p.getLeft()))
201 					.collect(Collectors.toList());
202 
203 			Validate.isTrue(kthDistances.size() == selectedIndex.size());
204 			selectedDistances.put(index, kthDistances);
205 
206 			for (int i = 0; i < kthDistances.size(); i++) {
207 				final Pair<Integer, Double> pair = kthDistances.get(i);
208 
209 				if (selectedDistancesIndex.containsKey(pair.getKey()) == false) {
210 					selectedDistancesIndex.put(pair.getKey(), new HashMap<>());
211 				}
212 
213 				selectedDistancesIndex.get(pair.getKey())
214 						.put(index, i);
215 			}
216 		}
217 
218 		while (selectedIndex.size() > numIndividuals) {
219 
220 			int minIndex = -1;
221 			List<Pair<Integer, Double>> minDistances = null;
222 			for (final int candidateIndex : selectedIndex) {
223 
224 				if (minIndex < 0) {
225 					minIndex = candidateIndex;
226 					minDistances = selectedDistances.get(candidateIndex);
227 				} else {
228 					final List<Pair<Integer, Double>> distancesCandidate = selectedDistances.get(candidateIndex);
229 					Validate.isTrue(minDistances.size() == distancesCandidate.size());
230 
231 					int result = 0;
232 					int j = skipNull(minDistances, 0);
233 					int l = skipNull(distancesCandidate, 0);
234 
235 					while (result == 0 && j < minDistances.size() && l < distancesCandidate.size()) {
236 
237 						result = Double.compare(minDistances.get(j)
238 								.getRight(),
239 								distancesCandidate.get(l)
240 										.getRight());
241 
242 						j++;
243 						j = skipNull(minDistances, j);
244 
245 						l++;
246 						l = skipNull(distancesCandidate, l);
247 					}
248 
249 					if (result > 0) {
250 						minIndex = candidateIndex;
251 						minDistances = distancesCandidate;
252 					}
253 				}
254 			}
255 
256 			/**
257 			 * We cannot just remove it. We have to set the entry to 'null' as to not mess up the positions recorded in
258 			 * selectedDistancesIndex.
259 			 */
260 			final Map<Integer, Integer> reverseIndex = selectedDistancesIndex.get(minIndex);
261 			for (Entry<Integer, Integer> entry : reverseIndex.entrySet()) {
262 				final List<Pair<Integer, Double>> distancesToClean = selectedDistances.get(entry.getKey());
263 				distancesToClean.set((int) entry.getValue(), null);
264 			}
265 			for (Map<Integer, Integer> map : selectedDistancesIndex.values()) {
266 				map.remove(minIndex);
267 			}
268 
269 			selectedDistancesIndex.remove(minIndex);
270 			selectedDistances.remove(minIndex);
271 			selectedIndex.remove(minIndex);
272 		}
273 
274 	}
275 
276 	protected Set<Integer> environmentalSelection(final List<List<Pair<Integer, Double>>> distances,
277 			final double[] rawFitness, final double[] finalFitness, final Population<T> population,
278 			final int numIndividuals) {
279 
280 		final Set<Integer> selectedIndex = IntStream.range(0, population.size())
281 				.boxed()
282 				.filter((i) -> finalFitness[i] < 1)
283 				.collect(Collectors.toSet());
284 
285 		logger.trace("Selected index size: {}", selectedIndex.size());
286 
287 		if (selectedIndex.size() < numIndividuals) {
288 
289 			final List<Integer> additionalIndividuals = computeAdditionalIndividuals(selectedIndex,
290 					rawFitness,
291 					population,
292 					numIndividuals);
293 
294 			logger.trace("Adding {} additional individuals", additionalIndividuals.size());
295 			selectedIndex.addAll(additionalIndividuals);
296 		}
297 
298 		if (selectedIndex.size() > numIndividuals) {
299 			logger.trace("Need to remove {} individuals", selectedIndex.size() - numIndividuals);
300 
301 			truncatePopulation(distances, population, numIndividuals, selectedIndex);
302 		}
303 
304 		return selectedIndex;
305 	}
306 
307 	@Override
308 	public Population<T> select(final AbstractEAConfiguration<T> eaConfiguration, final long generation,
309 			final int numIndividuals, final List<Genotype> population, final List<T> populationScores,
310 			final List<Genotype> offsprings, final List<T> offspringScores) {
311 		Objects.requireNonNull(eaConfiguration);
312 		Validate.isTrue(generation >= 0);
313 		Validate.isTrue(numIndividuals > 0);
314 		Objects.requireNonNull(population);
315 		Objects.requireNonNull(populationScores);
316 		Validate.isTrue(population.size() == populationScores.size());
317 		Objects.requireNonNull(offsprings);
318 		Objects.requireNonNull(offspringScores);
319 		Validate.isTrue(offsprings.size() == offspringScores.size());
320 
321 		final long startTimeNanos = System.nanoTime();
322 		logger.debug("Starting with requested {} individuals - {} population - {} offsprings",
323 				numIndividuals,
324 				population.size(),
325 				offsprings.size());
326 
327 		final Population<T> archive = new Population<>(population, populationScores);
328 		final Population<T> offspringPopulation = new Population<>(offsprings, offspringScores);
329 
330 		final Population<T> combinedPopulation = new Population<>();
331 		if (spea2Replacement.deduplicate()
332 				.isPresent()) {
333 			final Comparator<Genotype> individualDeduplicator = spea2Replacement.deduplicate()
334 					.get();
335 			final Set<Genotype> seenGenotype = new TreeSet<>(individualDeduplicator);
336 
337 			for (int i = 0; i < archive.size(); i++) {
338 				final Genotype genotype = archive.getGenotype(i);
339 
340 				if (seenGenotype.add(genotype)) {
341 					final T fitness = archive.getFitness(i);
342 					combinedPopulation.add(genotype, fitness);
343 				}
344 			}
345 			final int ingestedFromArchive = combinedPopulation.size();
346 			logger.debug("Ingested {} individuals from the archive out of the {} available",
347 					ingestedFromArchive,
348 					archive.size());
349 
350 			for (int i = 0; i < offspringPopulation.size(); i++) {
351 				final Genotype genotype = offspringPopulation.getGenotype(i);
352 
353 				if (seenGenotype.add(genotype)) {
354 					final T fitness = offspringPopulation.getFitness(i);
355 					combinedPopulation.add(genotype, fitness);
356 				}
357 			}
358 			if (logger.isDebugEnabled()) {
359 				logger.debug("Ingested {} individuals from the offsprings out of the {} available",
360 						combinedPopulation.size() - ingestedFromArchive,
361 						offspringPopulation.size());
362 			}
363 
364 		} else {
365 			combinedPopulation.addAll(archive);
366 			combinedPopulation.addAll(offspringPopulation);
367 		}
368 
369 		final Comparator<T> dominance = switch (eaConfiguration.optimization()) {
370 			case MAXIMIZE -> spea2Replacement.dominance();
371 			case MINIMIZE -> spea2Replacement.dominance()
372 					.reversed();
373 		};
374 
375 		final int k = spea2Replacement.k()
376 				.orElseGet(() -> (int) Math.sqrt(combinedPopulation.size()));
377 		logger.trace("Using k={}", k);
378 		Validate.isTrue(k > 0);
379 
380 		///////////////// Fitness computation //////////////////////
381 		final double[] strengths = computeStrength(dominance, combinedPopulation);
382 
383 		final double[][] distanceObjectives = computeObjectiveDistances(spea2Replacement.distance(), combinedPopulation);
384 
385 		final double[] rawFitness = computeRawFitness(dominance, strengths, combinedPopulation);
386 
387 		final List<List<Pair<Integer, Double>>> distances = computeSortedDistances(distanceObjectives,
388 				combinedPopulation);
389 
390 		final double[] density = computeDensity(distances, k, combinedPopulation);
391 
392 		final double[] finalFitness = computeFinalFitness(rawFitness, density, combinedPopulation);
393 
394 		///////////////// Environmental Selection //////////////////
395 
396 		final Set<Integer> selectedIndex = environmentalSelection(distances,
397 				rawFitness,
398 				finalFitness,
399 				combinedPopulation,
400 				numIndividuals);
401 
402 		final Population<T> newPopulation = new Population<>();
403 		for (final int i : selectedIndex) {
404 			newPopulation.add(combinedPopulation.getGenotype(i), combinedPopulation.getFitness(i));
405 		}
406 
407 		final long endTimeNanos = System.nanoTime();
408 		if (logger.isDebugEnabled()) {
409 			logger.debug("Finished with {} new population - Computation time: {}",
410 					newPopulation.size(),
411 					DurationFormatUtils.formatDurationHMS((endTimeNanos - startTimeNanos) / 1_000_000));
412 		}
413 
414 		return newPopulation;
415 	}
416 }