1 | package net.bmahe.genetics4j.neat; | |
2 | ||
3 | import java.util.HashMap; | |
4 | import java.util.List; | |
5 | import java.util.Map; | |
6 | import java.util.Set; | |
7 | import java.util.function.Function; | |
8 | ||
9 | import org.apache.commons.collections4.CollectionUtils; | |
10 | import org.apache.commons.lang3.Validate; | |
11 | import org.apache.logging.log4j.LogManager; | |
12 | import org.apache.logging.log4j.Logger; | |
13 | ||
14 | /** | |
15 | * Implements a feed-forward neural network for evaluating NEAT (NeuroEvolution of Augmenting Topologies) chromosomes. | |
16 | * | |
17 | * <p>FeedForwardNetwork provides a computational engine for executing neural networks evolved by the NEAT algorithm. | |
18 | * It takes a network topology defined by connections and nodes, organizes them into computational layers, and | |
19 | * provides efficient forward propagation for fitness evaluation. The network supports arbitrary topologies with | |
20 | * variable numbers of hidden layers and connections. | |
21 | * | |
22 | * <p>Key features: | |
23 | * <ul> | |
24 | * <li><strong>Dynamic topology</strong>: Supports arbitrary network structures evolved by NEAT</li> | |
25 | * <li><strong>Layer-based evaluation</strong>: Automatically computes optimal evaluation order</li> | |
26 | * <li><strong>Configurable activation</strong>: Supports any activation function for hidden and output nodes</li> | |
27 | * <li><strong>Efficient propagation</strong>: Optimized forward pass through network layers</li> | |
28 | * </ul> | |
29 | * | |
30 | * <p>Network evaluation process: | |
31 | * <ol> | |
32 | * <li><strong>Input assignment</strong>: Input values are assigned to input nodes</li> | |
33 | * <li><strong>Layer computation</strong>: Each layer is computed in topological order</li> | |
34 | * <li><strong>Node activation</strong>: Each node applies weighted sum followed by activation function</li> | |
35 | * <li><strong>Output extraction</strong>: Output values are collected from designated output nodes</li> | |
36 | * </ol> | |
37 | * | |
38 | * <p>Network construction workflow: | |
39 | * <ul> | |
40 | * <li><strong>Topology analysis</strong>: Network connections are analyzed to determine layer structure</li> | |
41 | * <li><strong>Layer partitioning</strong>: Nodes are organized into evaluation layers using topological sorting</li> | |
42 | * <li><strong>Connection mapping</strong>: Backward connections are precomputed for efficient evaluation</li> | |
43 | * <li><strong>Dead node removal</strong>: Unreachable nodes are excluded from computation</li> | |
44 | * </ul> | |
45 | * | |
46 | * <p>Common usage patterns: | |
47 | * <pre>{@code | |
48 | * // Create network from NEAT chromosome | |
49 | * NeatChromosome chromosome = // ... obtain from evolution | |
50 | * Set<Integer> inputNodes = Set.of(0, 1, 2); | |
51 | * Set<Integer> outputNodes = Set.of(3, 4); | |
52 | * Function<Float, Float> activation = Activations::sigmoid; | |
53 | * | |
54 | * FeedForwardNetwork network = new FeedForwardNetwork( | |
55 | * inputNodes, outputNodes, chromosome.getConnections(), activation | |
56 | * ); | |
57 | * | |
58 | * // Evaluate network on input data | |
59 | * Map<Integer, Float> inputs = Map.of(0, 1.0f, 1, 0.5f, 2, -0.3f); | |
60 | * Map<Integer, Float> outputs = network.compute(inputs); | |
61 | * | |
62 | * // Extract specific outputs | |
63 | * float output1 = outputs.get(3); | |
64 | * float output2 = outputs.get(4); | |
65 | * }</pre> | |
66 | * | |
67 | * <p>Activation function integration: | |
68 | * <ul> | |
69 | * <li><strong>Sigmoid activation</strong>: Standard logistic function for binary classification</li> | |
70 | * <li><strong>Tanh activation</strong>: Hyperbolic tangent for continuous outputs</li> | |
71 | * <li><strong>Linear activation</strong>: Identity function for regression problems</li> | |
72 | * <li><strong>Custom functions</strong>: Any Function<Float, Float> can be used</li> | |
73 | * </ul> | |
74 | * | |
75 | * <p>Performance optimizations: | |
76 | * <ul> | |
77 | * <li><strong>Layer precomputation</strong>: Network layers are computed once during construction</li> | |
78 | * <li><strong>Connection mapping</strong>: Backward connections are precomputed for fast lookup</li> | |
79 | * <li><strong>Dead node elimination</strong>: Unreachable nodes are excluded from evaluation</li> | |
80 | * <li><strong>Efficient propagation</strong>: Only enabled connections participate in computation</li> | |
81 | * </ul> | |
82 | * | |
83 | * <p>Error handling and validation: | |
84 | * <ul> | |
85 | * <li><strong>Input validation</strong>: Ensures all input nodes receive values</li> | |
86 | * <li><strong>Output validation</strong>: Verifies all output nodes produce values</li> | |
87 | * <li><strong>Topology validation</strong>: Validates network structure during construction</li> | |
88 | * <li><strong>Connection consistency</strong>: Ensures connection endpoints reference valid nodes</li> | |
89 | * </ul> | |
90 | * | |
91 | * <p>Integration with NEAT evolution: | |
92 | * <ul> | |
93 | * <li><strong>Chromosome evaluation</strong>: Converts NEAT chromosomes to executable networks</li> | |
94 | * <li><strong>Fitness computation</strong>: Provides network output for fitness evaluation</li> | |
95 | * <li><strong>Topology evolution</strong>: Supports networks with varying structure complexity</li> | |
96 | * <li><strong>Innovation tracking</strong>: Works with networks containing historical innovations</li> | |
97 | * </ul> | |
98 | * | |
99 | * @see NeatChromosome | |
100 | * @see Connection | |
101 | * @see Activations | |
102 | * @see NeatUtils#partitionLayersNodes | |
103 | */ | |
104 | public class FeedForwardNetwork { | |
105 | public static final Logger logger = LogManager.getLogger(FeedForwardNetwork.class); | |
106 | ||
107 | private final Set<Integer> inputNodeIndices; | |
108 | private final Set<Integer> outputNodeIndices; | |
109 | private final List<Connection> connections; | |
110 | ||
111 | private final List<List<Integer>> layers; | |
112 | private final Map<Integer, Set<Connection>> backwardConnections; | |
113 | ||
114 | private final Function<Float, Float> activationFunction; | |
115 | ||
116 | /** | |
117 | * Constructs a new feed-forward network with the specified topology and activation function. | |
118 | * | |
119 | * <p>The constructor analyzes the network topology, computes evaluation layers using topological | |
120 | * sorting, and precomputes connection mappings for efficient forward propagation. The network | |
121 | * is immediately ready for evaluation after construction. | |
122 | * | |
123 | * @param _inputNodeIndices set of input node indices | |
124 | * @param _outputNodeIndices set of output node indices | |
125 | * @param _connections list of network connections defining the topology | |
126 | * @param _activationFunction activation function to apply to hidden and output nodes | |
127 | * @throws IllegalArgumentException if any parameter is null or empty | |
128 | */ | |
129 | public FeedForwardNetwork(final Set<Integer> _inputNodeIndices, final Set<Integer> _outputNodeIndices, | |
130 | final List<Connection> _connections, final Function<Float, Float> _activationFunction) { | |
131 | Validate.isTrue(CollectionUtils.isNotEmpty(_inputNodeIndices)); | |
132 | Validate.isTrue(CollectionUtils.isNotEmpty(_outputNodeIndices)); | |
133 | Validate.isTrue(CollectionUtils.isNotEmpty(_connections)); | |
134 | Validate.notNull(_activationFunction); | |
135 | ||
136 |
1
1. <init> : Removed assignment to member variable inputNodeIndices → KILLED |
this.inputNodeIndices = _inputNodeIndices; |
137 |
1
1. <init> : Removed assignment to member variable outputNodeIndices → KILLED |
this.outputNodeIndices = _outputNodeIndices; |
138 |
1
1. <init> : Removed assignment to member variable connections → KILLED |
this.connections = _connections; |
139 |
1
1. <init> : Removed assignment to member variable activationFunction → KILLED |
this.activationFunction = _activationFunction; |
140 | ||
141 |
3
1. <init> : removed call to net/bmahe/genetics4j/neat/NeatUtils::partitionLayersNodes → KILLED 2. <init> : replaced call to net/bmahe/genetics4j/neat/NeatUtils::partitionLayersNodes with argument → KILLED 3. <init> : Removed assignment to member variable layers → KILLED |
this.layers = NeatUtils.partitionLayersNodes(this.inputNodeIndices, this.outputNodeIndices, this.connections); |
142 |
2
1. <init> : Removed assignment to member variable backwardConnections → KILLED 2. <init> : removed call to net/bmahe/genetics4j/neat/NeatUtils::computeBackwardConnections → KILLED |
this.backwardConnections = NeatUtils.computeBackwardConnections(this.connections); |
143 | } | |
144 | ||
145 | /** | |
146 | * Computes the network output for the given input values. | |
147 | * | |
148 | * <p>This method performs forward propagation through the network, computing node activations | |
149 | * layer by layer in topological order. Input values are assigned to input nodes, then each | |
150 | * subsequent layer is computed by applying weighted sums and activation functions. | |
151 | * | |
152 | * <p>The computation process: | |
153 | * <ol> | |
154 | * <li>Input values are assigned to input nodes</li> | |
155 | * <li>For each layer (starting from first hidden layer):</li> | |
156 | * <li> For each node in the layer:</li> | |
157 | * <li> Compute weighted sum of inputs from previous layers</li> | |
158 | * <li> Apply activation function to the sum</li> | |
159 | * <li> Store the result for use in subsequent layers</li> | |
160 | * <li>Extract and return output values from output nodes</li> | |
161 | * </ol> | |
162 | * | |
163 | * @param inputValues mapping from input node indices to their values | |
164 | * @return mapping from output node indices to their computed values | |
165 | * @throws IllegalArgumentException if inputValues is null, has wrong size, or missing required inputs | |
166 | */ | |
167 | public Map<Integer, Float> compute(final Map<Integer, Float> inputValues) { | |
168 | Validate.notNull(inputValues); | |
169 | Validate.isTrue(inputValues.size() == inputNodeIndices.size()); | |
170 | ||
171 |
1
1. compute : removed call to java/util/HashMap::<init> → KILLED |
final Map<Integer, Float> nodeValues = new HashMap<>(); |
172 | ||
173 | for (final Integer inputNodeIndex : inputNodeIndices) { | |
174 |
2
1. compute : replaced call to java/util/Map::get with argument → KILLED 2. compute : removed call to java/util/Map::get → KILLED |
Float nodeValue = inputValues.get(inputNodeIndex); |
175 |
3
1. compute : removed conditional - replaced equality check with false → SURVIVED 2. compute : removed conditional - replaced equality check with true → KILLED 3. compute : negated conditional → KILLED |
if (nodeValue == null) { |
176 |
1
1. compute : removed call to java/lang/IllegalArgumentException::<init> → NO_COVERAGE |
throw new IllegalArgumentException("Input vector missing values for input node " + inputNodeIndex); |
177 | } | |
178 |
2
1. compute : replaced call to java/util/Map::put with argument → KILLED 2. compute : removed call to java/util/Map::put → KILLED |
nodeValues.put(inputNodeIndex, nodeValue); |
179 | } | |
180 | ||
181 |
1
1. compute : Substituted 1 with 0 → KILLED |
int layerIndex = 1; |
182 |
5
1. compute : changed conditional boundary → KILLED 2. compute : removed call to java/util/List::size → KILLED 3. compute : removed conditional - replaced comparison check with true → KILLED 4. compute : negated conditional → KILLED 5. compute : removed conditional - replaced comparison check with false → KILLED |
while (layerIndex < layers.size()) { |
183 | ||
184 |
1
1. compute : removed call to java/util/List::get → KILLED |
final List<Integer> layer = layers.get(layerIndex); |
185 | ||
186 |
4
1. compute : removed conditional - replaced equality check with true → SURVIVED 2. compute : removed conditional - replaced equality check with false → KILLED 3. compute : negated conditional → KILLED 4. compute : removed call to org/apache/commons/collections4/CollectionUtils::isNotEmpty → KILLED |
if (CollectionUtils.isNotEmpty(layer)) { |
187 | ||
188 | for (Integer nodeIndex : layer) { | |
189 |
1
1. compute : Substituted 0.0 with 1.0 → KILLED |
float sum = 0.0f; |
190 |
3
1. compute : removed call to java/util/Set::of → SURVIVED 2. compute : removed call to java/util/Map::getOrDefault → KILLED 3. compute : replaced call to java/util/Map::getOrDefault with argument → KILLED |
final var incomingNodes = backwardConnections.getOrDefault(nodeIndex, Set.of()); |
191 | for (final Connection incomingConnection : incomingNodes) { | |
192 |
5
1. compute : removed conditional - replaced equality check with false → SURVIVED 2. compute : removed call to java/lang/Integer::intValue → KILLED 3. compute : removed conditional - replaced equality check with true → KILLED 4. compute : removed call to net/bmahe/genetics4j/neat/Connection::toNodeIndex → KILLED 5. compute : negated conditional → KILLED |
if (incomingConnection.toNodeIndex() != nodeIndex) { |
193 |
1
1. compute : removed call to java/lang/IllegalStateException::<init> → NO_COVERAGE |
throw new IllegalStateException(); |
194 | } | |
195 | ||
196 | // Incoming connection may have been disabled and dangling | |
197 |
6
1. compute : removed conditional - replaced equality check with true → SURVIVED 2. compute : removed call to net/bmahe/genetics4j/neat/Connection::fromNodeIndex → SURVIVED 3. compute : negated conditional → KILLED 4. compute : removed call to java/lang/Integer::valueOf → KILLED 5. compute : removed conditional - replaced equality check with false → KILLED 6. compute : removed call to java/util/Map::containsKey → KILLED |
if (nodeValues.containsKey(incomingConnection.fromNodeIndex())) { |
198 |
1
1. compute : removed call to net/bmahe/genetics4j/neat/Connection::weight → KILLED |
final float weight = incomingConnection.weight(); |
199 |
5
1. compute : removed call to java/util/Map::get → KILLED 2. compute : removed call to java/lang/Float::floatValue → KILLED 3. compute : replaced call to java/util/Map::get with argument → KILLED 4. compute : removed call to net/bmahe/genetics4j/neat/Connection::fromNodeIndex → KILLED 5. compute : removed call to java/lang/Integer::valueOf → KILLED |
final float incomingNodeValue = nodeValues.get(incomingConnection.fromNodeIndex()); |
200 | ||
201 |
2
1. compute : Replaced float multiplication with division → KILLED 2. compute : Replaced float addition with subtraction → KILLED |
sum += weight * incomingNodeValue; |
202 | } | |
203 | } | |
204 |
3
1. compute : replaced call to java/util/function/Function::apply with argument → KILLED 2. compute : removed call to java/lang/Float::valueOf → KILLED 3. compute : removed call to java/util/function/Function::apply → KILLED |
final Float outputValue = activationFunction.apply(sum); |
205 |
2
1. compute : replaced call to java/util/Map::put with argument → KILLED 2. compute : removed call to java/util/Map::put → KILLED |
nodeValues.put(nodeIndex, outputValue); |
206 | } | |
207 | } | |
208 | ||
209 |
1
1. compute : Changed increment from 1 to -1 → KILLED |
layerIndex++; |
210 | } | |
211 | ||
212 |
1
1. compute : removed call to java/util/HashMap::<init> → KILLED |
final Map<Integer, Float> outputValues = new HashMap<>(); |
213 | for (final Integer outputNodeIndex : outputNodeIndices) { | |
214 |
2
1. compute : replaced call to java/util/Map::get with argument → KILLED 2. compute : removed call to java/util/Map::get → KILLED |
final Float value = nodeValues.get(outputNodeIndex); |
215 |
3
1. compute : removed conditional - replaced equality check with false → SURVIVED 2. compute : negated conditional → KILLED 3. compute : removed conditional - replaced equality check with true → KILLED |
if (value == null) { |
216 |
1
1. compute : removed call to java/lang/IllegalArgumentException::<init> → NO_COVERAGE |
throw new IllegalArgumentException("Missing output value for node " + outputNodeIndex); |
217 | } | |
218 |
2
1. compute : replaced call to java/util/Map::put with argument → KILLED 2. compute : removed call to java/util/Map::put → KILLED |
outputValues.put(outputNodeIndex, value); |
219 | } | |
220 |
1
1. compute : replaced return value with Collections.emptyMap for net/bmahe/genetics4j/neat/FeedForwardNetwork::compute → KILLED |
return outputValues; |
221 | } | |
222 | } | |
Mutations | ||
136 |
1.1 |
|
137 |
1.1 |
|
138 |
1.1 |
|
139 |
1.1 |
|
141 |
1.1 2.2 3.3 |
|
142 |
1.1 2.2 |
|
171 |
1.1 |
|
174 |
1.1 2.2 |
|
175 |
1.1 2.2 3.3 |
|
176 |
1.1 |
|
178 |
1.1 2.2 |
|
181 |
1.1 |
|
182 |
1.1 2.2 3.3 4.4 5.5 |
|
184 |
1.1 |
|
186 |
1.1 2.2 3.3 4.4 |
|
189 |
1.1 |
|
190 |
1.1 2.2 3.3 |
|
192 |
1.1 2.2 3.3 4.4 5.5 |
|
193 |
1.1 |
|
197 |
1.1 2.2 3.3 4.4 5.5 6.6 |
|
198 |
1.1 |
|
199 |
1.1 2.2 3.3 4.4 5.5 |
|
201 |
1.1 2.2 |
|
204 |
1.1 2.2 3.3 |
|
205 |
1.1 2.2 |
|
209 |
1.1 |
|
212 |
1.1 |
|
214 |
1.1 2.2 |
|
215 |
1.1 2.2 3.3 |
|
216 |
1.1 |
|
218 |
1.1 2.2 |
|
220 |
1.1 |