NeatChromosomeAddConnection.java
package net.bmahe.genetics4j.neat.mutation.chromosome;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.random.RandomGenerator;
import org.apache.commons.lang3.Validate;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import net.bmahe.genetics4j.core.chromosomes.Chromosome;
import net.bmahe.genetics4j.core.mutation.chromosome.ChromosomeMutationHandler;
import net.bmahe.genetics4j.core.spec.chromosome.ChromosomeSpec;
import net.bmahe.genetics4j.core.spec.mutation.MutationPolicy;
import net.bmahe.genetics4j.neat.Connection;
import net.bmahe.genetics4j.neat.InnovationManager;
import net.bmahe.genetics4j.neat.chromosomes.NeatChromosome;
import net.bmahe.genetics4j.neat.spec.NeatChromosomeSpec;
import net.bmahe.genetics4j.neat.spec.mutation.AddConnection;
public class NeatChromosomeAddConnection implements ChromosomeMutationHandler<NeatChromosome> {
public static final Logger logger = LogManager.getLogger(NeatChromosomeAddConnection.class);
private final RandomGenerator randomGenerator;
private final InnovationManager innovationManager;
public NeatChromosomeAddConnection(final RandomGenerator _randomGenerator,
final InnovationManager _innovationManager) {
Validate.notNull(_randomGenerator);
Validate.notNull(_innovationManager);
this.randomGenerator = _randomGenerator;
this.innovationManager = _innovationManager;
}
@Override
public boolean canHandle(final MutationPolicy mutationPolicy, final ChromosomeSpec chromosome) {
Validate.notNull(mutationPolicy);
Validate.notNull(chromosome);
return mutationPolicy instanceof AddConnection && chromosome instanceof NeatChromosomeSpec;
}
@Override
public NeatChromosome mutate(final MutationPolicy mutationPolicy, final Chromosome chromosome) {
Validate.notNull(mutationPolicy);
Validate.notNull(chromosome);
Validate.isInstanceOf(AddConnection.class, mutationPolicy);
Validate.isInstanceOf(NeatChromosome.class, chromosome);
final var neatChromosome = (NeatChromosome) chromosome;
final var numInputs = neatChromosome.getNumInputs();
final var numOutputs = neatChromosome.getNumOutputs();
final var minValue = neatChromosome.getMinWeightValue();
final var maxValue = neatChromosome.getMaxWeightValue();
final var oldConnections = neatChromosome.getConnections();
final List<Connection> newConnections = new ArrayList<>(oldConnections);
final int maxNodeConnectionsValue = neatChromosome.getConnections()
.stream()
.map(connection -> Math.max(connection.fromNodeIndex(), connection.toNodeIndex()))
.max(Comparator.naturalOrder())
.orElse(0);
final int maxNodeValue = Math.max(maxNodeConnectionsValue,
neatChromosome.getNumInputs() + neatChromosome.getNumOutputs() - 1);
final int fromNode = randomGenerator.nextInt(maxNodeValue + 1);
final int toNode = randomGenerator.nextInt(maxNodeValue + 1);
final boolean isConnectionExist = oldConnections.stream()
.anyMatch(connection -> connection.fromNodeIndex() == fromNode && connection.toNodeIndex() == toNode);
final boolean isFromNodeAnOutput = fromNode < numInputs + numOutputs && fromNode >= numInputs;
final boolean isToNodeAnInput = toNode < numInputs;
if (fromNode != toNode && isConnectionExist == false && isToNodeAnInput == false && isFromNodeAnOutput == false) {
final int innovation = innovationManager.computeNewId(fromNode, toNode);
final var newConnection = Connection.builder()
.fromNodeIndex(fromNode)
.toNodeIndex(toNode)
.innovation(innovation)
.weight(randomGenerator.nextFloat(minValue, maxValue))
.isEnabled(true)
.build();
newConnections.add(newConnection);
}
return new NeatChromosome(numInputs, numOutputs, minValue, maxValue, newConnections);
}
}