NeatUtils.java

package net.bmahe.genetics4j.neat;

import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Deque;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import java.util.function.BiPredicate;
import java.util.random.RandomGenerator;

import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.Validate;

import net.bmahe.genetics4j.core.Genotype;
import net.bmahe.genetics4j.core.Individual;
import net.bmahe.genetics4j.core.Population;
import net.bmahe.genetics4j.neat.chromosomes.NeatChromosome;

public class NeatUtils {

	private NeatUtils() {
	}

	/**
	 * Working backward from the output nodes, we identify the nodes that did not
	 * get visited as dead nodes
	 * 
	 * @param connections
	 * @param forwardConnections
	 * @param backwardConnections
	 * @param outputNodeIndices
	 * @return
	 */
	public static Set<Integer> computeDeadNodes(final List<Connection> connections,
			final Map<Integer, Set<Integer>> forwardConnections, final Map<Integer, Set<Integer>> backwardConnections,
			final Set<Integer> outputNodeIndices) {
		Validate.notNull(connections);

		final Set<Integer> deadNodes = new HashSet<>();
		for (final Connection connection : connections) {
			deadNodes.add(connection.fromNodeIndex());
			deadNodes.add(connection.toNodeIndex());
		}
		deadNodes.removeAll(outputNodeIndices);

		final Set<Integer> visited = new HashSet<>();
		final Deque<Integer> toVisit = new ArrayDeque<>(outputNodeIndices);
		while (toVisit.size() > 0) {
			final Integer currentNode = toVisit.poll();

			deadNodes.remove(currentNode);
			if (visited.contains(currentNode) == false) {

				visited.add(currentNode);

				final var next = backwardConnections.getOrDefault(currentNode, Set.of());
				if (next.size() > 0) {
					toVisit.addAll(next);
				}
			}
		}

		return deadNodes;
	}

	public static Map<Integer, Set<Integer>> computeForwardLinks(final List<Connection> connections) {
		Validate.notNull(connections);

		final Map<Integer, Set<Integer>> forwardConnections = new HashMap<>();
		for (final Connection connection : connections) {
			final var fromNodeIndex = connection.fromNodeIndex();
			final var toNodeIndex = connection.toNodeIndex();

			if (connection.isEnabled()) {
				final var toNodes = forwardConnections.computeIfAbsent(fromNodeIndex, k -> new HashSet<>());

				if (toNodes.add(toNodeIndex) == false) {
					throw new IllegalArgumentException(
							"Found duplicate entries for nodes defined in connection " + connection);
				}
			}
		}

		return forwardConnections;
	}

	public static Map<Integer, Set<Integer>> computeBackwardLinks(final List<Connection> connections) {
		Validate.notNull(connections);

		final Map<Integer, Set<Integer>> backwardConnections = new HashMap<>();
		for (final Connection connection : connections) {
			final var fromNodeIndex = connection.fromNodeIndex();
			final var toNodeIndex = connection.toNodeIndex();

			if (connection.isEnabled()) {
				final var fromNodes = backwardConnections.computeIfAbsent(toNodeIndex, k -> new HashSet<>());

				if (fromNodes.add(fromNodeIndex) == false) {
					throw new IllegalArgumentException(
							"Found duplicate entries for nodes defined in connection " + connection);
				}
			}
		}
		return backwardConnections;
	}

	public static Map<Integer, Set<Connection>> computeBackwardConnections(final List<Connection> connections) {
		Validate.notNull(connections);

		final Map<Integer, Set<Connection>> backwardConnections = new HashMap<>();
		for (final Connection connection : connections) {
			final var toNodeIndex = connection.toNodeIndex();

			if (connection.isEnabled()) {
				final var fromConnections = backwardConnections.computeIfAbsent(toNodeIndex, k -> new HashSet<>());

				if (fromConnections.stream()
						.anyMatch(existingConnection -> existingConnection.fromNodeIndex() == connection.fromNodeIndex())) {
					throw new IllegalArgumentException(
							"Found duplicate entries for nodes defined in connection " + connection);
				}
				fromConnections.add(connection);
			}
		}
		return backwardConnections;
	}

	public static List<List<Integer>> partitionLayersNodes(final Set<Integer> inputNodeIndices,
			final Set<Integer> outputNodeIndices, final List<Connection> connections) {
		Validate.isTrue(CollectionUtils.isNotEmpty(inputNodeIndices));
		Validate.isTrue(CollectionUtils.isNotEmpty(outputNodeIndices));
		Validate.isTrue(CollectionUtils.isNotEmpty(connections));

		final Map<Integer, Set<Integer>> forwardConnections = computeForwardLinks(connections);
		final Map<Integer, Set<Integer>> backwardConnections = computeBackwardLinks(connections);

		// Is it useful? If it's connected to the input node, it's not dead
		final var deadNodes = computeDeadNodes(connections, forwardConnections, backwardConnections, outputNodeIndices);

		final Set<Integer> processedSet = new HashSet<>();
		final List<List<Integer>> layers = new ArrayList<>();
		processedSet.addAll(inputNodeIndices);
		layers.add(new ArrayList<>(inputNodeIndices));

		boolean done = false;
		while (done == false) {
			final List<Integer> layer = new ArrayList<>();

			final Set<Integer> layerCandidates = new HashSet<>();
			for (final Entry<Integer, Set<Integer>> entry : forwardConnections.entrySet()) {
				final var key = entry.getKey();
				final var values = entry.getValue();

				if (processedSet.contains(key) == true) {
					for (final Integer candidate : values) {
						if (deadNodes.contains(candidate) == false && processedSet.contains(candidate) == false
								&& outputNodeIndices.contains(candidate) == false) {
							layerCandidates.add(candidate);
						}
					}
				}
			}

			/**
			 * We need to ensure that all the nodes pointed at the candidate are either a
			 * dead node (and we don't care) or is already in the processedSet
			 */
			for (final Integer candidate : layerCandidates) {
				final var backwardLinks = backwardConnections.getOrDefault(candidate, Set.of());

				final boolean allBackwardInEndSet = backwardLinks.stream()
						.allMatch(next -> processedSet.contains(next) || deadNodes.contains(next));

				if (allBackwardInEndSet) {
					layer.add(candidate);
				}
			}

			if (layer.size() == 0) {
				done = true;
				layer.addAll(outputNodeIndices);
			} else {
				processedSet.addAll(layer);
			}
			layers.add(layer);
		}
		return layers;
	}

	public static float compatibilityDistance(final List<Connection> firstConnections,
			final List<Connection> secondConnections, final float c1, final float c2, final float c3) {
		if (firstConnections == null || secondConnections == null) {
			return Float.MAX_VALUE;
		}

		/**
		 * Both connections are expected to already be sorted
		 */

		final int maxConnectionSize = Math.max(firstConnections.size(), secondConnections.size());
		final float n = maxConnectionSize < 20 ? 1.0f : maxConnectionSize;

		int disjointGenes = 0;

		float sumWeightDifference = 0;
		int numMatchingGenes = 0;

		int indexFirst = 0;
		int indexSecond = 0;

		while (indexFirst < firstConnections.size() && indexSecond < secondConnections.size()) {

			final Connection firstConnection = firstConnections.get(indexFirst);
			final int firstInnovation = firstConnection.innovation();

			final Connection secondConnection = secondConnections.get(indexSecond);
			final int secondInnovation = secondConnection.innovation();

			if (firstInnovation == secondInnovation) {
				sumWeightDifference += Math.abs(secondConnection.weight() - firstConnection.weight());
				numMatchingGenes++;

				indexFirst++;
				indexSecond++;
			} else {

				disjointGenes++;

				if (firstInnovation < secondInnovation) {
					indexFirst++;
				} else {
					indexSecond++;
				}
			}
		}

		int excessGenes = 0;
		/**
		 * We have consumed all elements from secondConnections and thus have their
		 * remaining difference as excess genes
		 */
		if (indexFirst < firstConnections.size()) {
			excessGenes += firstConnections.size() - indexSecond;
		} else if (indexSecond < secondConnections.size()) {
			excessGenes += secondConnections.size() - indexFirst;
		}

		final float averageWeightDifference = sumWeightDifference / Math.max(1, numMatchingGenes);

		return (c1 * excessGenes) / n + (c2 * disjointGenes) / n + c3 * averageWeightDifference;
	}

	public static float compatibilityDistance(final Genotype genotype1, final Genotype genotype2,
			final int chromosomeIndex, final float c1, final float c2, final float c3) {
		Validate.notNull(genotype1);
		Validate.notNull(genotype2);
		Validate.isTrue(chromosomeIndex >= 0);
		Validate.isTrue(chromosomeIndex < genotype1.getSize());
		Validate.isTrue(chromosomeIndex < genotype2.getSize());

		final var neatChromosome1 = genotype1.getChromosome(chromosomeIndex, NeatChromosome.class);
		final var connections1 = neatChromosome1.getConnections();

		final var neatChromosome2 = genotype2.getChromosome(chromosomeIndex, NeatChromosome.class);
		final var connections2 = neatChromosome2.getConnections();

		return compatibilityDistance(connections1, connections2, c1, c2, c3);
	}

	public static <T extends Comparable<T>> List<Species<T>> speciate(final RandomGenerator random,
			final SpeciesIdGenerator speciesIdGenerator, final List<Species<T>> seedSpecies,
			final Population<T> population, final BiPredicate<Individual<T>, Individual<T>> speciesPredicate) {
		Validate.notNull(random);
		Validate.notNull(speciesIdGenerator);
		Validate.notNull(seedSpecies);
		Validate.notNull(population);
		Validate.notNull(speciesPredicate);

		final List<Species<T>> species = new ArrayList<>();

		for (final Species<T> speciesIterator : seedSpecies) {
			final var speciesId = speciesIterator.getId();
			final int numMembers = speciesIterator.getNumMembers();
			if (numMembers > 0) {
				final int randomIndex = random.nextInt(numMembers);
				final var newAncestors = List.of(speciesIterator.getMembers()
						.get(randomIndex));
				final var newSpecies = new Species<>(speciesId, newAncestors);
				species.add(newSpecies);
			}
		}

		for (final Individual<T> individual : population) {

			boolean existingSpeciesFound = false;
			int currentSpeciesIndex = 0;
			while (existingSpeciesFound == false && currentSpeciesIndex < species.size()) {

				final var currentSpecies = species.get(currentSpeciesIndex);

				final boolean anyAncestorMatch = currentSpecies.getAncestors()
						.stream()
						.anyMatch(candidate -> speciesPredicate.test(individual, candidate));

				final boolean anyMemberMatch = currentSpecies.getMembers()
						.stream()
						.anyMatch(candidate -> speciesPredicate.test(individual, candidate));

				if (anyAncestorMatch || anyMemberMatch) {
					currentSpecies.addMember(individual);
					existingSpeciesFound = true;
				} else {
					currentSpeciesIndex++;
				}
			}

			if (existingSpeciesFound == false) {
				final int newSpeciesId = speciesIdGenerator.computeNewId();
				final var newSpecies = new Species<T>(newSpeciesId, List.of());
				newSpecies.addMember(individual);
				species.add(newSpecies);
			}
		}

		return species.stream()
				.filter(sp -> sp.getNumMembers() > 0)
				.toList();
	}
}