1 package net.bmahe.genetics4j.moo.spea2.replacement;
2
3 import java.util.ArrayList;
4 import java.util.Collections;
5 import java.util.Comparator;
6 import java.util.HashMap;
7 import java.util.List;
8 import java.util.Map;
9 import java.util.Map.Entry;
10 import java.util.Objects;
11 import java.util.Set;
12 import java.util.TreeSet;
13 import java.util.function.BiFunction;
14 import java.util.stream.Collectors;
15 import java.util.stream.IntStream;
16
17 import org.apache.commons.lang3.Validate;
18 import org.apache.commons.lang3.time.DurationFormatUtils;
19 import org.apache.commons.lang3.tuple.Pair;
20 import org.apache.logging.log4j.LogManager;
21 import org.apache.logging.log4j.Logger;
22
23 import net.bmahe.genetics4j.core.Genotype;
24 import net.bmahe.genetics4j.core.Population;
25 import net.bmahe.genetics4j.core.replacement.ReplacementStrategyImplementor;
26 import net.bmahe.genetics4j.core.spec.AbstractEAConfiguration;
27 import net.bmahe.genetics4j.moo.spea2.spec.replacement.SPEA2Replacement;
28
29 public class SPEA2ReplacementStrategyImplementor<T extends Comparable<T>> implements ReplacementStrategyImplementor<T> {
30 final static public Logger logger = LogManager.getLogger(SPEA2ReplacementStrategyImplementor.class);
31
32 private final SPEA2Replacement<T> spea2Replacement;
33
34 public SPEA2ReplacementStrategyImplementor(final SPEA2Replacement<T> _spea2Replacement) {
35 this.spea2Replacement = _spea2Replacement;
36 }
37
38 protected double[] computeStrength(final Comparator<T> dominance, final Population<T> population) {
39 Objects.requireNonNull(dominance);
40 Objects.requireNonNull(population);
41 Validate.isTrue(population.size() > 0);
42
43 final double[] strengths = new double[population.size()];
44 for (int i = 0; i < population.size(); i++) {
45 final T fitness = population.getFitness(i);
46
47 strengths[i] = SPEA2Utils.strength(dominance, i, fitness, population);
48 }
49
50 return strengths;
51 }
52
53 protected double[][] computeObjectiveDistances(final BiFunction<T, T, Double> distance,
54 final Population<T> population) {
55 Objects.requireNonNull(distance);
56 Objects.requireNonNull(population);
57 Validate.isTrue(population.size() > 0);
58
59 final double[][] distanceObjectives = new double[population.size()][population.size()];
60
61 for (int i = 0; i < population.size(); i++) {
62 for (int j = 0; j < i; j++) {
63 final Double distanceMeasure = distance.apply(population.getFitness(i), population.getFitness(j));
64 distanceObjectives[i][j] = distanceMeasure;
65 distanceObjectives[j][i] = distanceMeasure;
66 }
67
68 distanceObjectives[i][i] = 0.0;
69 }
70 return distanceObjectives;
71 }
72
73 protected double[] computeRawFitness(final Comparator<T> dominance, final double[] strengths,
74 final Population<T> population) {
75 Objects.requireNonNull(dominance);
76 Objects.requireNonNull(strengths);
77 Objects.requireNonNull(population);
78 Validate.isTrue(population.size() == strengths.length);
79 Validate.isTrue(population.size() > 0);
80
81 final double[] rawFitness = new double[population.size()];
82 for (int i = 0; i < population.size(); i++) {
83 final T fitness = population.getFitness(i);
84
85 rawFitness[i] = SPEA2Utils.rawFitness(dominance, strengths, i, fitness, population);
86 }
87
88 return rawFitness;
89 }
90
91 protected List<List<Pair<Integer, Double>>> computeSortedDistances(final double[][] distanceObjectives,
92 final Population<T> population) {
93 Objects.requireNonNull(distanceObjectives);
94 Objects.requireNonNull(population);
95 Validate.isTrue(population.size() == distanceObjectives.length);
96 Validate.isTrue(population.size() > 0);
97
98 final List<List<Pair<Integer, Double>>> distances = new ArrayList<>();
99 for (int i = 0; i < population.size(); i++) {
100 final T fitness = population.getFitness(i);
101
102 final List<Pair<Integer, Double>> kthDistances = SPEA2Utils
103 .kthDistances(distanceObjectives, i, fitness, population);
104 distances.add(kthDistances);
105
106 }
107 return distances;
108 }
109
110 protected double[] computeDensity(final List<List<Pair<Integer, Double>>> distances, final int k,
111 final Population<T> population) {
112 Objects.requireNonNull(distances);
113 Validate.isTrue(population.size() == distances.size());
114 Validate.isTrue(k > 0);
115 Objects.requireNonNull(population);
116 Validate.isTrue(population.size() > 0);
117
118 final double[] density = new double[population.size()];
119 for (int i = 0; i < population.size(); i++) {
120 density[i] = 1.0d / (distances.get(i).get(k).getRight() + 2);
121 }
122
123 return density;
124 }
125
126 protected double[] computeFinalFitness(final double[] rawFitness, final double[] density,
127 final Population<T> population) {
128 Objects.requireNonNull(rawFitness);
129 Objects.requireNonNull(density);
130 Validate.isTrue(rawFitness.length == density.length);
131 Objects.requireNonNull(population);
132 Validate.isTrue(population.size() > 0);
133 Validate.isTrue(population.size() == density.length);
134
135 final double[] finalFitness = new double[population.size()];
136 for (int i = 0; i < population.size(); i++) {
137 finalFitness[i] = rawFitness[i] + density[i];
138 }
139
140 return finalFitness;
141 }
142
143 protected int skipNull(final List<Pair<Integer, Double>> distances, final int i) {
144 Objects.requireNonNull(distances);
145 Validate.isTrue(i >= 0);
146 Validate.isTrue(i <= distances.size());
147
148 int j = i;
149
150 while (j < distances.size() && distances.get(j) == null) {
151 j++;
152 }
153
154 return j;
155 }
156
157 protected List<Integer> computeAdditionalIndividuals(final Set<Integer> selectedIndex, final double[] rawFitness,
158 final Population<T> population, final int numIndividuals) {
159 Objects.requireNonNull(selectedIndex);
160 Objects.requireNonNull(rawFitness);
161 Objects.requireNonNull(population);
162 Validate.isTrue(rawFitness.length == population.size());
163 Validate.isTrue(numIndividuals >= selectedIndex.size());
164
165 if (numIndividuals == selectedIndex.size()) {
166 return Collections.emptyList();
167 }
168
169 final List<Integer> additionalIndividuals = IntStream.range(0, population.size())
170 .boxed()
171 .filter((i) -> selectedIndex.contains(i) == false)
172 .sorted((a, b) -> Double.compare(rawFitness[a], rawFitness[b]))
173 .limit(numIndividuals - selectedIndex.size())
174 .collect(Collectors.toList());
175
176 return additionalIndividuals;
177 }
178
179 protected void truncatePopulation(final List<List<Pair<Integer, Double>>> distances, final Population<T> population,
180 final int numIndividuals, final Set<Integer> selectedIndex) {
181
182 final Map<Integer, List<Pair<Integer, Double>>> selectedDistances = new HashMap<>();
183 final Map<Integer, Map<Integer, Integer>> selectedDistancesIndex = new HashMap<>();
184
185
186
187
188
189
190
191
192
193
194 for (final int index : selectedIndex) {
195
196 final List<Pair<Integer, Double>> kthDistances = distances.get(index)
197 .stream()
198 .filter(p -> selectedIndex.contains(p.getLeft()))
199 .collect(Collectors.toList());
200
201 Validate.isTrue(kthDistances.size() == selectedIndex.size());
202 selectedDistances.put(index, kthDistances);
203
204 for (int i = 0; i < kthDistances.size(); i++) {
205 final Pair<Integer, Double> pair = kthDistances.get(i);
206
207 if (selectedDistancesIndex.containsKey(pair.getKey()) == false) {
208 selectedDistancesIndex.put(pair.getKey(), new HashMap<>());
209 }
210
211 selectedDistancesIndex.get(pair.getKey()).put(index, i);
212 }
213 }
214
215 while (selectedIndex.size() > numIndividuals) {
216
217 int minIndex = -1;
218 List<Pair<Integer, Double>> minDistances = null;
219 for (final int candidateIndex : selectedIndex) {
220
221 if (minIndex < 0) {
222 minIndex = candidateIndex;
223 minDistances = selectedDistances.get(candidateIndex);
224 } else {
225 final List<Pair<Integer, Double>> distancesCandidate = selectedDistances.get(candidateIndex);
226 Validate.isTrue(minDistances.size() == distancesCandidate.size());
227
228 int result = 0;
229 int j = skipNull(minDistances, 0);
230 int l = skipNull(distancesCandidate, 0);
231
232 while (result == 0 && j < minDistances.size() && l < distancesCandidate.size()) {
233
234 result = Double.compare(minDistances.get(j).getRight(), distancesCandidate.get(l).getRight());
235
236 j++;
237 j = skipNull(minDistances, j);
238
239 l++;
240 l = skipNull(distancesCandidate, l);
241 }
242
243 if (result > 0) {
244 minIndex = candidateIndex;
245 minDistances = distancesCandidate;
246 }
247 }
248 }
249
250
251
252
253
254 final Map<Integer, Integer> reverseIndex = selectedDistancesIndex.get(minIndex);
255 for (Entry<Integer, Integer> entry : reverseIndex.entrySet()) {
256 final List<Pair<Integer, Double>> distancesToClean = selectedDistances.get(entry.getKey());
257 distancesToClean.set((int) entry.getValue(), null);
258 }
259 for (Map<Integer, Integer> map : selectedDistancesIndex.values()) {
260 map.remove(minIndex);
261 }
262
263 selectedDistancesIndex.remove(minIndex);
264 selectedDistances.remove(minIndex);
265 selectedIndex.remove(minIndex);
266 }
267
268 }
269
270 protected Set<Integer> environmentalSelection(final List<List<Pair<Integer, Double>>> distances,
271 final double[] rawFitness, final double[] finalFitness, final Population<T> population,
272 final int numIndividuals) {
273
274 final Set<Integer> selectedIndex = IntStream.range(0, population.size())
275 .boxed()
276 .filter((i) -> finalFitness[i] < 1)
277 .collect(Collectors.toSet());
278
279 logger.trace("Selected index size: {}", selectedIndex.size());
280
281 if (selectedIndex.size() < numIndividuals) {
282
283 final List<Integer> additionalIndividuals = computeAdditionalIndividuals(
284 selectedIndex,
285 rawFitness,
286 population,
287 numIndividuals);
288
289 logger.trace("Adding {} additional individuals", additionalIndividuals.size());
290 selectedIndex.addAll(additionalIndividuals);
291 }
292
293 if (selectedIndex.size() > numIndividuals) {
294 logger.trace("Need to remove {} individuals", selectedIndex.size() - numIndividuals);
295
296 truncatePopulation(distances, population, numIndividuals, selectedIndex);
297 }
298
299 return selectedIndex;
300 }
301
302 @Override
303 public Population<T> select(final AbstractEAConfiguration<T> eaConfiguration, final long generation,
304 final int numIndividuals, final List<Genotype> population, final List<T> populationScores,
305 final List<Genotype> offsprings, final List<T> offspringScores) {
306 Objects.requireNonNull(eaConfiguration);
307 Validate.isTrue(generation >= 0);
308 Validate.isTrue(numIndividuals > 0);
309 Objects.requireNonNull(population);
310 Objects.requireNonNull(populationScores);
311 Validate.isTrue(population.size() == populationScores.size());
312 Objects.requireNonNull(offsprings);
313 Objects.requireNonNull(offspringScores);
314 Validate.isTrue(offsprings.size() == offspringScores.size());
315
316 final long startTimeNanos = System.nanoTime();
317 logger.debug(
318 "Starting with requested {} individuals - {} population - {} offsprings",
319 numIndividuals,
320 population.size(),
321 offsprings.size());
322
323 final Population<T> archive = new Population<>(population, populationScores);
324 final Population<T> offspringPopulation = new Population<>(offsprings, offspringScores);
325
326 final Population<T> combinedPopulation = new Population<>();
327 if (spea2Replacement.deduplicate().isPresent()) {
328 final Comparator<Genotype> individualDeduplicator = spea2Replacement.deduplicate().get();
329 final Set<Genotype> seenGenotype = new TreeSet<>(individualDeduplicator);
330
331 for (int i = 0; i < archive.size(); i++) {
332 final Genotype genotype = archive.getGenotype(i);
333
334 if (seenGenotype.add(genotype)) {
335 final T fitness = archive.getFitness(i);
336 combinedPopulation.add(genotype, fitness);
337 }
338 }
339 final int ingestedFromArchive = combinedPopulation.size();
340 logger.debug(
341 "Ingested {} individuals from the archive out of the {} available",
342 ingestedFromArchive,
343 archive.size());
344
345 for (int i = 0; i < offspringPopulation.size(); i++) {
346 final Genotype genotype = offspringPopulation.getGenotype(i);
347
348 if (seenGenotype.add(genotype)) {
349 final T fitness = offspringPopulation.getFitness(i);
350 combinedPopulation.add(genotype, fitness);
351 }
352 }
353 if (logger.isDebugEnabled()) {
354 logger.debug(
355 "Ingested {} individuals from the offsprings out of the {} available",
356 combinedPopulation.size() - ingestedFromArchive,
357 offspringPopulation.size());
358 }
359
360 } else {
361 combinedPopulation.addAll(archive);
362 combinedPopulation.addAll(offspringPopulation);
363 }
364
365 final Comparator<T> dominance = switch (eaConfiguration.optimization()) {
366 case MAXIMIZE -> spea2Replacement.dominance();
367 case MINIMIZE -> spea2Replacement.dominance().reversed();
368 };
369
370 final int k = spea2Replacement.k().orElseGet(() -> (int) Math.sqrt(combinedPopulation.size()));
371 logger.trace("Using k={}", k);
372 Validate.isTrue(k > 0);
373
374
375 final double[] strengths = computeStrength(dominance, combinedPopulation);
376
377 final double[][] distanceObjectives = computeObjectiveDistances(spea2Replacement.distance(), combinedPopulation);
378
379 final double[] rawFitness = computeRawFitness(dominance, strengths, combinedPopulation);
380
381 final List<List<Pair<Integer, Double>>> distances = computeSortedDistances(
382 distanceObjectives,
383 combinedPopulation);
384
385 final double[] density = computeDensity(distances, k, combinedPopulation);
386
387 final double[] finalFitness = computeFinalFitness(rawFitness, density, combinedPopulation);
388
389
390
391 final Set<Integer> selectedIndex = environmentalSelection(
392 distances,
393 rawFitness,
394 finalFitness,
395 combinedPopulation,
396 numIndividuals);
397
398 final Population<T> newPopulation = new Population<>();
399 for (final int i : selectedIndex) {
400 newPopulation.add(combinedPopulation.getGenotype(i), combinedPopulation.getFitness(i));
401 }
402
403 final long endTimeNanos = System.nanoTime();
404 if (logger.isDebugEnabled()) {
405 logger.debug(
406 "Finished with {} new population - Computation time: {}",
407 newPopulation.size(),
408 DurationFormatUtils.formatDurationHMS((endTimeNanos - startTimeNanos) / 1_000_000));
409 }
410
411 return newPopulation;
412 }
413 }