1 | package net.bmahe.genetics4j.neat; | |
2 | ||
3 | import java.util.HashMap; | |
4 | import java.util.List; | |
5 | import java.util.Map; | |
6 | import java.util.Set; | |
7 | import java.util.function.Function; | |
8 | ||
9 | import org.apache.commons.collections4.CollectionUtils; | |
10 | import org.apache.commons.lang3.Validate; | |
11 | import org.apache.logging.log4j.LogManager; | |
12 | import org.apache.logging.log4j.Logger; | |
13 | ||
14 | /** | |
15 | * Implements a feed-forward neural network for evaluating NEAT (NeuroEvolution of Augmenting Topologies) chromosomes. | |
16 | * | |
17 | * <p>FeedForwardNetwork provides a computational engine for executing neural networks evolved by the NEAT algorithm. It | |
18 | * takes a network topology defined by connections and nodes, organizes them into computational layers, and provides | |
19 | * efficient forward propagation for fitness evaluation. The network supports arbitrary topologies with variable numbers | |
20 | * of hidden layers and connections. | |
21 | * | |
22 | * <p>Key features: | |
23 | * <ul> | |
24 | * <li><strong>Dynamic topology</strong>: Supports arbitrary network structures evolved by NEAT</li> | |
25 | * <li><strong>Layer-based evaluation</strong>: Automatically computes optimal evaluation order</li> | |
26 | * <li><strong>Configurable activation</strong>: Supports any activation function for hidden and output nodes</li> | |
27 | * <li><strong>Efficient propagation</strong>: Optimized forward pass through network layers</li> | |
28 | * </ul> | |
29 | * | |
30 | * <p>Network evaluation process: | |
31 | * <ol> | |
32 | * <li><strong>Input assignment</strong>: Input values are assigned to input nodes</li> | |
33 | * <li><strong>Layer computation</strong>: Each layer is computed in topological order</li> | |
34 | * <li><strong>Node activation</strong>: Each node applies weighted sum followed by activation function</li> | |
35 | * <li><strong>Output extraction</strong>: Output values are collected from designated output nodes</li> | |
36 | * </ol> | |
37 | * | |
38 | * <p>Network construction workflow: | |
39 | * <ul> | |
40 | * <li><strong>Topology analysis</strong>: Network connections are analyzed to determine layer structure</li> | |
41 | * <li><strong>Layer partitioning</strong>: Nodes are organized into evaluation layers using topological sorting</li> | |
42 | * <li><strong>Connection mapping</strong>: Backward connections are precomputed for efficient evaluation</li> | |
43 | * <li><strong>Dead node removal</strong>: Unreachable nodes are excluded from computation</li> | |
44 | * </ul> | |
45 | * | |
46 | * <p>Common usage patterns: | |
47 | * | |
48 | * <pre>{@code | |
49 | * // Create network from NEAT chromosome | |
50 | * NeatChromosome chromosome = // ... obtain from evolution | |
51 | * Set<Integer> inputNodes = Set.of(0, 1, 2); | |
52 | * Set<Integer> outputNodes = Set.of(3, 4); | |
53 | * Function<Float, Float> activation = Activations::sigmoid; | |
54 | * | |
55 | * FeedForwardNetwork network = new FeedForwardNetwork( | |
56 | * inputNodes, outputNodes, chromosome.getConnections(), activation | |
57 | * ); | |
58 | * | |
59 | * // Evaluate network on input data | |
60 | * Map<Integer, Float> inputs = Map.of(0, 1.0f, 1, 0.5f, 2, -0.3f); | |
61 | * Map<Integer, Float> outputs = network.compute(inputs); | |
62 | * | |
63 | * // Extract specific outputs | |
64 | * float output1 = outputs.get(3); | |
65 | * float output2 = outputs.get(4); | |
66 | * }</pre> | |
67 | * | |
68 | * <p>Activation function integration: | |
69 | * <ul> | |
70 | * <li><strong>Sigmoid activation</strong>: Standard logistic function for binary classification</li> | |
71 | * <li><strong>Tanh activation</strong>: Hyperbolic tangent for continuous outputs</li> | |
72 | * <li><strong>Linear activation</strong>: Identity function for regression problems</li> | |
73 | * <li><strong>Custom functions</strong>: Any Function<Float, Float> can be used</li> | |
74 | * </ul> | |
75 | * | |
76 | * <p>Performance optimizations: | |
77 | * <ul> | |
78 | * <li><strong>Layer precomputation</strong>: Network layers are computed once during construction</li> | |
79 | * <li><strong>Connection mapping</strong>: Backward connections are precomputed for fast lookup</li> | |
80 | * <li><strong>Dead node elimination</strong>: Unreachable nodes are excluded from evaluation</li> | |
81 | * <li><strong>Efficient propagation</strong>: Only enabled connections participate in computation</li> | |
82 | * </ul> | |
83 | * | |
84 | * <p>Error handling and validation: | |
85 | * <ul> | |
86 | * <li><strong>Input validation</strong>: Ensures all input nodes receive values</li> | |
87 | * <li><strong>Output validation</strong>: Verifies all output nodes produce values</li> | |
88 | * <li><strong>Topology validation</strong>: Validates network structure during construction</li> | |
89 | * <li><strong>Connection consistency</strong>: Ensures connection endpoints reference valid nodes</li> | |
90 | * </ul> | |
91 | * | |
92 | * <p>Integration with NEAT evolution: | |
93 | * <ul> | |
94 | * <li><strong>Chromosome evaluation</strong>: Converts NEAT chromosomes to executable networks</li> | |
95 | * <li><strong>Fitness computation</strong>: Provides network output for fitness evaluation</li> | |
96 | * <li><strong>Topology evolution</strong>: Supports networks with varying structure complexity</li> | |
97 | * <li><strong>Innovation tracking</strong>: Works with networks containing historical innovations</li> | |
98 | * </ul> | |
99 | * | |
100 | * @see NeatChromosome | |
101 | * @see Connection | |
102 | * @see Activations | |
103 | * @see NeatUtils#partitionLayersNodes | |
104 | */ | |
105 | public class FeedForwardNetwork { | |
106 | public static final Logger logger = LogManager.getLogger(FeedForwardNetwork.class); | |
107 | ||
108 | private final Set<Integer> inputNodeIndices; | |
109 | private final Set<Integer> outputNodeIndices; | |
110 | private final List<Connection> connections; | |
111 | ||
112 | private final List<List<Integer>> layers; | |
113 | private final Map<Integer, Set<Connection>> backwardConnections; | |
114 | ||
115 | private final Function<Float, Float> activationFunction; | |
116 | ||
117 | /** | |
118 | * Constructs a new feed-forward network with the specified topology and activation function. | |
119 | * | |
120 | * <p>The constructor analyzes the network topology, computes evaluation layers using topological sorting, and | |
121 | * precomputes connection mappings for efficient forward propagation. The network is immediately ready for evaluation | |
122 | * after construction. | |
123 | * | |
124 | * @param _inputNodeIndices set of input node indices | |
125 | * @param _outputNodeIndices set of output node indices | |
126 | * @param _connections list of network connections defining the topology | |
127 | * @param _activationFunction activation function to apply to hidden and output nodes | |
128 | * @throws IllegalArgumentException if any parameter is null or empty | |
129 | */ | |
130 | public FeedForwardNetwork(final Set<Integer> _inputNodeIndices, final Set<Integer> _outputNodeIndices, | |
131 | final List<Connection> _connections, final Function<Float, Float> _activationFunction) { | |
132 | Validate.isTrue(CollectionUtils.isNotEmpty(_inputNodeIndices)); | |
133 | Validate.isTrue(CollectionUtils.isNotEmpty(_outputNodeIndices)); | |
134 | Validate.isTrue(CollectionUtils.isNotEmpty(_connections)); | |
135 | Validate.notNull(_activationFunction); | |
136 | ||
137 |
1
1. <init> : Removed assignment to member variable inputNodeIndices → KILLED |
this.inputNodeIndices = _inputNodeIndices; |
138 |
1
1. <init> : Removed assignment to member variable outputNodeIndices → KILLED |
this.outputNodeIndices = _outputNodeIndices; |
139 |
1
1. <init> : Removed assignment to member variable connections → KILLED |
this.connections = _connections; |
140 |
1
1. <init> : Removed assignment to member variable activationFunction → KILLED |
this.activationFunction = _activationFunction; |
141 | ||
142 |
3
1. <init> : removed call to net/bmahe/genetics4j/neat/NeatUtils::partitionLayersNodes → KILLED 2. <init> : replaced call to net/bmahe/genetics4j/neat/NeatUtils::partitionLayersNodes with argument → KILLED 3. <init> : Removed assignment to member variable layers → KILLED |
this.layers = NeatUtils.partitionLayersNodes(this.inputNodeIndices, this.outputNodeIndices, this.connections); |
143 |
2
1. <init> : Removed assignment to member variable backwardConnections → KILLED 2. <init> : removed call to net/bmahe/genetics4j/neat/NeatUtils::computeBackwardConnections → KILLED |
this.backwardConnections = NeatUtils.computeBackwardConnections(this.connections); |
144 | } | |
145 | ||
146 | /** | |
147 | * Computes the network output for the given input values. | |
148 | * | |
149 | * <p>This method performs forward propagation through the network, computing node activations layer by layer in | |
150 | * topological order. Input values are assigned to input nodes, then each subsequent layer is computed by applying | |
151 | * weighted sums and activation functions. | |
152 | * | |
153 | * <p>The computation process: | |
154 | * <ol> | |
155 | * <li>Input values are assigned to input nodes</li> | |
156 | * <li>For each layer (starting from first hidden layer):</li> | |
157 | * <li>For each node in the layer:</li> | |
158 | * <li>Compute weighted sum of inputs from previous layers</li> | |
159 | * <li>Apply activation function to the sum</li> | |
160 | * <li>Store the result for use in subsequent layers</li> | |
161 | * <li>Extract and return output values from output nodes</li> | |
162 | * </ol> | |
163 | * | |
164 | * @param inputValues mapping from input node indices to their values | |
165 | * @return mapping from output node indices to their computed values | |
166 | * @throws IllegalArgumentException if inputValues is null, has wrong size, or missing required inputs | |
167 | */ | |
168 | public Map<Integer, Float> compute(final Map<Integer, Float> inputValues) { | |
169 | Validate.notNull(inputValues); | |
170 | Validate.isTrue(inputValues.size() == inputNodeIndices.size()); | |
171 | ||
172 |
1
1. compute : removed call to java/util/HashMap::<init> → KILLED |
final Map<Integer, Float> nodeValues = new HashMap<>(); |
173 | ||
174 | for (final Integer inputNodeIndex : inputNodeIndices) { | |
175 |
2
1. compute : replaced call to java/util/Map::get with argument → KILLED 2. compute : removed call to java/util/Map::get → KILLED |
Float nodeValue = inputValues.get(inputNodeIndex); |
176 |
3
1. compute : removed conditional - replaced equality check with false → SURVIVED 2. compute : removed conditional - replaced equality check with true → KILLED 3. compute : negated conditional → KILLED |
if (nodeValue == null) { |
177 |
1
1. compute : removed call to java/lang/IllegalArgumentException::<init> → NO_COVERAGE |
throw new IllegalArgumentException("Input vector missing values for input node " + inputNodeIndex); |
178 | } | |
179 |
2
1. compute : replaced call to java/util/Map::put with argument → KILLED 2. compute : removed call to java/util/Map::put → KILLED |
nodeValues.put(inputNodeIndex, nodeValue); |
180 | } | |
181 | ||
182 |
1
1. compute : Substituted 1 with 0 → KILLED |
int layerIndex = 1; |
183 |
5
1. compute : changed conditional boundary → KILLED 2. compute : removed call to java/util/List::size → KILLED 3. compute : removed conditional - replaced comparison check with true → KILLED 4. compute : negated conditional → KILLED 5. compute : removed conditional - replaced comparison check with false → KILLED |
while (layerIndex < layers.size()) { |
184 | ||
185 |
1
1. compute : removed call to java/util/List::get → KILLED |
final List<Integer> layer = layers.get(layerIndex); |
186 | ||
187 |
4
1. compute : removed conditional - replaced equality check with true → SURVIVED 2. compute : removed conditional - replaced equality check with false → KILLED 3. compute : negated conditional → KILLED 4. compute : removed call to org/apache/commons/collections4/CollectionUtils::isNotEmpty → KILLED |
if (CollectionUtils.isNotEmpty(layer)) { |
188 | ||
189 | for (Integer nodeIndex : layer) { | |
190 |
1
1. compute : Substituted 0.0 with 1.0 → KILLED |
float sum = 0.0f; |
191 |
3
1. compute : removed call to java/util/Set::of → SURVIVED 2. compute : removed call to java/util/Map::getOrDefault → KILLED 3. compute : replaced call to java/util/Map::getOrDefault with argument → KILLED |
final var incomingNodes = backwardConnections.getOrDefault(nodeIndex, Set.of()); |
192 | for (final Connection incomingConnection : incomingNodes) { | |
193 |
5
1. compute : removed conditional - replaced equality check with false → SURVIVED 2. compute : removed call to java/lang/Integer::intValue → KILLED 3. compute : removed conditional - replaced equality check with true → KILLED 4. compute : removed call to net/bmahe/genetics4j/neat/Connection::toNodeIndex → KILLED 5. compute : negated conditional → KILLED |
if (incomingConnection.toNodeIndex() != nodeIndex) { |
194 |
1
1. compute : removed call to java/lang/IllegalStateException::<init> → NO_COVERAGE |
throw new IllegalStateException(); |
195 | } | |
196 | ||
197 | // Incoming connection may have been disabled and dangling | |
198 |
6
1. compute : removed conditional - replaced equality check with true → SURVIVED 2. compute : removed call to net/bmahe/genetics4j/neat/Connection::fromNodeIndex → SURVIVED 3. compute : negated conditional → KILLED 4. compute : removed call to java/lang/Integer::valueOf → KILLED 5. compute : removed conditional - replaced equality check with false → KILLED 6. compute : removed call to java/util/Map::containsKey → KILLED |
if (nodeValues.containsKey(incomingConnection.fromNodeIndex())) { |
199 |
1
1. compute : removed call to net/bmahe/genetics4j/neat/Connection::weight → KILLED |
final float weight = incomingConnection.weight(); |
200 |
5
1. compute : removed call to java/util/Map::get → KILLED 2. compute : removed call to java/lang/Float::floatValue → KILLED 3. compute : replaced call to java/util/Map::get with argument → KILLED 4. compute : removed call to net/bmahe/genetics4j/neat/Connection::fromNodeIndex → KILLED 5. compute : removed call to java/lang/Integer::valueOf → KILLED |
final float incomingNodeValue = nodeValues.get(incomingConnection.fromNodeIndex()); |
201 | ||
202 |
2
1. compute : Replaced float multiplication with division → KILLED 2. compute : Replaced float addition with subtraction → KILLED |
sum += weight * incomingNodeValue; |
203 | } | |
204 | } | |
205 |
3
1. compute : replaced call to java/util/function/Function::apply with argument → KILLED 2. compute : removed call to java/lang/Float::valueOf → KILLED 3. compute : removed call to java/util/function/Function::apply → KILLED |
final Float outputValue = activationFunction.apply(sum); |
206 |
2
1. compute : replaced call to java/util/Map::put with argument → KILLED 2. compute : removed call to java/util/Map::put → KILLED |
nodeValues.put(nodeIndex, outputValue); |
207 | } | |
208 | } | |
209 | ||
210 |
1
1. compute : Changed increment from 1 to -1 → KILLED |
layerIndex++; |
211 | } | |
212 | ||
213 |
1
1. compute : removed call to java/util/HashMap::<init> → KILLED |
final Map<Integer, Float> outputValues = new HashMap<>(); |
214 | for (final Integer outputNodeIndex : outputNodeIndices) { | |
215 |
2
1. compute : replaced call to java/util/Map::get with argument → KILLED 2. compute : removed call to java/util/Map::get → KILLED |
final Float value = nodeValues.get(outputNodeIndex); |
216 |
3
1. compute : removed conditional - replaced equality check with false → SURVIVED 2. compute : negated conditional → KILLED 3. compute : removed conditional - replaced equality check with true → KILLED |
if (value == null) { |
217 |
1
1. compute : removed call to java/lang/IllegalArgumentException::<init> → NO_COVERAGE |
throw new IllegalArgumentException("Missing output value for node " + outputNodeIndex); |
218 | } | |
219 |
2
1. compute : replaced call to java/util/Map::put with argument → KILLED 2. compute : removed call to java/util/Map::put → KILLED |
outputValues.put(outputNodeIndex, value); |
220 | } | |
221 |
1
1. compute : replaced return value with Collections.emptyMap for net/bmahe/genetics4j/neat/FeedForwardNetwork::compute → KILLED |
return outputValues; |
222 | } | |
223 | } | |
Mutations | ||
137 |
1.1 |
|
138 |
1.1 |
|
139 |
1.1 |
|
140 |
1.1 |
|
142 |
1.1 2.2 3.3 |
|
143 |
1.1 2.2 |
|
172 |
1.1 |
|
175 |
1.1 2.2 |
|
176 |
1.1 2.2 3.3 |
|
177 |
1.1 |
|
179 |
1.1 2.2 |
|
182 |
1.1 |
|
183 |
1.1 2.2 3.3 4.4 5.5 |
|
185 |
1.1 |
|
187 |
1.1 2.2 3.3 4.4 |
|
190 |
1.1 |
|
191 |
1.1 2.2 3.3 |
|
193 |
1.1 2.2 3.3 4.4 5.5 |
|
194 |
1.1 |
|
198 |
1.1 2.2 3.3 4.4 5.5 6.6 |
|
199 |
1.1 |
|
200 |
1.1 2.2 3.3 4.4 5.5 |
|
202 |
1.1 2.2 |
|
205 |
1.1 2.2 3.3 |
|
206 |
1.1 2.2 |
|
210 |
1.1 |
|
213 |
1.1 |
|
215 |
1.1 2.2 |
|
216 |
1.1 2.2 3.3 |
|
217 |
1.1 |
|
219 |
1.1 2.2 |
|
221 |
1.1 |