NeatChromosomeDeleteNodeMutationHandler.java

package net.bmahe.genetics4j.neat.mutation.chromosome;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.random.RandomGenerator;
import java.util.stream.Stream;

import org.apache.commons.lang3.Validate;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

import net.bmahe.genetics4j.core.chromosomes.Chromosome;
import net.bmahe.genetics4j.core.mutation.chromosome.ChromosomeMutationHandler;
import net.bmahe.genetics4j.core.spec.chromosome.ChromosomeSpec;
import net.bmahe.genetics4j.core.spec.mutation.MutationPolicy;
import net.bmahe.genetics4j.neat.Connection;
import net.bmahe.genetics4j.neat.chromosomes.NeatChromosome;
import net.bmahe.genetics4j.neat.spec.NeatChromosomeSpec;
import net.bmahe.genetics4j.neat.spec.mutation.DeleteNode;

public class NeatChromosomeDeleteNodeMutationHandler implements ChromosomeMutationHandler<NeatChromosome> {

	public static final Logger logger = LogManager.getLogger(NeatChromosomeDeleteNodeMutationHandler.class);

	private final RandomGenerator randomGenerator;

	public NeatChromosomeDeleteNodeMutationHandler(final RandomGenerator _randomGenerator) {
		Validate.notNull(_randomGenerator);

		this.randomGenerator = _randomGenerator;
	}

	@Override
	public boolean canHandle(final MutationPolicy mutationPolicy, final ChromosomeSpec chromosome) {
		Validate.notNull(mutationPolicy);
		Validate.notNull(chromosome);

		return mutationPolicy instanceof DeleteNode && chromosome instanceof NeatChromosomeSpec;
	}

	@Override
	public NeatChromosome mutate(final MutationPolicy mutationPolicy, final Chromosome chromosome) {
		Validate.notNull(mutationPolicy);
		Validate.notNull(chromosome);
		Validate.isInstanceOf(DeleteNode.class, mutationPolicy);
		Validate.isInstanceOf(NeatChromosome.class, chromosome);

		final var neatChromosome = (NeatChromosome) chromosome;
		final var numInputs = neatChromosome.getNumInputs();
		final var numOutputs = neatChromosome.getNumOutputs();
		final var minValue = neatChromosome.getMinWeightValue();
		final var maxValue = neatChromosome.getMaxWeightValue();

		final var oldConnections = neatChromosome.getConnections();

		final Set<Integer> inoutNodes = new HashSet<>();
		inoutNodes.addAll(neatChromosome.getInputNodeIndices());
		inoutNodes.addAll(neatChromosome.getOutputNodeIndices());

		final List<Integer> allNodeValues = neatChromosome.getConnections()
				.stream()
				.flatMap(connection -> Stream.of(connection.fromNodeIndex(), connection.toNodeIndex()))
				.filter(nodeIndex -> inoutNodes.contains(nodeIndex) == false)
				.toList();

		final Set<Integer> nodeValues = Set.copyOf(allNodeValues);

		final List<Connection> newConnections = switch (nodeValues.size()) {
			case 0 -> new ArrayList<>(oldConnections);
			default -> {
				final int nodeIndexToRemove = nodeValues.size() > 1 ? randomGenerator.nextInt(nodeValues.size() - 1) : 0;

				final int nodeValueToRemove = nodeValues.stream()
						.skip(nodeIndexToRemove)
						.findFirst()
						.get();

				yield oldConnections.stream()
						.filter(connection -> connection.fromNodeIndex() != nodeValueToRemove
								&& connection.toNodeIndex() != nodeValueToRemove)
						.toList();
			}
		};

		return new NeatChromosome(numInputs, numOutputs, minValue, maxValue, newConnections);
	}
}