| 1 | package net.bmahe.genetics4j.neat; | |
| 2 | ||
| 3 | import java.util.HashMap; | |
| 4 | import java.util.List; | |
| 5 | import java.util.Map; | |
| 6 | import java.util.Set; | |
| 7 | import java.util.function.Function; | |
| 8 | ||
| 9 | import org.apache.commons.collections4.CollectionUtils; | |
| 10 | import org.apache.commons.lang3.Validate; | |
| 11 | import org.apache.logging.log4j.LogManager; | |
| 12 | import org.apache.logging.log4j.Logger; | |
| 13 | ||
| 14 | /** | |
| 15 | * Implements a feed-forward neural network for evaluating NEAT (NeuroEvolution of Augmenting Topologies) chromosomes. | |
| 16 | * | |
| 17 | * <p>FeedForwardNetwork provides a computational engine for executing neural networks evolved by the NEAT algorithm. It | |
| 18 | * takes a network topology defined by connections and nodes, organizes them into computational layers, and provides | |
| 19 | * efficient forward propagation for fitness evaluation. The network supports arbitrary topologies with variable numbers | |
| 20 | * of hidden layers and connections. | |
| 21 | * | |
| 22 | * <p>Key features: | |
| 23 | * <ul> | |
| 24 | * <li><strong>Dynamic topology</strong>: Supports arbitrary network structures evolved by NEAT</li> | |
| 25 | * <li><strong>Layer-based evaluation</strong>: Automatically computes optimal evaluation order</li> | |
| 26 | * <li><strong>Configurable activation</strong>: Supports any activation function for hidden and output nodes</li> | |
| 27 | * <li><strong>Efficient propagation</strong>: Optimized forward pass through network layers</li> | |
| 28 | * </ul> | |
| 29 | * | |
| 30 | * <p>Network evaluation process: | |
| 31 | * <ol> | |
| 32 | * <li><strong>Input assignment</strong>: Input values are assigned to input nodes</li> | |
| 33 | * <li><strong>Layer computation</strong>: Each layer is computed in topological order</li> | |
| 34 | * <li><strong>Node activation</strong>: Each node applies weighted sum followed by activation function</li> | |
| 35 | * <li><strong>Output extraction</strong>: Output values are collected from designated output nodes</li> | |
| 36 | * </ol> | |
| 37 | * | |
| 38 | * <p>Network construction workflow: | |
| 39 | * <ul> | |
| 40 | * <li><strong>Topology analysis</strong>: Network connections are analyzed to determine layer structure</li> | |
| 41 | * <li><strong>Layer partitioning</strong>: Nodes are organized into evaluation layers using topological sorting</li> | |
| 42 | * <li><strong>Connection mapping</strong>: Backward connections are precomputed for efficient evaluation</li> | |
| 43 | * <li><strong>Dead node removal</strong>: Unreachable nodes are excluded from computation</li> | |
| 44 | * </ul> | |
| 45 | * | |
| 46 | * <p>Common usage patterns: | |
| 47 | * | |
| 48 | * <pre>{@code | |
| 49 | * // Create network from NEAT chromosome | |
| 50 | * NeatChromosome chromosome = // ... obtain from evolution | |
| 51 | * Set<Integer> inputNodes = Set.of(0, 1, 2); | |
| 52 | * Set<Integer> outputNodes = Set.of(3, 4); | |
| 53 | * Function<Float, Float> activation = Activations::sigmoid; | |
| 54 | * | |
| 55 | * FeedForwardNetwork network = new FeedForwardNetwork( | |
| 56 | * inputNodes, outputNodes, chromosome.getConnections(), activation | |
| 57 | * ); | |
| 58 | * | |
| 59 | * // Evaluate network on input data | |
| 60 | * Map<Integer, Float> inputs = Map.of(0, 1.0f, 1, 0.5f, 2, -0.3f); | |
| 61 | * Map<Integer, Float> outputs = network.compute(inputs); | |
| 62 | * | |
| 63 | * // Extract specific outputs | |
| 64 | * float output1 = outputs.get(3); | |
| 65 | * float output2 = outputs.get(4); | |
| 66 | * }</pre> | |
| 67 | * | |
| 68 | * <p>Activation function integration: | |
| 69 | * <ul> | |
| 70 | * <li><strong>Sigmoid activation</strong>: Standard logistic function for binary classification</li> | |
| 71 | * <li><strong>Tanh activation</strong>: Hyperbolic tangent for continuous outputs</li> | |
| 72 | * <li><strong>Linear activation</strong>: Identity function for regression problems</li> | |
| 73 | * <li><strong>Custom functions</strong>: Any Function<Float, Float> can be used</li> | |
| 74 | * </ul> | |
| 75 | * | |
| 76 | * <p>Performance optimizations: | |
| 77 | * <ul> | |
| 78 | * <li><strong>Layer precomputation</strong>: Network layers are computed once during construction</li> | |
| 79 | * <li><strong>Connection mapping</strong>: Backward connections are precomputed for fast lookup</li> | |
| 80 | * <li><strong>Dead node elimination</strong>: Unreachable nodes are excluded from evaluation</li> | |
| 81 | * <li><strong>Efficient propagation</strong>: Only enabled connections participate in computation</li> | |
| 82 | * </ul> | |
| 83 | * | |
| 84 | * <p>Error handling and validation: | |
| 85 | * <ul> | |
| 86 | * <li><strong>Input validation</strong>: Ensures all input nodes receive values</li> | |
| 87 | * <li><strong>Output validation</strong>: Verifies all output nodes produce values</li> | |
| 88 | * <li><strong>Topology validation</strong>: Validates network structure during construction</li> | |
| 89 | * <li><strong>Connection consistency</strong>: Ensures connection endpoints reference valid nodes</li> | |
| 90 | * </ul> | |
| 91 | * | |
| 92 | * <p>Integration with NEAT evolution: | |
| 93 | * <ul> | |
| 94 | * <li><strong>Chromosome evaluation</strong>: Converts NEAT chromosomes to executable networks</li> | |
| 95 | * <li><strong>Fitness computation</strong>: Provides network output for fitness evaluation</li> | |
| 96 | * <li><strong>Topology evolution</strong>: Supports networks with varying structure complexity</li> | |
| 97 | * <li><strong>Innovation tracking</strong>: Works with networks containing historical innovations</li> | |
| 98 | * </ul> | |
| 99 | * | |
| 100 | * @see NeatChromosome | |
| 101 | * @see Connection | |
| 102 | * @see Activations | |
| 103 | * @see NeatUtils#partitionLayersNodes | |
| 104 | */ | |
| 105 | public class FeedForwardNetwork { | |
| 106 | public static final Logger logger = LogManager.getLogger(FeedForwardNetwork.class); | |
| 107 | ||
| 108 | private final Set<Integer> inputNodeIndices; | |
| 109 | private final Set<Integer> outputNodeIndices; | |
| 110 | private final List<Connection> connections; | |
| 111 | ||
| 112 | private final List<List<Integer>> layers; | |
| 113 | private final Map<Integer, Set<Connection>> backwardConnections; | |
| 114 | ||
| 115 | private final Function<Float, Float> activationFunction; | |
| 116 | ||
| 117 | /** | |
| 118 | * Constructs a new feed-forward network with the specified topology and activation function. | |
| 119 | * | |
| 120 | * <p>The constructor analyzes the network topology, computes evaluation layers using topological sorting, and | |
| 121 | * precomputes connection mappings for efficient forward propagation. The network is immediately ready for evaluation | |
| 122 | * after construction. | |
| 123 | * | |
| 124 | * @param _inputNodeIndices set of input node indices | |
| 125 | * @param _outputNodeIndices set of output node indices | |
| 126 | * @param _connections list of network connections defining the topology | |
| 127 | * @param _activationFunction activation function to apply to hidden and output nodes | |
| 128 | * @throws IllegalArgumentException if any parameter is null or empty | |
| 129 | */ | |
| 130 | public FeedForwardNetwork(final Set<Integer> _inputNodeIndices, final Set<Integer> _outputNodeIndices, | |
| 131 | final List<Connection> _connections, final Function<Float, Float> _activationFunction) { | |
| 132 | Validate.isTrue(CollectionUtils.isNotEmpty(_inputNodeIndices)); | |
| 133 | Validate.isTrue(CollectionUtils.isNotEmpty(_outputNodeIndices)); | |
| 134 | Validate.isTrue(CollectionUtils.isNotEmpty(_connections)); | |
| 135 | Validate.notNull(_activationFunction); | |
| 136 | ||
| 137 |
1
1. <init> : Removed assignment to member variable inputNodeIndices → KILLED |
this.inputNodeIndices = _inputNodeIndices; |
| 138 |
1
1. <init> : Removed assignment to member variable outputNodeIndices → KILLED |
this.outputNodeIndices = _outputNodeIndices; |
| 139 |
1
1. <init> : Removed assignment to member variable connections → KILLED |
this.connections = _connections; |
| 140 |
1
1. <init> : Removed assignment to member variable activationFunction → KILLED |
this.activationFunction = _activationFunction; |
| 141 | ||
| 142 |
3
1. <init> : removed call to net/bmahe/genetics4j/neat/NeatUtils::partitionLayersNodes → KILLED 2. <init> : replaced call to net/bmahe/genetics4j/neat/NeatUtils::partitionLayersNodes with argument → KILLED 3. <init> : Removed assignment to member variable layers → KILLED |
this.layers = NeatUtils.partitionLayersNodes(this.inputNodeIndices, this.outputNodeIndices, this.connections); |
| 143 |
2
1. <init> : Removed assignment to member variable backwardConnections → KILLED 2. <init> : removed call to net/bmahe/genetics4j/neat/NeatUtils::computeBackwardConnections → KILLED |
this.backwardConnections = NeatUtils.computeBackwardConnections(this.connections); |
| 144 | } | |
| 145 | ||
| 146 | /** | |
| 147 | * Computes the network output for the given input values. | |
| 148 | * | |
| 149 | * <p>This method performs forward propagation through the network, computing node activations layer by layer in | |
| 150 | * topological order. Input values are assigned to input nodes, then each subsequent layer is computed by applying | |
| 151 | * weighted sums and activation functions. | |
| 152 | * | |
| 153 | * <p>The computation process: | |
| 154 | * <ol> | |
| 155 | * <li>Input values are assigned to input nodes</li> | |
| 156 | * <li>For each layer (starting from first hidden layer):</li> | |
| 157 | * <li>For each node in the layer:</li> | |
| 158 | * <li>Compute weighted sum of inputs from previous layers</li> | |
| 159 | * <li>Apply activation function to the sum</li> | |
| 160 | * <li>Store the result for use in subsequent layers</li> | |
| 161 | * <li>Extract and return output values from output nodes</li> | |
| 162 | * </ol> | |
| 163 | * | |
| 164 | * @param inputValues mapping from input node indices to their values | |
| 165 | * @return mapping from output node indices to their computed values | |
| 166 | * @throws IllegalArgumentException if inputValues is null, has wrong size, or missing required inputs | |
| 167 | */ | |
| 168 | public Map<Integer, Float> compute(final Map<Integer, Float> inputValues) { | |
| 169 | Validate.notNull(inputValues); | |
| 170 | Validate.isTrue(inputValues.size() == inputNodeIndices.size()); | |
| 171 | ||
| 172 |
1
1. compute : removed call to java/util/HashMap::<init> → KILLED |
final Map<Integer, Float> nodeValues = new HashMap<>(); |
| 173 | ||
| 174 | for (final Integer inputNodeIndex : inputNodeIndices) { | |
| 175 |
2
1. compute : replaced call to java/util/Map::get with argument → KILLED 2. compute : removed call to java/util/Map::get → KILLED |
Float nodeValue = inputValues.get(inputNodeIndex); |
| 176 |
3
1. compute : removed conditional - replaced equality check with false → SURVIVED 2. compute : removed conditional - replaced equality check with true → KILLED 3. compute : negated conditional → KILLED |
if (nodeValue == null) { |
| 177 |
1
1. compute : removed call to java/lang/IllegalArgumentException::<init> → NO_COVERAGE |
throw new IllegalArgumentException("Input vector missing values for input node " + inputNodeIndex); |
| 178 | } | |
| 179 |
2
1. compute : replaced call to java/util/Map::put with argument → KILLED 2. compute : removed call to java/util/Map::put → KILLED |
nodeValues.put(inputNodeIndex, nodeValue); |
| 180 | } | |
| 181 | ||
| 182 |
1
1. compute : Substituted 1 with 0 → KILLED |
int layerIndex = 1; |
| 183 |
5
1. compute : changed conditional boundary → KILLED 2. compute : removed call to java/util/List::size → KILLED 3. compute : removed conditional - replaced comparison check with true → KILLED 4. compute : negated conditional → KILLED 5. compute : removed conditional - replaced comparison check with false → KILLED |
while (layerIndex < layers.size()) { |
| 184 | ||
| 185 |
1
1. compute : removed call to java/util/List::get → KILLED |
final List<Integer> layer = layers.get(layerIndex); |
| 186 | ||
| 187 |
4
1. compute : removed conditional - replaced equality check with true → SURVIVED 2. compute : removed conditional - replaced equality check with false → KILLED 3. compute : negated conditional → KILLED 4. compute : removed call to org/apache/commons/collections4/CollectionUtils::isNotEmpty → KILLED |
if (CollectionUtils.isNotEmpty(layer)) { |
| 188 | ||
| 189 | for (Integer nodeIndex : layer) { | |
| 190 |
1
1. compute : Substituted 0.0 with 1.0 → KILLED |
float sum = 0.0f; |
| 191 |
3
1. compute : removed call to java/util/Set::of → SURVIVED 2. compute : removed call to java/util/Map::getOrDefault → KILLED 3. compute : replaced call to java/util/Map::getOrDefault with argument → KILLED |
final var incomingNodes = backwardConnections.getOrDefault(nodeIndex, Set.of()); |
| 192 | for (final Connection incomingConnection : incomingNodes) { | |
| 193 |
5
1. compute : removed conditional - replaced equality check with false → SURVIVED 2. compute : removed call to java/lang/Integer::intValue → KILLED 3. compute : removed conditional - replaced equality check with true → KILLED 4. compute : removed call to net/bmahe/genetics4j/neat/Connection::toNodeIndex → KILLED 5. compute : negated conditional → KILLED |
if (incomingConnection.toNodeIndex() != nodeIndex) { |
| 194 |
1
1. compute : removed call to java/lang/IllegalStateException::<init> → NO_COVERAGE |
throw new IllegalStateException(); |
| 195 | } | |
| 196 | ||
| 197 | // Incoming connection may have been disabled and dangling | |
| 198 |
6
1. compute : removed conditional - replaced equality check with true → SURVIVED 2. compute : removed call to net/bmahe/genetics4j/neat/Connection::fromNodeIndex → SURVIVED 3. compute : negated conditional → KILLED 4. compute : removed call to java/lang/Integer::valueOf → KILLED 5. compute : removed conditional - replaced equality check with false → KILLED 6. compute : removed call to java/util/Map::containsKey → KILLED |
if (nodeValues.containsKey(incomingConnection.fromNodeIndex())) { |
| 199 |
1
1. compute : removed call to net/bmahe/genetics4j/neat/Connection::weight → KILLED |
final float weight = incomingConnection.weight(); |
| 200 |
5
1. compute : removed call to java/util/Map::get → KILLED 2. compute : removed call to java/lang/Float::floatValue → KILLED 3. compute : replaced call to java/util/Map::get with argument → KILLED 4. compute : removed call to net/bmahe/genetics4j/neat/Connection::fromNodeIndex → KILLED 5. compute : removed call to java/lang/Integer::valueOf → KILLED |
final float incomingNodeValue = nodeValues.get(incomingConnection.fromNodeIndex()); |
| 201 | ||
| 202 |
2
1. compute : Replaced float multiplication with division → KILLED 2. compute : Replaced float addition with subtraction → KILLED |
sum += weight * incomingNodeValue; |
| 203 | } | |
| 204 | } | |
| 205 |
3
1. compute : replaced call to java/util/function/Function::apply with argument → KILLED 2. compute : removed call to java/lang/Float::valueOf → KILLED 3. compute : removed call to java/util/function/Function::apply → KILLED |
final Float outputValue = activationFunction.apply(sum); |
| 206 |
2
1. compute : replaced call to java/util/Map::put with argument → KILLED 2. compute : removed call to java/util/Map::put → KILLED |
nodeValues.put(nodeIndex, outputValue); |
| 207 | } | |
| 208 | } | |
| 209 | ||
| 210 |
1
1. compute : Changed increment from 1 to -1 → KILLED |
layerIndex++; |
| 211 | } | |
| 212 | ||
| 213 |
1
1. compute : removed call to java/util/HashMap::<init> → KILLED |
final Map<Integer, Float> outputValues = new HashMap<>(); |
| 214 | for (final Integer outputNodeIndex : outputNodeIndices) { | |
| 215 |
2
1. compute : replaced call to java/util/Map::get with argument → KILLED 2. compute : removed call to java/util/Map::get → KILLED |
final Float value = nodeValues.get(outputNodeIndex); |
| 216 |
3
1. compute : removed conditional - replaced equality check with false → SURVIVED 2. compute : negated conditional → KILLED 3. compute : removed conditional - replaced equality check with true → KILLED |
if (value == null) { |
| 217 |
1
1. compute : removed call to java/lang/IllegalArgumentException::<init> → NO_COVERAGE |
throw new IllegalArgumentException("Missing output value for node " + outputNodeIndex); |
| 218 | } | |
| 219 |
2
1. compute : replaced call to java/util/Map::put with argument → KILLED 2. compute : removed call to java/util/Map::put → KILLED |
outputValues.put(outputNodeIndex, value); |
| 220 | } | |
| 221 |
1
1. compute : replaced return value with Collections.emptyMap for net/bmahe/genetics4j/neat/FeedForwardNetwork::compute → KILLED |
return outputValues; |
| 222 | } | |
| 223 | } | |
Mutations | ||
| 137 |
1.1 |
|
| 138 |
1.1 |
|
| 139 |
1.1 |
|
| 140 |
1.1 |
|
| 142 |
1.1 2.2 3.3 |
|
| 143 |
1.1 2.2 |
|
| 172 |
1.1 |
|
| 175 |
1.1 2.2 |
|
| 176 |
1.1 2.2 3.3 |
|
| 177 |
1.1 |
|
| 179 |
1.1 2.2 |
|
| 182 |
1.1 |
|
| 183 |
1.1 2.2 3.3 4.4 5.5 |
|
| 185 |
1.1 |
|
| 187 |
1.1 2.2 3.3 4.4 |
|
| 190 |
1.1 |
|
| 191 |
1.1 2.2 3.3 |
|
| 193 |
1.1 2.2 3.3 4.4 5.5 |
|
| 194 |
1.1 |
|
| 198 |
1.1 2.2 3.3 4.4 5.5 6.6 |
|
| 199 |
1.1 |
|
| 200 |
1.1 2.2 3.3 4.4 5.5 |
|
| 202 |
1.1 2.2 |
|
| 205 |
1.1 2.2 3.3 |
|
| 206 |
1.1 2.2 |
|
| 210 |
1.1 |
|
| 213 |
1.1 |
|
| 215 |
1.1 2.2 |
|
| 216 |
1.1 2.2 3.3 |
|
| 217 |
1.1 |
|
| 219 |
1.1 2.2 |
|
| 221 |
1.1 |