NeatChromosomeDeleteNodeMutationHandler.java

  1. package net.bmahe.genetics4j.neat.mutation.chromosome;

  2. import java.util.ArrayList;
  3. import java.util.HashSet;
  4. import java.util.List;
  5. import java.util.Set;
  6. import java.util.random.RandomGenerator;
  7. import java.util.stream.Stream;

  8. import org.apache.commons.lang3.Validate;
  9. import org.apache.logging.log4j.LogManager;
  10. import org.apache.logging.log4j.Logger;

  11. import net.bmahe.genetics4j.core.chromosomes.Chromosome;
  12. import net.bmahe.genetics4j.core.mutation.chromosome.ChromosomeMutationHandler;
  13. import net.bmahe.genetics4j.core.spec.chromosome.ChromosomeSpec;
  14. import net.bmahe.genetics4j.core.spec.mutation.MutationPolicy;
  15. import net.bmahe.genetics4j.neat.Connection;
  16. import net.bmahe.genetics4j.neat.chromosomes.NeatChromosome;
  17. import net.bmahe.genetics4j.neat.spec.NeatChromosomeSpec;
  18. import net.bmahe.genetics4j.neat.spec.mutation.DeleteNode;

  19. public class NeatChromosomeDeleteNodeMutationHandler implements ChromosomeMutationHandler<NeatChromosome> {

  20.     public static final Logger logger = LogManager.getLogger(NeatChromosomeDeleteNodeMutationHandler.class);

  21.     private final RandomGenerator randomGenerator;

  22.     public NeatChromosomeDeleteNodeMutationHandler(final RandomGenerator _randomGenerator) {
  23.         Validate.notNull(_randomGenerator);

  24.         this.randomGenerator = _randomGenerator;
  25.     }

  26.     @Override
  27.     public boolean canHandle(final MutationPolicy mutationPolicy, final ChromosomeSpec chromosome) {
  28.         Validate.notNull(mutationPolicy);
  29.         Validate.notNull(chromosome);

  30.         return mutationPolicy instanceof DeleteNode && chromosome instanceof NeatChromosomeSpec;
  31.     }

  32.     @Override
  33.     public NeatChromosome mutate(final MutationPolicy mutationPolicy, final Chromosome chromosome) {
  34.         Validate.notNull(mutationPolicy);
  35.         Validate.notNull(chromosome);
  36.         Validate.isInstanceOf(DeleteNode.class, mutationPolicy);
  37.         Validate.isInstanceOf(NeatChromosome.class, chromosome);

  38.         final var neatChromosome = (NeatChromosome) chromosome;
  39.         final var numInputs = neatChromosome.getNumInputs();
  40.         final var numOutputs = neatChromosome.getNumOutputs();
  41.         final var minValue = neatChromosome.getMinWeightValue();
  42.         final var maxValue = neatChromosome.getMaxWeightValue();

  43.         final var oldConnections = neatChromosome.getConnections();

  44.         final Set<Integer> inoutNodes = new HashSet<>();
  45.         inoutNodes.addAll(neatChromosome.getInputNodeIndices());
  46.         inoutNodes.addAll(neatChromosome.getOutputNodeIndices());

  47.         final List<Integer> allNodeValues = neatChromosome.getConnections()
  48.                 .stream()
  49.                 .flatMap(connection -> Stream.of(connection.fromNodeIndex(), connection.toNodeIndex()))
  50.                 .filter(nodeIndex -> inoutNodes.contains(nodeIndex) == false)
  51.                 .toList();

  52.         final Set<Integer> nodeValues = Set.copyOf(allNodeValues);

  53.         final List<Connection> newConnections = switch (nodeValues.size()) {
  54.             case 0 -> new ArrayList<>(oldConnections);
  55.             default -> {
  56.                 final int nodeIndexToRemove = nodeValues.size() > 1 ? randomGenerator.nextInt(nodeValues.size() - 1) : 0;

  57.                 final int nodeValueToRemove = nodeValues.stream()
  58.                         .skip(nodeIndexToRemove)
  59.                         .findFirst()
  60.                         .get();

  61.                 yield oldConnections.stream()
  62.                         .filter(connection -> connection.fromNodeIndex() != nodeValueToRemove
  63.                                 && connection.toNodeIndex() != nodeValueToRemove)
  64.                         .toList();
  65.             }
  66.         };

  67.         return new NeatChromosome(numInputs, numOutputs, minValue, maxValue, newConnections);
  68.     }
  69. }