View Javadoc
1   package net.bmahe.genetics4j.neat.mutation.chromosome;
2   
3   import java.util.ArrayList;
4   import java.util.Comparator;
5   import java.util.List;
6   import java.util.random.RandomGenerator;
7   
8   import org.apache.commons.lang3.Validate;
9   
10  import net.bmahe.genetics4j.neat.Connection;
11  import net.bmahe.genetics4j.neat.InnovationManager;
12  import net.bmahe.genetics4j.neat.chromosomes.NeatChromosome;
13  import net.bmahe.genetics4j.neat.spec.mutation.AddNode;
14  
15  public class NeatChromosomeAddNodeMutationHandler extends AbstractNeatChromosomeConnectionMutationHandler<AddNode> {
16  
17  	private final RandomGenerator randomGenerator;
18  	private final InnovationManager innovationManager;
19  
20  	public NeatChromosomeAddNodeMutationHandler(final RandomGenerator _randomGenerator,
21  			final InnovationManager _innovationManager) {
22  		super(AddNode.class, _randomGenerator);
23  		Validate.notNull(_randomGenerator);
24  		Validate.notNull(_innovationManager);
25  
26  		this.randomGenerator = _randomGenerator;
27  		this.innovationManager = _innovationManager;
28  	}
29  
30  	@Override
31  	protected List<Connection> mutateConnection(final AddNode mutationPolicy, final NeatChromosome neatChromosome,
32  			final Connection oldConnection, final int i) {
33  
34  		final List<Connection> connections = new ArrayList<>();
35  
36  		final var disabledConnection = Connection.builder()
37  				.from(oldConnection)
38  				.isEnabled(false)
39  				.build();
40  		connections.add(disabledConnection);
41  
42  		final int maxNodeConnectionsValue = neatChromosome.getConnections()
43  				.stream()
44  				.map(connection -> Math.max(connection.fromNodeIndex(), connection.toNodeIndex()))
45  				.max(Comparator.naturalOrder())
46  				.orElse(0);
47  
48  		final int maxNodeValue = Math.max(maxNodeConnectionsValue,
49  				neatChromosome.getNumInputs() + neatChromosome.getNumOutputs() - 1);
50  
51  		final int newNodeValue = maxNodeValue + 1;
52  
53  		final int firstInnovation = innovationManager.computeNewId(oldConnection.fromNodeIndex(), newNodeValue);
54  		final var firstConnection = Connection.builder()
55  				.from(oldConnection)
56  				.weight(1.0f)
57  				.toNodeIndex(newNodeValue)
58  				.innovation(firstInnovation)
59  				.build();
60  		connections.add(firstConnection);
61  
62  		final int secondInnovation = innovationManager.computeNewId(newNodeValue, oldConnection.toNodeIndex());
63  		final var secondConnection = Connection.builder()
64  				.from(oldConnection)
65  				.fromNodeIndex(newNodeValue)
66  				.innovation(secondInnovation)
67  				.build();
68  		connections.add(secondConnection);
69  
70  		return connections;
71  	}
72  
73  }