1 package net.bmahe.genetics4j.core.selection;
2
3 import java.util.ArrayList;
4 import java.util.Comparator;
5 import java.util.List;
6 import java.util.Objects;
7 import java.util.random.RandomGenerator;
8 import java.util.stream.Collectors;
9
10 import org.apache.commons.lang3.Validate;
11 import org.apache.logging.log4j.LogManager;
12 import org.apache.logging.log4j.Logger;
13
14 import net.bmahe.genetics4j.core.Genotype;
15 import net.bmahe.genetics4j.core.Individual;
16 import net.bmahe.genetics4j.core.Population;
17 import net.bmahe.genetics4j.core.spec.AbstractEAConfiguration;
18 import net.bmahe.genetics4j.core.spec.AbstractEAExecutionContext;
19 import net.bmahe.genetics4j.core.spec.selection.MultiTournaments;
20 import net.bmahe.genetics4j.core.spec.selection.SelectionPolicy;
21 import net.bmahe.genetics4j.core.spec.selection.Tournament;
22
23 public class MultiTournamentsSelectionPolicyHandler<T extends Comparable<T>> implements SelectionPolicyHandler<T> {
24 final static public Logger logger = LogManager.getLogger(MultiTournamentsSelectionPolicyHandler.class);
25
26 private final RandomGenerator randomGenerator;
27
28 private List<Individual<T>> pickRandomCandidates(final RandomGenerator randomGenerator,
29 final List<Genotype> population, final List<T> fitnessScore, final int numCandidates) {
30 Objects.requireNonNull(randomGenerator);
31 Objects.requireNonNull(population);
32 Objects.requireNonNull(fitnessScore);
33 Validate.isTrue(fitnessScore.size() > 0);
34 Validate.isTrue(numCandidates > 0);
35
36 return randomGenerator.ints(0, fitnessScore.size())
37 .boxed()
38 .limit(numCandidates)
39 .map(i -> Individual.of(population.get(i), fitnessScore.get(i)))
40 .collect(Collectors.toList());
41 }
42
43 private Individual<T> runTournament(final Tournament<T> tournament, final List<Genotype> population,
44 final List<T> fitnessScore, final List<Individual<T>> candidates) {
45 Objects.requireNonNull(tournament);
46
47 final Comparator<Individual<T>> comparator = tournament.comparator();
48
49 return candidates.stream()
50 .max(comparator)
51 .get();
52 }
53
54 private Individual<T> runTournament(final RandomGenerator randomGenerator, final List<Tournament<T>> tournaments,
55 final List<Genotype> population, final List<T> fitnessScore, final int tournamentIndex) {
56 Objects.requireNonNull(tournaments);
57 Objects.requireNonNull(population);
58 Objects.requireNonNull(fitnessScore);
59 Validate.isTrue(tournamentIndex < tournaments.size());
60 Validate.isTrue(tournamentIndex >= 0);
61
62 final Tournament<T> tournament = tournaments.get(tournamentIndex);
63 final int numCandidates = tournament.numCandidates();
64
65 List<Individual<T>> candidates;
66 if (tournamentIndex == 0) {
67 candidates = pickRandomCandidates(randomGenerator, population, fitnessScore, numCandidates);
68 } else {
69 candidates = new ArrayList<>();
70
71 for (int i = 0; i < numCandidates; i++) {
72 final Individual<T> candidate = runTournament(randomGenerator,
73 tournaments,
74 population,
75 fitnessScore,
76 tournamentIndex - 1);
77 candidates.add(candidate);
78 }
79 }
80
81 return runTournament(tournament, population, fitnessScore, candidates);
82
83 }
84
85 public MultiTournamentsSelectionPolicyHandler(final RandomGenerator _randomGenerator) {
86 Objects.requireNonNull(_randomGenerator);
87
88 this.randomGenerator = _randomGenerator;
89 }
90
91 @Override
92 public boolean canHandle(final SelectionPolicy selectionPolicy) {
93 Objects.requireNonNull(selectionPolicy);
94 return selectionPolicy instanceof MultiTournaments;
95 }
96
97 @Override
98 public Selector<T> resolve(final AbstractEAExecutionContext<T> eaExecutionContext,
99 final AbstractEAConfiguration<T> eaConfiguration,
100 final SelectionPolicyHandlerResolver<T> selectionPolicyHandlerResolver,
101 final SelectionPolicy selectionPolicy) {
102 Objects.requireNonNull(selectionPolicy);
103 Validate.isInstanceOf(MultiTournaments.class, selectionPolicy);
104
105 return new Selector<T>() {
106
107 @Override
108 public Population<T> select(final AbstractEAConfiguration<T> eaConfiguration, final long generation,
109 final int numIndividuals, final List<Genotype> population, final List<T> fitnessScore) {
110 Objects.requireNonNull(eaConfiguration);
111 Objects.requireNonNull(population);
112 Objects.requireNonNull(fitnessScore);
113 Validate.isTrue(numIndividuals > 0);
114 Validate.isTrue(population.size() == fitnessScore.size());
115
116 @SuppressWarnings("unchecked")
117 final MultiTournaments<T> multiTournaments = (MultiTournaments<T>) selectionPolicy;
118 final List<Tournament<T>> tournaments = multiTournaments.tournaments();
119
120 logger.debug("Selecting {} individuals", numIndividuals);
121 final Population<T> selectedIndividuals = new Population<>();
122 while (selectedIndividuals.size() < numIndividuals) {
123 final Individual<T> selectedIndividual = runTournament(randomGenerator,
124 tournaments,
125 population,
126 fitnessScore,
127 tournaments.size() - 1);
128 selectedIndividuals.add(selectedIndividual);
129 }
130
131 return selectedIndividuals;
132 }
133 };
134 }
135 }