1 package net.bmahe.genetics4j.core.selection;
2
3 import java.util.ArrayList;
4 import java.util.Comparator;
5 import java.util.List;
6 import java.util.Objects;
7 import java.util.random.RandomGenerator;
8 import java.util.stream.Collectors;
9
10 import org.apache.commons.lang3.Validate;
11 import org.apache.logging.log4j.LogManager;
12 import org.apache.logging.log4j.Logger;
13
14 import net.bmahe.genetics4j.core.Genotype;
15 import net.bmahe.genetics4j.core.Individual;
16 import net.bmahe.genetics4j.core.Population;
17 import net.bmahe.genetics4j.core.spec.AbstractEAConfiguration;
18 import net.bmahe.genetics4j.core.spec.AbstractEAExecutionContext;
19 import net.bmahe.genetics4j.core.spec.selection.MultiTournaments;
20 import net.bmahe.genetics4j.core.spec.selection.SelectionPolicy;
21 import net.bmahe.genetics4j.core.spec.selection.Tournament;
22
23 public class MultiTournamentsSelectionPolicyHandler<T extends Comparable<T>> implements SelectionPolicyHandler<T> {
24 final static public Logger logger = LogManager.getLogger(MultiTournamentsSelectionPolicyHandler.class);
25
26 private final RandomGenerator randomGenerator;
27
28 private List<Individual<T>> pickRandomCandidates(final RandomGenerator randomGenerator,
29 final List<Genotype> population, final List<T> fitnessScore, final int numCandidates) {
30 Objects.requireNonNull(randomGenerator);
31 Objects.requireNonNull(population);
32 Objects.requireNonNull(fitnessScore);
33 Validate.isTrue(fitnessScore.size() > 0);
34 Validate.isTrue(numCandidates > 0);
35
36 return randomGenerator.ints(0, fitnessScore.size())
37 .boxed()
38 .limit(numCandidates)
39 .map(i -> Individual.of(population.get(i), fitnessScore.get(i)))
40 .collect(Collectors.toList());
41 }
42
43 private Individual<T> runTournament(final Tournament<T> tournament, final List<Genotype> population,
44 final List<T> fitnessScore, final List<Individual<T>> candidates) {
45 Objects.requireNonNull(tournament);
46
47 final Comparator<Individual<T>> comparator = tournament.comparator();
48
49 return candidates.stream().max(comparator).get();
50 }
51
52 private Individual<T> runTournament(final RandomGenerator randomGenerator, final List<Tournament<T>> tournaments,
53 final List<Genotype> population, final List<T> fitnessScore, final int tournamentIndex) {
54 Objects.requireNonNull(tournaments);
55 Objects.requireNonNull(population);
56 Objects.requireNonNull(fitnessScore);
57 Validate.isTrue(tournamentIndex < tournaments.size());
58 Validate.isTrue(tournamentIndex >= 0);
59
60 final Tournament<T> tournament = tournaments.get(tournamentIndex);
61 final int numCandidates = tournament.numCandidates();
62
63 List<Individual<T>> candidates;
64 if (tournamentIndex == 0) {
65 candidates = pickRandomCandidates(randomGenerator, population, fitnessScore, numCandidates);
66 } else {
67 candidates = new ArrayList<>();
68
69 for (int i = 0; i < numCandidates; i++) {
70 final Individual<T> candidate = runTournament(
71 randomGenerator,
72 tournaments,
73 population,
74 fitnessScore,
75 tournamentIndex - 1);
76 candidates.add(candidate);
77 }
78 }
79
80 return runTournament(tournament, population, fitnessScore, candidates);
81
82 }
83
84 public MultiTournamentsSelectionPolicyHandler(final RandomGenerator _randomGenerator) {
85 Objects.requireNonNull(_randomGenerator);
86
87 this.randomGenerator = _randomGenerator;
88 }
89
90 @Override
91 public boolean canHandle(final SelectionPolicy selectionPolicy) {
92 Objects.requireNonNull(selectionPolicy);
93 return selectionPolicy instanceof MultiTournaments;
94 }
95
96 @Override
97 public Selector<T> resolve(final AbstractEAExecutionContext<T> eaExecutionContext,
98 final AbstractEAConfiguration<T> eaConfiguration,
99 final SelectionPolicyHandlerResolver<T> selectionPolicyHandlerResolver,
100 final SelectionPolicy selectionPolicy) {
101 Objects.requireNonNull(selectionPolicy);
102 Validate.isInstanceOf(MultiTournaments.class, selectionPolicy);
103
104 return new Selector<T>() {
105
106 @Override
107 public Population<T> select(final AbstractEAConfiguration<T> eaConfiguration, final long generation,
108 final int numIndividuals, final List<Genotype> population, final List<T> fitnessScore) {
109 Objects.requireNonNull(eaConfiguration);
110 Objects.requireNonNull(population);
111 Objects.requireNonNull(fitnessScore);
112 Validate.isTrue(numIndividuals > 0);
113 Validate.isTrue(population.size() == fitnessScore.size());
114
115 @SuppressWarnings("unchecked")
116 final MultiTournaments<T> multiTournaments = (MultiTournaments<T>) selectionPolicy;
117 final List<Tournament<T>> tournaments = multiTournaments.tournaments();
118
119 logger.debug("Selecting {} individuals", numIndividuals);
120 final Population<T> selectedIndividuals = new Population<>();
121 while (selectedIndividuals.size() < numIndividuals) {
122 final Individual<T> selectedIndividual = runTournament(
123 randomGenerator,
124 tournaments,
125 population,
126 fitnessScore,
127 tournaments.size() - 1);
128 selectedIndividuals.add(selectedIndividual);
129 }
130
131 return selectedIndividuals;
132 }
133 };
134 }
135 }