1 package net.bmahe.genetics4j.moo.spea2.replacement;
2
3 import java.util.ArrayList;
4 import java.util.Collections;
5 import java.util.Comparator;
6 import java.util.HashMap;
7 import java.util.List;
8 import java.util.Map;
9 import java.util.Map.Entry;
10 import java.util.Set;
11 import java.util.TreeSet;
12 import java.util.function.BiFunction;
13 import java.util.stream.Collectors;
14 import java.util.stream.IntStream;
15
16 import org.apache.commons.lang3.Validate;
17 import org.apache.commons.lang3.time.DurationFormatUtils;
18 import org.apache.commons.lang3.tuple.Pair;
19 import org.apache.logging.log4j.LogManager;
20 import org.apache.logging.log4j.Logger;
21
22 import net.bmahe.genetics4j.core.Genotype;
23 import net.bmahe.genetics4j.core.Population;
24 import net.bmahe.genetics4j.core.replacement.ReplacementStrategyImplementor;
25 import net.bmahe.genetics4j.core.spec.AbstractEAConfiguration;
26 import net.bmahe.genetics4j.moo.spea2.spec.replacement.SPEA2Replacement;
27
28 public class SPEA2ReplacementStrategyImplementor<T extends Comparable<T>> implements ReplacementStrategyImplementor<T> {
29 final static public Logger logger = LogManager.getLogger(SPEA2ReplacementStrategyImplementor.class);
30
31 private final SPEA2Replacement<T> spea2Replacement;
32
33 public SPEA2ReplacementStrategyImplementor(final SPEA2Replacement<T> _spea2Replacement) {
34 this.spea2Replacement = _spea2Replacement;
35 }
36
37 protected double[] computeStrength(final Comparator<T> dominance, final Population<T> population) {
38 Validate.notNull(dominance);
39 Validate.notNull(population);
40 Validate.isTrue(population.size() > 0);
41
42 final double[] strengths = new double[population.size()];
43 for (int i = 0; i < population.size(); i++) {
44 final T fitness = population.getFitness(i);
45
46 strengths[i] = SPEA2Utils.strength(dominance, i, fitness, population);
47 }
48
49 return strengths;
50 }
51
52 protected double[][] computeObjectiveDistances(final BiFunction<T, T, Double> distance,
53 final Population<T> population) {
54 Validate.notNull(distance);
55 Validate.notNull(population);
56 Validate.isTrue(population.size() > 0);
57
58 final double[][] distanceObjectives = new double[population.size()][population.size()];
59
60 for (int i = 0; i < population.size(); i++) {
61 for (int j = 0; j < i; j++) {
62 final Double distanceMeasure = distance.apply(population.getFitness(i), population.getFitness(j));
63 distanceObjectives[i][j] = distanceMeasure;
64 distanceObjectives[j][i] = distanceMeasure;
65 }
66
67 distanceObjectives[i][i] = 0.0;
68 }
69 return distanceObjectives;
70 }
71
72 protected double[] computeRawFitness(final Comparator<T> dominance, final double[] strengths,
73 final Population<T> population) {
74 Validate.notNull(dominance);
75 Validate.notNull(strengths);
76 Validate.notNull(population);
77 Validate.isTrue(population.size() == strengths.length);
78 Validate.isTrue(population.size() > 0);
79
80 final double[] rawFitness = new double[population.size()];
81 for (int i = 0; i < population.size(); i++) {
82 final T fitness = population.getFitness(i);
83
84 rawFitness[i] = SPEA2Utils.rawFitness(dominance, strengths, i, fitness, population);
85 }
86
87 return rawFitness;
88 }
89
90 protected List<List<Pair<Integer, Double>>> computeSortedDistances(final double[][] distanceObjectives,
91 final Population<T> population) {
92 Validate.notNull(distanceObjectives);
93 Validate.notNull(population);
94 Validate.isTrue(population.size() == distanceObjectives.length);
95 Validate.isTrue(population.size() > 0);
96
97 final List<List<Pair<Integer, Double>>> distances = new ArrayList<>();
98 for (int i = 0; i < population.size(); i++) {
99 final T fitness = population.getFitness(i);
100
101 final List<Pair<Integer, Double>> kthDistances = SPEA2Utils
102 .kthDistances(distanceObjectives, i, fitness, population);
103 distances.add(kthDistances);
104
105 }
106 return distances;
107 }
108
109 protected double[] computeDensity(final List<List<Pair<Integer, Double>>> distances, final int k,
110 final Population<T> population) {
111 Validate.notNull(distances);
112 Validate.isTrue(population.size() == distances.size());
113 Validate.isTrue(k > 0);
114 Validate.notNull(population);
115 Validate.isTrue(population.size() > 0);
116
117 final double[] density = new double[population.size()];
118 for (int i = 0; i < population.size(); i++) {
119 density[i] = 1.0d / (distances.get(i)
120 .get(k)
121 .getRight() + 2);
122 }
123
124 return density;
125 }
126
127 protected double[] computeFinalFitness(final double[] rawFitness, final double[] density,
128 final Population<T> population) {
129 Validate.notNull(rawFitness);
130 Validate.notNull(density);
131 Validate.isTrue(rawFitness.length == density.length);
132 Validate.notNull(population);
133 Validate.isTrue(population.size() > 0);
134 Validate.isTrue(population.size() == density.length);
135
136 final double[] finalFitness = new double[population.size()];
137 for (int i = 0; i < population.size(); i++) {
138 finalFitness[i] = rawFitness[i] + density[i];
139 }
140
141 return finalFitness;
142 }
143
144 protected int skipNull(final List<Pair<Integer, Double>> distances, final int i) {
145 Validate.notNull(distances);
146 Validate.isTrue(i >= 0);
147 Validate.isTrue(i <= distances.size());
148
149 int j = i;
150
151 while (j < distances.size() && distances.get(j) == null) {
152 j++;
153 }
154
155 return j;
156 }
157
158 protected List<Integer> computeAdditionalIndividuals(final Set<Integer> selectedIndex, final double[] rawFitness,
159 final Population<T> population, final int numIndividuals) {
160 Validate.notNull(selectedIndex);
161 Validate.notNull(rawFitness);
162 Validate.notNull(population);
163 Validate.isTrue(rawFitness.length == population.size());
164 Validate.isTrue(numIndividuals >= selectedIndex.size());
165
166 if (numIndividuals == selectedIndex.size()) {
167 return Collections.emptyList();
168 }
169
170 final List<Integer> additionalIndividuals = IntStream.range(0, population.size())
171 .boxed()
172 .filter((i) -> selectedIndex.contains(i) == false)
173 .sorted((a, b) -> Double.compare(rawFitness[a], rawFitness[b]))
174 .limit(numIndividuals - selectedIndex.size())
175 .collect(Collectors.toList());
176
177 return additionalIndividuals;
178 }
179
180 protected void truncatePopulation(final List<List<Pair<Integer, Double>>> distances, final Population<T> population,
181 final int numIndividuals, final Set<Integer> selectedIndex) {
182
183 final Map<Integer, List<Pair<Integer, Double>>> selectedDistances = new HashMap<>();
184 final Map<Integer, Map<Integer, Integer>> selectedDistancesIndex = new HashMap<>();
185
186
187
188
189
190
191
192
193
194
195
196
197 for (final int index : selectedIndex) {
198
199 final List<Pair<Integer, Double>> kthDistances = distances.get(index)
200 .stream()
201 .filter(p -> selectedIndex.contains(p.getLeft()))
202 .collect(Collectors.toList());
203
204 Validate.isTrue(kthDistances.size() == selectedIndex.size());
205 selectedDistances.put(index, kthDistances);
206
207 for (int i = 0; i < kthDistances.size(); i++) {
208 final Pair<Integer, Double> pair = kthDistances.get(i);
209
210 if (selectedDistancesIndex.containsKey(pair.getKey()) == false) {
211 selectedDistancesIndex.put(pair.getKey(), new HashMap<>());
212 }
213
214 selectedDistancesIndex.get(pair.getKey())
215 .put(index, i);
216 }
217 }
218
219 while (selectedIndex.size() > numIndividuals) {
220
221 int minIndex = -1;
222 List<Pair<Integer, Double>> minDistances = null;
223 for (final int candidateIndex : selectedIndex) {
224
225 if (minIndex < 0) {
226 minIndex = candidateIndex;
227 minDistances = selectedDistances.get(candidateIndex);
228 } else {
229 final List<Pair<Integer, Double>> distancesCandidate = selectedDistances.get(candidateIndex);
230 Validate.isTrue(minDistances.size() == distancesCandidate.size());
231
232 int result = 0;
233 int j = skipNull(minDistances, 0);
234 int l = skipNull(distancesCandidate, 0);
235
236 while (result == 0 && j < minDistances.size() && l < distancesCandidate.size()) {
237
238 result = Double.compare(minDistances.get(j)
239 .getRight(),
240 distancesCandidate.get(l)
241 .getRight());
242
243 j++;
244 j = skipNull(minDistances, j);
245
246 l++;
247 l = skipNull(distancesCandidate, l);
248 }
249
250 if (result > 0) {
251 minIndex = candidateIndex;
252 minDistances = distancesCandidate;
253 }
254 }
255 }
256
257
258
259
260
261 final Map<Integer, Integer> reverseIndex = selectedDistancesIndex.get(minIndex);
262 for (Entry<Integer, Integer> entry : reverseIndex.entrySet()) {
263 final List<Pair<Integer, Double>> distancesToClean = selectedDistances.get(entry.getKey());
264 distancesToClean.set((int) entry.getValue(), null);
265 }
266 for (Map<Integer, Integer> map : selectedDistancesIndex.values()) {
267 map.remove(minIndex);
268 }
269
270 selectedDistancesIndex.remove(minIndex);
271 selectedDistances.remove(minIndex);
272 selectedIndex.remove(minIndex);
273 }
274
275 }
276
277 protected Set<Integer> environmentalSelection(final List<List<Pair<Integer, Double>>> distances,
278 final double[] rawFitness, final double[] finalFitness, final Population<T> population,
279 final int numIndividuals) {
280
281 final Set<Integer> selectedIndex = IntStream.range(0, population.size())
282 .boxed()
283 .filter((i) -> finalFitness[i] < 1)
284 .collect(Collectors.toSet());
285
286 logger.trace("Selected index size: {}", selectedIndex.size());
287
288 if (selectedIndex.size() < numIndividuals) {
289
290 final List<Integer> additionalIndividuals = computeAdditionalIndividuals(selectedIndex,
291 rawFitness,
292 population,
293 numIndividuals);
294
295 logger.trace("Adding {} additional individuals", additionalIndividuals.size());
296 selectedIndex.addAll(additionalIndividuals);
297 }
298
299 if (selectedIndex.size() > numIndividuals) {
300 logger.trace("Need to remove {} individuals", selectedIndex.size() - numIndividuals);
301
302 truncatePopulation(distances, population, numIndividuals, selectedIndex);
303 }
304
305 return selectedIndex;
306 }
307
308 @Override
309 public Population<T> select(final AbstractEAConfiguration<T> eaConfiguration, final int numIndividuals,
310 final List<Genotype> population, final List<T> populationScores, final List<Genotype> offsprings,
311 final List<T> offspringScores) {
312 Validate.notNull(eaConfiguration);
313 Validate.isTrue(numIndividuals > 0);
314 Validate.notNull(population);
315 Validate.notNull(populationScores);
316 Validate.isTrue(population.size() == populationScores.size());
317 Validate.notNull(offsprings);
318 Validate.notNull(offspringScores);
319 Validate.isTrue(offsprings.size() == offspringScores.size());
320
321 final long startTimeNanos = System.nanoTime();
322 logger.debug("Starting with requested {} individuals - {} population - {} offsprings",
323 numIndividuals,
324 population.size(),
325 offsprings.size());
326
327 final Population<T> archive = new Population<>(population, populationScores);
328 final Population<T> offspringPopulation = new Population<>(offsprings, offspringScores);
329
330 final Population<T> combinedPopulation = new Population<>();
331 if (spea2Replacement.deduplicate()
332 .isPresent()) {
333 final Comparator<Genotype> individualDeduplicator = spea2Replacement.deduplicate()
334 .get();
335 final Set<Genotype> seenGenotype = new TreeSet<>(individualDeduplicator);
336
337 for (int i = 0; i < archive.size(); i++) {
338 final Genotype genotype = archive.getGenotype(i);
339
340 if (seenGenotype.add(genotype)) {
341 final T fitness = archive.getFitness(i);
342 combinedPopulation.add(genotype, fitness);
343 }
344 }
345 final int ingestedFromArchive = combinedPopulation.size();
346 logger.debug("Ingested {} individuals from the archive out of the {} available",
347 ingestedFromArchive,
348 archive.size());
349
350 for (int i = 0; i < offspringPopulation.size(); i++) {
351 final Genotype genotype = offspringPopulation.getGenotype(i);
352
353 if (seenGenotype.add(genotype)) {
354 final T fitness = offspringPopulation.getFitness(i);
355 combinedPopulation.add(genotype, fitness);
356 }
357 }
358 if (logger.isDebugEnabled()) {
359 logger.debug("Ingested {} individuals from the offsprings out of the {} available",
360 combinedPopulation.size() - ingestedFromArchive,
361 offspringPopulation.size());
362 }
363
364 } else {
365 combinedPopulation.addAll(archive);
366 combinedPopulation.addAll(offspringPopulation);
367 }
368
369 final Comparator<T> dominance = switch (eaConfiguration.optimization()) {
370 case MAXIMIZE -> spea2Replacement.dominance();
371 case MINIMIZE -> spea2Replacement.dominance()
372 .reversed();
373 };
374
375 final int k = spea2Replacement.k()
376 .orElseGet(() -> (int) Math.sqrt(combinedPopulation.size()));
377 logger.trace("Using k={}", k);
378 Validate.isTrue(k > 0);
379
380
381 final double[] strengths = computeStrength(dominance, combinedPopulation);
382
383 final double[][] distanceObjectives = computeObjectiveDistances(spea2Replacement.distance(), combinedPopulation);
384
385 final double[] rawFitness = computeRawFitness(dominance, strengths, combinedPopulation);
386
387 final List<List<Pair<Integer, Double>>> distances = computeSortedDistances(distanceObjectives,
388 combinedPopulation);
389
390 final double[] density = computeDensity(distances, k, combinedPopulation);
391
392 final double[] finalFitness = computeFinalFitness(rawFitness, density, combinedPopulation);
393
394
395
396 final Set<Integer> selectedIndex = environmentalSelection(distances,
397 rawFitness,
398 finalFitness,
399 combinedPopulation,
400 numIndividuals);
401
402 final Population<T> newPopulation = new Population<>();
403 for (final int i : selectedIndex) {
404 newPopulation.add(combinedPopulation.getGenotype(i), combinedPopulation.getFitness(i));
405 }
406
407 final long endTimeNanos = System.nanoTime();
408 if (logger.isDebugEnabled()) {
409 logger.debug("Finished with {} new population - Computation time: {}",
410 newPopulation.size(),
411 DurationFormatUtils.formatDurationHMS((endTimeNanos - startTimeNanos) / 1_000_000));
412 }
413
414 return newPopulation;
415 }
416 }