1 package net.bmahe.genetics4j.core.selection; 2 3 import java.util.ArrayList; 4 import java.util.Comparator; 5 import java.util.List; 6 import java.util.Objects; 7 import java.util.random.RandomGenerator; 8 import java.util.stream.Collectors; 9 10 import org.apache.commons.lang3.Validate; 11 import org.apache.logging.log4j.LogManager; 12 import org.apache.logging.log4j.Logger; 13 14 import net.bmahe.genetics4j.core.Genotype; 15 import net.bmahe.genetics4j.core.Individual; 16 import net.bmahe.genetics4j.core.Population; 17 import net.bmahe.genetics4j.core.spec.AbstractEAConfiguration; 18 import net.bmahe.genetics4j.core.spec.AbstractEAExecutionContext; 19 import net.bmahe.genetics4j.core.spec.selection.MultiTournaments; 20 import net.bmahe.genetics4j.core.spec.selection.SelectionPolicy; 21 import net.bmahe.genetics4j.core.spec.selection.Tournament; 22 23 public class MultiTournamentsSelectionPolicyHandler<T extends Comparable<T>> implements SelectionPolicyHandler<T> { 24 final static public Logger logger = LogManager.getLogger(MultiTournamentsSelectionPolicyHandler.class); 25 26 private final RandomGenerator randomGenerator; 27 28 private List<Individual<T>> pickRandomCandidates(final RandomGenerator randomGenerator, 29 final List<Genotype> population, final List<T> fitnessScore, final int numCandidates) { 30 Objects.requireNonNull(randomGenerator); 31 Objects.requireNonNull(population); 32 Objects.requireNonNull(fitnessScore); 33 Validate.isTrue(fitnessScore.size() > 0); 34 Validate.isTrue(numCandidates > 0); 35 36 return randomGenerator.ints(0, fitnessScore.size()) 37 .boxed() 38 .limit(numCandidates) 39 .map(i -> Individual.of(population.get(i), fitnessScore.get(i))) 40 .collect(Collectors.toList()); 41 } 42 43 private Individual<T> runTournament(final Tournament<T> tournament, final List<Genotype> population, 44 final List<T> fitnessScore, final List<Individual<T>> candidates) { 45 Objects.requireNonNull(tournament); 46 47 final Comparator<Individual<T>> comparator = tournament.comparator(); 48 49 return candidates.stream() 50 .max(comparator) 51 .get(); 52 } 53 54 private Individual<T> runTournament(final RandomGenerator randomGenerator, final List<Tournament<T>> tournaments, 55 final List<Genotype> population, final List<T> fitnessScore, final int tournamentIndex) { 56 Objects.requireNonNull(tournaments); 57 Objects.requireNonNull(population); 58 Objects.requireNonNull(fitnessScore); 59 Validate.isTrue(tournamentIndex < tournaments.size()); 60 Validate.isTrue(tournamentIndex >= 0); 61 62 final Tournament<T> tournament = tournaments.get(tournamentIndex); 63 final int numCandidates = tournament.numCandidates(); 64 65 List<Individual<T>> candidates; 66 if (tournamentIndex == 0) { 67 candidates = pickRandomCandidates(randomGenerator, population, fitnessScore, numCandidates); 68 } else { 69 candidates = new ArrayList<>(); 70 71 for (int i = 0; i < numCandidates; i++) { 72 final Individual<T> candidate = runTournament(randomGenerator, 73 tournaments, 74 population, 75 fitnessScore, 76 tournamentIndex - 1); 77 candidates.add(candidate); 78 } 79 } 80 81 return runTournament(tournament, population, fitnessScore, candidates); 82 83 } 84 85 public MultiTournamentsSelectionPolicyHandler(final RandomGenerator _randomGenerator) { 86 Objects.requireNonNull(_randomGenerator); 87 88 this.randomGenerator = _randomGenerator; 89 } 90 91 @Override 92 public boolean canHandle(final SelectionPolicy selectionPolicy) { 93 Objects.requireNonNull(selectionPolicy); 94 return selectionPolicy instanceof MultiTournaments; 95 } 96 97 @Override 98 public Selector<T> resolve(final AbstractEAExecutionContext<T> eaExecutionContext, 99 final AbstractEAConfiguration<T> eaConfiguration, 100 final SelectionPolicyHandlerResolver<T> selectionPolicyHandlerResolver, 101 final SelectionPolicy selectionPolicy) { 102 Objects.requireNonNull(selectionPolicy); 103 Validate.isInstanceOf(MultiTournaments.class, selectionPolicy); 104 105 return new Selector<T>() { 106 107 @Override 108 public Population<T> select(final AbstractEAConfiguration<T> eaConfiguration, final long generation, 109 final int numIndividuals, final List<Genotype> population, final List<T> fitnessScore) { 110 Objects.requireNonNull(eaConfiguration); 111 Objects.requireNonNull(population); 112 Objects.requireNonNull(fitnessScore); 113 Validate.isTrue(numIndividuals > 0); 114 Validate.isTrue(population.size() == fitnessScore.size()); 115 116 @SuppressWarnings("unchecked") 117 final MultiTournaments<T> multiTournaments = (MultiTournaments<T>) selectionPolicy; 118 final List<Tournament<T>> tournaments = multiTournaments.tournaments(); 119 120 logger.debug("Selecting {} individuals", numIndividuals); 121 final Population<T> selectedIndividuals = new Population<>(); 122 while (selectedIndividuals.size() < numIndividuals) { 123 final Individual<T> selectedIndividual = runTournament(randomGenerator, 124 tournaments, 125 population, 126 fitnessScore, 127 tournaments.size() - 1); 128 selectedIndividuals.add(selectedIndividual); 129 } 130 131 return selectedIndividuals; 132 } 133 }; 134 } 135 }