NeatSelectorImpl.java

package net.bmahe.genetics4j.neat.selection;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.random.RandomGenerator;
import java.util.stream.IntStream;

import org.apache.commons.lang3.Validate;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

import net.bmahe.genetics4j.core.Genotype;
import net.bmahe.genetics4j.core.Individual;
import net.bmahe.genetics4j.core.Population;
import net.bmahe.genetics4j.core.selection.Selector;
import net.bmahe.genetics4j.core.spec.AbstractEAConfiguration;
import net.bmahe.genetics4j.core.util.IndividualUtils;
import net.bmahe.genetics4j.neat.NeatUtils;
import net.bmahe.genetics4j.neat.Species;
import net.bmahe.genetics4j.neat.SpeciesIdGenerator;
import net.bmahe.genetics4j.neat.spec.selection.NeatSelection;

public class NeatSelectorImpl<T extends Number & Comparable<T>> implements Selector<T> {
	public static final Logger logger = LogManager.getLogger(NeatSelectorImpl.class);

	private final RandomGenerator randomGenerator;
	private final NeatSelection<T> neatSelection;
	private final SpeciesIdGenerator speciesIdGenerator;
	private final Selector<T> speciesSelector;

	private List<Species<T>> previousSpecies;

	public NeatSelectorImpl(final RandomGenerator _randomGenerator, final NeatSelection<T> _neatSelection,
			final SpeciesIdGenerator _speciesIdGenerator, final Selector<T> _speciesSelector) {
		Validate.notNull(_randomGenerator);
		Validate.notNull(_neatSelection);
		Validate.notNull(_speciesIdGenerator);
		Validate.notNull(_speciesSelector);

		this.randomGenerator = _randomGenerator;
		this.neatSelection = _neatSelection;
		this.speciesIdGenerator = _speciesIdGenerator;
		this.speciesSelector = _speciesSelector;

		this.previousSpecies = new ArrayList<>();
	}

	protected Species<T> trimSpecies(final Species<T> species, final Comparator<Individual<T>> comparator,
			final int minSpeciesSize, final float perSpeciesKeepRatio) {
		Validate.notNull(species);

		final List<Individual<T>> members = species.getMembers();
		final float speciesSize = members.size();
		final int numIndividualtoKeep = (int) Math.max(minSpeciesSize, speciesSize * perSpeciesKeepRatio);

		if (logger.isDebugEnabled()) {
			logger.debug(
					"Species id: {}, size: {}, perSepciesKeepRatio: {}, we want to keep {} members - best fitness: {}",
					species.getId(),
					speciesSize,
					perSpeciesKeepRatio,
					numIndividualtoKeep,
					members.stream()
							.max(comparator)
							.map(Individual::fitness));
		}

		final Species<T> trimmedSpecies = new Species<>(species.getId(), List.of());
		if (numIndividualtoKeep > 0) {
			final var selectedIndividuals = members.stream()
					.sorted(comparator.reversed())
					.limit(numIndividualtoKeep)
					.toList();

			trimmedSpecies.addAllMembers(selectedIndividuals);
		}
		return trimmedSpecies;

	}

	protected List<Species<T>> eliminateLowestPerformers(final AbstractEAConfiguration<T> eaConfiguration,
			final List<Species<T>> allSpecies) {
		Validate.notNull(eaConfiguration);
		Validate.notNull(allSpecies);

		final Comparator<Individual<T>> comparator = IndividualUtils.fitnessBasedComparator(eaConfiguration);

		final float perSpeciesKeepRatio = neatSelection.perSpeciesKeepRatio();
		logger.trace("Keeping only the best {} number of individuals per species", perSpeciesKeepRatio);

		final int minSpeciesSize = neatSelection.minSpeciesSize();

		return allSpecies.stream()
				.map(species -> trimSpecies(species, comparator, minSpeciesSize, perSpeciesKeepRatio))
				.filter(species -> species.getNumMembers() > 0)
				.toList();
	}

	@Override
	public Population<T> select(final AbstractEAConfiguration<T> eaConfiguration, final int numIndividuals,
			final List<Genotype> genotypes, final List<T> fitnessScore) {
		Validate.notNull(eaConfiguration);
		Validate.notNull(genotypes);
		Validate.notNull(fitnessScore);
		Validate.isTrue(numIndividuals > 0);
		Validate.isTrue(genotypes.size() == fitnessScore.size());

		final Population<T> population = Population.of(genotypes, fitnessScore);

		final List<Species<T>> allSpecies = NeatUtils.speciate(randomGenerator,
				speciesIdGenerator,
				previousSpecies,
				population,
				neatSelection.speciesPredicate());

		logger.debug("Number of species found: {}", allSpecies.size());
		logger.trace("Species: {}", allSpecies);

		/**
		 * We want to remove the bottom performers of each species
		 */
		final var allTrimmedSpecies = eliminateLowestPerformers(eaConfiguration, allSpecies);
		logger.debug("After trimming, we have {} species", allTrimmedSpecies.size());

		previousSpecies = allTrimmedSpecies;
		if (allTrimmedSpecies.size() == 0) {
			return Population.empty();
		}

		/**
		 * Now we want to select the next generation on a per species basis and
		 * proportionally to the sum of the fitnesses of each members
		 */

		final double[] sumFitnesses = new double[allTrimmedSpecies.size()];
		double totalSum = 0;
		for (int i = 0; i < allTrimmedSpecies.size(); i++) {
			final var species = allTrimmedSpecies.get(i);
			sumFitnesses[i] = species.getMembers()
					.stream()
					.mapToDouble(individual -> individual.fitness()
							.doubleValue())
					.sum() / (float) species.getNumMembers();
			totalSum += sumFitnesses[i];
		}

		final List<Integer> decreasingFitnessIndex = IntStream.range(0, sumFitnesses.length)
				.boxed()
				.sorted(Comparator.comparing(i -> sumFitnesses[(int) i])
						.reversed())
				.toList();

		final Population<T> selected = new Population<>();

		final List<Population<T>> trimmedPopulations = allTrimmedSpecies.stream()
				.map(species -> Population.of(species.getMembers()))
				.toList();

		int i = 0;
		while (selected.size() < numIndividuals && i < sumFitnesses.length) {
			int speciesIndex = decreasingFitnessIndex.get(i);

			int numIndividualSpecies = (int) (numIndividuals * sumFitnesses[speciesIndex] / totalSum);
			if (numIndividualSpecies > numIndividuals - selected.size()) {
				numIndividualSpecies = numIndividuals - selected.size();
			}

			if (numIndividualSpecies > 0) {
				final Population<T> speciesPopulation = trimmedPopulations.get(speciesIndex);

				logger.debug("sub selecting {} for index {} - species id: {}",
						numIndividualSpecies,
						speciesIndex,
						allTrimmedSpecies.get(speciesIndex)
								.getId());

				final var selectedFromSpecies = speciesSelector.select(eaConfiguration,
						numIndividualSpecies,
						speciesPopulation.getAllGenotypes(),
						speciesPopulation.getAllFitnesses());

				selected.addAll(selectedFromSpecies);
			}

			i++;
		}

		if (selected.size() < numIndividuals) {
			logger.debug("There are less selected individual [{}] than desired [{}]. Will include additional invididuals",
					selected.size(),
					numIndividuals);
			final Population<T> speciesPopulation = trimmedPopulations.get(decreasingFitnessIndex.get(0));

			selected.addAll(speciesSelector.select(eaConfiguration,
					numIndividuals - selected.size(),
					speciesPopulation.getAllGenotypes(),
					speciesPopulation.getAllFitnesses()));
		}

		return selected;
	}
}