View Javadoc
1   package net.bmahe.genetics4j.moo.spea2.replacement;
2   
3   import java.util.ArrayList;
4   import java.util.Collections;
5   import java.util.Comparator;
6   import java.util.HashMap;
7   import java.util.List;
8   import java.util.Map;
9   import java.util.Map.Entry;
10  import java.util.Set;
11  import java.util.TreeSet;
12  import java.util.function.BiFunction;
13  import java.util.stream.Collectors;
14  import java.util.stream.IntStream;
15  
16  import org.apache.commons.lang3.Validate;
17  import org.apache.commons.lang3.time.DurationFormatUtils;
18  import org.apache.commons.lang3.tuple.Pair;
19  import org.apache.logging.log4j.LogManager;
20  import org.apache.logging.log4j.Logger;
21  
22  import net.bmahe.genetics4j.core.Genotype;
23  import net.bmahe.genetics4j.core.Population;
24  import net.bmahe.genetics4j.core.replacement.ReplacementStrategyImplementor;
25  import net.bmahe.genetics4j.core.spec.AbstractEAConfiguration;
26  import net.bmahe.genetics4j.moo.spea2.spec.replacement.SPEA2Replacement;
27  
28  public class SPEA2ReplacementStrategyImplementor<T extends Comparable<T>> implements ReplacementStrategyImplementor<T> {
29  	final static public Logger logger = LogManager.getLogger(SPEA2ReplacementStrategyImplementor.class);
30  
31  	private final SPEA2Replacement<T> spea2Replacement;
32  
33  	public SPEA2ReplacementStrategyImplementor(final SPEA2Replacement<T> _spea2Replacement) {
34  		this.spea2Replacement = _spea2Replacement;
35  	}
36  
37  	protected double[] computeStrength(final Comparator<T> dominance, final Population<T> population) {
38  		Validate.notNull(dominance);
39  		Validate.notNull(population);
40  		Validate.isTrue(population.size() > 0);
41  
42  		final double[] strengths = new double[population.size()];
43  		for (int i = 0; i < population.size(); i++) {
44  			final T fitness = population.getFitness(i);
45  
46  			strengths[i] = SPEA2Utils.strength(dominance, i, fitness, population);
47  		}
48  
49  		return strengths;
50  	}
51  
52  	protected double[][] computeObjectiveDistances(final BiFunction<T, T, Double> distance,
53  			final Population<T> population) {
54  		Validate.notNull(distance);
55  		Validate.notNull(population);
56  		Validate.isTrue(population.size() > 0);
57  
58  		final double[][] distanceObjectives = new double[population.size()][population.size()];
59  
60  		for (int i = 0; i < population.size(); i++) {
61  			for (int j = 0; j < i; j++) {
62  				final Double distanceMeasure = distance.apply(population.getFitness(i), population.getFitness(j));
63  				distanceObjectives[i][j] = distanceMeasure;
64  				distanceObjectives[j][i] = distanceMeasure;
65  			}
66  
67  			distanceObjectives[i][i] = 0.0;
68  		}
69  		return distanceObjectives;
70  	}
71  
72  	protected double[] computeRawFitness(final Comparator<T> dominance, final double[] strengths,
73  			final Population<T> population) {
74  		Validate.notNull(dominance);
75  		Validate.notNull(strengths);
76  		Validate.notNull(population);
77  		Validate.isTrue(population.size() == strengths.length);
78  		Validate.isTrue(population.size() > 0);
79  
80  		final double[] rawFitness = new double[population.size()];
81  		for (int i = 0; i < population.size(); i++) {
82  			final T fitness = population.getFitness(i);
83  
84  			rawFitness[i] = SPEA2Utils.rawFitness(dominance, strengths, i, fitness, population);
85  		}
86  
87  		return rawFitness;
88  	}
89  
90  	protected List<List<Pair<Integer, Double>>> computeSortedDistances(final double[][] distanceObjectives,
91  			final Population<T> population) {
92  		Validate.notNull(distanceObjectives);
93  		Validate.notNull(population);
94  		Validate.isTrue(population.size() == distanceObjectives.length); // won't test all the rows
95  		Validate.isTrue(population.size() > 0);
96  
97  		final List<List<Pair<Integer, Double>>> distances = new ArrayList<>();
98  		for (int i = 0; i < population.size(); i++) {
99  			final T fitness = population.getFitness(i);
100 
101 			final List<Pair<Integer, Double>> kthDistances = SPEA2Utils
102 					.kthDistances(distanceObjectives, i, fitness, population);
103 			distances.add(kthDistances);
104 
105 		}
106 		return distances;
107 	}
108 
109 	protected double[] computeDensity(final List<List<Pair<Integer, Double>>> distances, final int k,
110 			final Population<T> population) {
111 		Validate.notNull(distances);
112 		Validate.isTrue(population.size() == distances.size());
113 		Validate.isTrue(k > 0);
114 		Validate.notNull(population);
115 		Validate.isTrue(population.size() > 0);
116 
117 		final double[] density = new double[population.size()];
118 		for (int i = 0; i < population.size(); i++) {
119 			density[i] = 1.0d / (distances.get(i)
120 					.get(k)
121 					.getRight() + 2);
122 		}
123 
124 		return density;
125 	}
126 
127 	protected double[] computeFinalFitness(final double[] rawFitness, final double[] density,
128 			final Population<T> population) {
129 		Validate.notNull(rawFitness);
130 		Validate.notNull(density);
131 		Validate.isTrue(rawFitness.length == density.length);
132 		Validate.notNull(population);
133 		Validate.isTrue(population.size() > 0);
134 		Validate.isTrue(population.size() == density.length);
135 
136 		final double[] finalFitness = new double[population.size()];
137 		for (int i = 0; i < population.size(); i++) {
138 			finalFitness[i] = rawFitness[i] + density[i];
139 		}
140 
141 		return finalFitness;
142 	}
143 
144 	protected int skipNull(final List<Pair<Integer, Double>> distances, final int i) {
145 		Validate.notNull(distances);
146 		Validate.isTrue(i >= 0);
147 		Validate.isTrue(i <= distances.size());
148 
149 		int j = i;
150 
151 		while (j < distances.size() && distances.get(j) == null) {
152 			j++;
153 		}
154 
155 		return j;
156 	}
157 
158 	protected List<Integer> computeAdditionalIndividuals(final Set<Integer> selectedIndex, final double[] rawFitness,
159 			final Population<T> population, final int numIndividuals) {
160 		Validate.notNull(selectedIndex);
161 		Validate.notNull(rawFitness);
162 		Validate.notNull(population);
163 		Validate.isTrue(rawFitness.length == population.size());
164 		Validate.isTrue(numIndividuals >= selectedIndex.size());
165 
166 		if (numIndividuals == selectedIndex.size()) {
167 			return Collections.emptyList();
168 		}
169 
170 		final List<Integer> additionalIndividuals = IntStream.range(0, population.size())
171 				.boxed()
172 				.filter((i) -> selectedIndex.contains(i) == false)
173 				.sorted((a, b) -> Double.compare(rawFitness[a], rawFitness[b]))
174 				.limit(numIndividuals - selectedIndex.size())
175 				.collect(Collectors.toList());
176 
177 		return additionalIndividuals;
178 	}
179 
180 	protected void truncatePopulation(final List<List<Pair<Integer, Double>>> distances, final Population<T> population,
181 			final int numIndividuals, final Set<Integer> selectedIndex) {
182 
183 		final Map<Integer, List<Pair<Integer, Double>>> selectedDistances = new HashMap<>();
184 		final Map<Integer, Map<Integer, Integer>> selectedDistancesIndex = new HashMap<>();
185 
186 		/**
187 		 * The goal here is two fold:
188 		 * - Build selectedDistances, which is a map of individual index -> ordered list
189 		 * of nearest neighbors, with only the individuals from selectedIndex. This will
190 		 * prevent the unnecessary processing of ignored individuals
191 		 * 
192 		 * - Build an inverted index selectedDistancesIndex so that we know where to
193 		 * delete entries in selectedDistances whenever an individual has been removed
194 		 * The index is in the form: individual -> key in selectedDistance -> Which
195 		 * position in the nearest neighbors
196 		 */
197 		for (final int index : selectedIndex) {
198 
199 			final List<Pair<Integer, Double>> kthDistances = distances.get(index)
200 					.stream()
201 					.filter(p -> selectedIndex.contains(p.getLeft()))
202 					.collect(Collectors.toList());
203 
204 			Validate.isTrue(kthDistances.size() == selectedIndex.size());
205 			selectedDistances.put(index, kthDistances);
206 
207 			for (int i = 0; i < kthDistances.size(); i++) {
208 				final Pair<Integer, Double> pair = kthDistances.get(i);
209 
210 				if (selectedDistancesIndex.containsKey(pair.getKey()) == false) {
211 					selectedDistancesIndex.put(pair.getKey(), new HashMap<>());
212 				}
213 
214 				selectedDistancesIndex.get(pair.getKey())
215 						.put(index, i);
216 			}
217 		}
218 
219 		while (selectedIndex.size() > numIndividuals) {
220 
221 			int minIndex = -1;
222 			List<Pair<Integer, Double>> minDistances = null;
223 			for (final int candidateIndex : selectedIndex) {
224 
225 				if (minIndex < 0) {
226 					minIndex = candidateIndex;
227 					minDistances = selectedDistances.get(candidateIndex);
228 				} else {
229 					final List<Pair<Integer, Double>> distancesCandidate = selectedDistances.get(candidateIndex);
230 					Validate.isTrue(minDistances.size() == distancesCandidate.size());
231 
232 					int result = 0;
233 					int j = skipNull(minDistances, 0);
234 					int l = skipNull(distancesCandidate, 0);
235 
236 					while (result == 0 && j < minDistances.size() && l < distancesCandidate.size()) {
237 
238 						result = Double.compare(minDistances.get(j)
239 								.getRight(),
240 								distancesCandidate.get(l)
241 										.getRight());
242 
243 						j++;
244 						j = skipNull(minDistances, j);
245 
246 						l++;
247 						l = skipNull(distancesCandidate, l);
248 					}
249 
250 					if (result > 0) {
251 						minIndex = candidateIndex;
252 						minDistances = distancesCandidate;
253 					}
254 				}
255 			}
256 
257 			/**
258 			 * We cannot just remove it. We have to set the entry to 'null' as to not mess
259 			 * up the positions recorded in selectedDistancesIndex.
260 			 */
261 			final Map<Integer, Integer> reverseIndex = selectedDistancesIndex.get(minIndex);
262 			for (Entry<Integer, Integer> entry : reverseIndex.entrySet()) {
263 				final List<Pair<Integer, Double>> distancesToClean = selectedDistances.get(entry.getKey());
264 				distancesToClean.set((int) entry.getValue(), null);
265 			}
266 			for (Map<Integer, Integer> map : selectedDistancesIndex.values()) {
267 				map.remove(minIndex);
268 			}
269 
270 			selectedDistancesIndex.remove(minIndex);
271 			selectedDistances.remove(minIndex);
272 			selectedIndex.remove(minIndex);
273 		}
274 
275 	}
276 
277 	protected Set<Integer> environmentalSelection(final List<List<Pair<Integer, Double>>> distances,
278 			final double[] rawFitness, final double[] finalFitness, final Population<T> population,
279 			final int numIndividuals) {
280 
281 		final Set<Integer> selectedIndex = IntStream.range(0, population.size())
282 				.boxed()
283 				.filter((i) -> finalFitness[i] < 1)
284 				.collect(Collectors.toSet());
285 
286 		logger.trace("Selected index size: {}", selectedIndex.size());
287 
288 		if (selectedIndex.size() < numIndividuals) {
289 
290 			final List<Integer> additionalIndividuals = computeAdditionalIndividuals(selectedIndex,
291 					rawFitness,
292 					population,
293 					numIndividuals);
294 
295 			logger.trace("Adding {} additional individuals", additionalIndividuals.size());
296 			selectedIndex.addAll(additionalIndividuals);
297 		}
298 
299 		if (selectedIndex.size() > numIndividuals) {
300 			logger.trace("Need to remove {} individuals", selectedIndex.size() - numIndividuals);
301 
302 			truncatePopulation(distances, population, numIndividuals, selectedIndex);
303 		}
304 
305 		return selectedIndex;
306 	}
307 
308 	@Override
309 	public Population<T> select(final AbstractEAConfiguration<T> eaConfiguration, final int numIndividuals,
310 			final List<Genotype> population, final List<T> populationScores, final List<Genotype> offsprings,
311 			final List<T> offspringScores) {
312 		Validate.notNull(eaConfiguration);
313 		Validate.isTrue(numIndividuals > 0);
314 		Validate.notNull(population);
315 		Validate.notNull(populationScores);
316 		Validate.isTrue(population.size() == populationScores.size());
317 		Validate.notNull(offsprings);
318 		Validate.notNull(offspringScores);
319 		Validate.isTrue(offsprings.size() == offspringScores.size());
320 
321 		final long startTimeNanos = System.nanoTime();
322 		logger.debug("Starting with requested {} individuals - {} population - {} offsprings",
323 				numIndividuals,
324 				population.size(),
325 				offsprings.size());
326 
327 		final Population<T> archive = new Population<>(population, populationScores);
328 		final Population<T> offspringPopulation = new Population<>(offsprings, offspringScores);
329 
330 		final Population<T> combinedPopulation = new Population<>();
331 		if (spea2Replacement.deduplicate()
332 				.isPresent()) {
333 			final Comparator<Genotype> individualDeduplicator = spea2Replacement.deduplicate()
334 					.get();
335 			final Set<Genotype> seenGenotype = new TreeSet<>(individualDeduplicator);
336 
337 			for (int i = 0; i < archive.size(); i++) {
338 				final Genotype genotype = archive.getGenotype(i);
339 
340 				if (seenGenotype.add(genotype)) {
341 					final T fitness = archive.getFitness(i);
342 					combinedPopulation.add(genotype, fitness);
343 				}
344 			}
345 			final int ingestedFromArchive = combinedPopulation.size();
346 			logger.debug("Ingested {} individuals from the archive out of the {} available",
347 					ingestedFromArchive,
348 					archive.size());
349 
350 			for (int i = 0; i < offspringPopulation.size(); i++) {
351 				final Genotype genotype = offspringPopulation.getGenotype(i);
352 
353 				if (seenGenotype.add(genotype)) {
354 					final T fitness = offspringPopulation.getFitness(i);
355 					combinedPopulation.add(genotype, fitness);
356 				}
357 			}
358 			if (logger.isDebugEnabled()) {
359 				logger.debug("Ingested {} individuals from the offsprings out of the {} available",
360 						combinedPopulation.size() - ingestedFromArchive,
361 						offspringPopulation.size());
362 			}
363 
364 		} else {
365 			combinedPopulation.addAll(archive);
366 			combinedPopulation.addAll(offspringPopulation);
367 		}
368 
369 		final Comparator<T> dominance = switch (eaConfiguration.optimization()) {
370 			case MAXIMIZE -> spea2Replacement.dominance();
371 			case MINIMIZE -> spea2Replacement.dominance()
372 					.reversed();
373 		};
374 
375 		final int k = spea2Replacement.k()
376 				.orElseGet(() -> (int) Math.sqrt(combinedPopulation.size()));
377 		logger.trace("Using k={}", k);
378 		Validate.isTrue(k > 0);
379 
380 		///////////////// Fitness computation //////////////////////
381 		final double[] strengths = computeStrength(dominance, combinedPopulation);
382 
383 		final double[][] distanceObjectives = computeObjectiveDistances(spea2Replacement.distance(), combinedPopulation);
384 
385 		final double[] rawFitness = computeRawFitness(dominance, strengths, combinedPopulation);
386 
387 		final List<List<Pair<Integer, Double>>> distances = computeSortedDistances(distanceObjectives,
388 				combinedPopulation);
389 
390 		final double[] density = computeDensity(distances, k, combinedPopulation);
391 
392 		final double[] finalFitness = computeFinalFitness(rawFitness, density, combinedPopulation);
393 
394 		///////////////// Environmental Selection //////////////////
395 
396 		final Set<Integer> selectedIndex = environmentalSelection(distances,
397 				rawFitness,
398 				finalFitness,
399 				combinedPopulation,
400 				numIndividuals);
401 
402 		final Population<T> newPopulation = new Population<>();
403 		for (final int i : selectedIndex) {
404 			newPopulation.add(combinedPopulation.getGenotype(i), combinedPopulation.getFitness(i));
405 		}
406 
407 		final long endTimeNanos = System.nanoTime();
408 		if (logger.isDebugEnabled()) {
409 			logger.debug("Finished with {} new population - Computation time: {}",
410 					newPopulation.size(),
411 					DurationFormatUtils.formatDurationHMS((endTimeNanos - startTimeNanos) / 1_000_000));
412 		}
413 
414 		return newPopulation;
415 	}
416 }