SelectiveRefinementTournamentSelector.java

package net.bmahe.genetics4j.core.selection;

import java.util.Comparator;
import java.util.List;
import java.util.Objects;
import java.util.random.RandomGenerator;
import java.util.stream.IntStream;

import org.apache.commons.lang3.Validate;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

import net.bmahe.genetics4j.core.Genotype;
import net.bmahe.genetics4j.core.Individual;
import net.bmahe.genetics4j.core.Population;
import net.bmahe.genetics4j.core.spec.AbstractEAConfiguration;
import net.bmahe.genetics4j.core.spec.selection.SelectiveRefinementTournament;
import net.bmahe.genetics4j.core.spec.selection.Tournament;
import net.bmahe.genetics4j.core.util.IndividualUtils;

public class SelectiveRefinementTournamentSelector<T extends Comparable<T>> implements Selector<T> {
	final static public Logger logger = LogManager.getLogger(SelectiveRefinementTournamentSelector.class);

	private final SelectiveRefinementTournament<T> selectiveRefinementTournament;
	private final RandomGenerator randomGenerator;

	public SelectiveRefinementTournamentSelector(final SelectiveRefinementTournament<T> _selectiveRefinementTournament,
			final RandomGenerator _randomGenerator) {
		Objects.requireNonNull(_selectiveRefinementTournament);
		Objects.requireNonNull(_randomGenerator);

		this.selectiveRefinementTournament = _selectiveRefinementTournament;
		this.randomGenerator = _randomGenerator;
	}

	protected Individual<T> randomIndividual(final List<Genotype> population, final List<T> fitnessScore) {
		Objects.requireNonNull(population);
		Objects.requireNonNull(fitnessScore);
		Validate.isTrue(fitnessScore.size() > 0);
		Validate.isTrue(population.size() == fitnessScore.size());

		final int candidateIndex = randomGenerator.nextInt(fitnessScore.size());
		return Individual.of(population.get(candidateIndex), fitnessScore.get(candidateIndex));

	}

	protected Individual<T> selectForFitness(final AbstractEAConfiguration<T> eaConfiguration,
			final Comparator<Individual<T>> fitnessComparator, final int numCandidates, final List<Genotype> population,
			final List<T> fitnessScore) {
		Objects.requireNonNull(population);
		Objects.requireNonNull(fitnessScore);
		Validate.isTrue(fitnessScore.isEmpty() == false);

		return IntStream.range(0, numCandidates)
				.boxed()
				.map(i -> randomIndividual(population, fitnessScore))
				.max((a, b) -> fitnessComparator.compare(a, b))
				.get();
	}

	protected Individual<T> selectForRefinement(final Comparator<Individual<T>> refinementComparator,
			final Individual<T> candidateA, final Individual<T> candidateB) {
		Objects.requireNonNull(refinementComparator);
		Objects.requireNonNull(candidateA);
		Objects.requireNonNull(candidateB);

		int compared = refinementComparator.compare(candidateA, candidateB);
		if (compared < 0) {
			return candidateB;
		} else if (compared > 0) {
			return candidateA;
		}
		return randomGenerator.nextFloat() < 0.5 ? candidateA : candidateB;
	}

	@Override
	public Population<T> select(final AbstractEAConfiguration<T> eaConfiguration, final long generation,
			final int numIndividuals, final List<Genotype> population, final List<T> fitnessScore) {
		Objects.requireNonNull(eaConfiguration);
		Objects.requireNonNull(population);
		Objects.requireNonNull(fitnessScore);
		Validate.isTrue(generation >= 0);
		Validate.isTrue(numIndividuals > 0);
		Validate.isTrue(population.size() == fitnessScore.size());

		final Tournament<T> tournament = selectiveRefinementTournament.tournament();
		final Comparator<Individual<T>> refinementComparator = selectiveRefinementTournament.refinementComparator();
		final float refinementRatio = selectiveRefinementTournament.refinementRatio();

		final Comparator<Individual<T>> fitnessComparator = IndividualUtils.fitnessBasedComparator(eaConfiguration);

		logger.debug("Selecting {} individuals", numIndividuals);

		final Population<T> selectedIndividuals = new Population<>();

		while (selectedIndividuals.size() < numIndividuals) {

			final Individual<T> first = selectForFitness(eaConfiguration,
					fitnessComparator,
					tournament.numCandidates(),
					population,
					fitnessScore);
			final Individual<T> second = selectForFitness(eaConfiguration,
					fitnessComparator,
					tournament.numCandidates(),
					population,
					fitnessScore);

			final Individual<T> selected = randomGenerator.nextFloat() < refinementRatio
					? selectForRefinement(refinementComparator, first, second)
					: (fitnessComparator.compare(first, second) < 0 ? second : first);

			selectedIndividuals.add(selected);
		}

		return selectedIndividuals;
	}
}