1 package net.bmahe.genetics4j.neat.mutation.chromosome;
2
3 import java.util.ArrayList;
4 import java.util.Comparator;
5 import java.util.List;
6 import java.util.random.RandomGenerator;
7
8 import org.apache.commons.lang3.Validate;
9
10 import net.bmahe.genetics4j.neat.Connection;
11 import net.bmahe.genetics4j.neat.InnovationManager;
12 import net.bmahe.genetics4j.neat.chromosomes.NeatChromosome;
13 import net.bmahe.genetics4j.neat.spec.mutation.AddNode;
14
15 public class NeatChromosomeAddNodeMutationHandler extends AbstractNeatChromosomeConnectionMutationHandler<AddNode> {
16
17 private final RandomGenerator randomGenerator;
18 private final InnovationManager innovationManager;
19
20 public NeatChromosomeAddNodeMutationHandler(final RandomGenerator _randomGenerator,
21 final InnovationManager _innovationManager) {
22 super(AddNode.class, _randomGenerator);
23 Validate.notNull(_randomGenerator);
24 Validate.notNull(_innovationManager);
25
26 this.randomGenerator = _randomGenerator;
27 this.innovationManager = _innovationManager;
28 }
29
30 @Override
31 protected List<Connection> mutateConnection(final AddNode mutationPolicy, final NeatChromosome neatChromosome,
32 final Connection oldConnection, final int i) {
33
34 final List<Connection> connections = new ArrayList<>();
35
36 final var disabledConnection = Connection.builder()
37 .from(oldConnection)
38 .isEnabled(false)
39 .build();
40 connections.add(disabledConnection);
41
42 final int maxNodeConnectionsValue = neatChromosome.getConnections()
43 .stream()
44 .map(connection -> Math.max(connection.fromNodeIndex(), connection.toNodeIndex()))
45 .max(Comparator.naturalOrder())
46 .orElse(0);
47
48 final int maxNodeValue = Math.max(maxNodeConnectionsValue,
49 neatChromosome.getNumInputs() + neatChromosome.getNumOutputs() - 1);
50
51 final int newNodeValue = maxNodeValue + 1;
52
53 final int firstInnovation = innovationManager.computeNewId(oldConnection.fromNodeIndex(), newNodeValue);
54 final var firstConnection = Connection.builder()
55 .from(oldConnection)
56 .weight(1.0f)
57 .toNodeIndex(newNodeValue)
58 .innovation(firstInnovation)
59 .build();
60 connections.add(firstConnection);
61
62 final int secondInnovation = innovationManager.computeNewId(newNodeValue, oldConnection.toNodeIndex());
63 final var secondConnection = Connection.builder()
64 .from(oldConnection)
65 .fromNodeIndex(newNodeValue)
66 .innovation(secondInnovation)
67 .build();
68 connections.add(secondConnection);
69
70 return connections;
71 }
72
73 }