FeedForwardNetwork.java
package net.bmahe.genetics4j.neat;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.Validate;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
/**
* Implements a feed-forward neural network for evaluating NEAT (NeuroEvolution of Augmenting Topologies) chromosomes.
*
* <p>FeedForwardNetwork provides a computational engine for executing neural networks evolved by the NEAT algorithm.
* It takes a network topology defined by connections and nodes, organizes them into computational layers, and
* provides efficient forward propagation for fitness evaluation. The network supports arbitrary topologies with
* variable numbers of hidden layers and connections.
*
* <p>Key features:
* <ul>
* <li><strong>Dynamic topology</strong>: Supports arbitrary network structures evolved by NEAT</li>
* <li><strong>Layer-based evaluation</strong>: Automatically computes optimal evaluation order</li>
* <li><strong>Configurable activation</strong>: Supports any activation function for hidden and output nodes</li>
* <li><strong>Efficient propagation</strong>: Optimized forward pass through network layers</li>
* </ul>
*
* <p>Network evaluation process:
* <ol>
* <li><strong>Input assignment</strong>: Input values are assigned to input nodes</li>
* <li><strong>Layer computation</strong>: Each layer is computed in topological order</li>
* <li><strong>Node activation</strong>: Each node applies weighted sum followed by activation function</li>
* <li><strong>Output extraction</strong>: Output values are collected from designated output nodes</li>
* </ol>
*
* <p>Network construction workflow:
* <ul>
* <li><strong>Topology analysis</strong>: Network connections are analyzed to determine layer structure</li>
* <li><strong>Layer partitioning</strong>: Nodes are organized into evaluation layers using topological sorting</li>
* <li><strong>Connection mapping</strong>: Backward connections are precomputed for efficient evaluation</li>
* <li><strong>Dead node removal</strong>: Unreachable nodes are excluded from computation</li>
* </ul>
*
* <p>Common usage patterns:
* <pre>{@code
* // Create network from NEAT chromosome
* NeatChromosome chromosome = // ... obtain from evolution
* Set<Integer> inputNodes = Set.of(0, 1, 2);
* Set<Integer> outputNodes = Set.of(3, 4);
* Function<Float, Float> activation = Activations::sigmoid;
*
* FeedForwardNetwork network = new FeedForwardNetwork(
* inputNodes, outputNodes, chromosome.getConnections(), activation
* );
*
* // Evaluate network on input data
* Map<Integer, Float> inputs = Map.of(0, 1.0f, 1, 0.5f, 2, -0.3f);
* Map<Integer, Float> outputs = network.compute(inputs);
*
* // Extract specific outputs
* float output1 = outputs.get(3);
* float output2 = outputs.get(4);
* }</pre>
*
* <p>Activation function integration:
* <ul>
* <li><strong>Sigmoid activation</strong>: Standard logistic function for binary classification</li>
* <li><strong>Tanh activation</strong>: Hyperbolic tangent for continuous outputs</li>
* <li><strong>Linear activation</strong>: Identity function for regression problems</li>
* <li><strong>Custom functions</strong>: Any Function<Float, Float> can be used</li>
* </ul>
*
* <p>Performance optimizations:
* <ul>
* <li><strong>Layer precomputation</strong>: Network layers are computed once during construction</li>
* <li><strong>Connection mapping</strong>: Backward connections are precomputed for fast lookup</li>
* <li><strong>Dead node elimination</strong>: Unreachable nodes are excluded from evaluation</li>
* <li><strong>Efficient propagation</strong>: Only enabled connections participate in computation</li>
* </ul>
*
* <p>Error handling and validation:
* <ul>
* <li><strong>Input validation</strong>: Ensures all input nodes receive values</li>
* <li><strong>Output validation</strong>: Verifies all output nodes produce values</li>
* <li><strong>Topology validation</strong>: Validates network structure during construction</li>
* <li><strong>Connection consistency</strong>: Ensures connection endpoints reference valid nodes</li>
* </ul>
*
* <p>Integration with NEAT evolution:
* <ul>
* <li><strong>Chromosome evaluation</strong>: Converts NEAT chromosomes to executable networks</li>
* <li><strong>Fitness computation</strong>: Provides network output for fitness evaluation</li>
* <li><strong>Topology evolution</strong>: Supports networks with varying structure complexity</li>
* <li><strong>Innovation tracking</strong>: Works with networks containing historical innovations</li>
* </ul>
*
* @see NeatChromosome
* @see Connection
* @see Activations
* @see NeatUtils#partitionLayersNodes
*/
public class FeedForwardNetwork {
public static final Logger logger = LogManager.getLogger(FeedForwardNetwork.class);
private final Set<Integer> inputNodeIndices;
private final Set<Integer> outputNodeIndices;
private final List<Connection> connections;
private final List<List<Integer>> layers;
private final Map<Integer, Set<Connection>> backwardConnections;
private final Function<Float, Float> activationFunction;
/**
* Constructs a new feed-forward network with the specified topology and activation function.
*
* <p>The constructor analyzes the network topology, computes evaluation layers using topological
* sorting, and precomputes connection mappings for efficient forward propagation. The network
* is immediately ready for evaluation after construction.
*
* @param _inputNodeIndices set of input node indices
* @param _outputNodeIndices set of output node indices
* @param _connections list of network connections defining the topology
* @param _activationFunction activation function to apply to hidden and output nodes
* @throws IllegalArgumentException if any parameter is null or empty
*/
public FeedForwardNetwork(final Set<Integer> _inputNodeIndices, final Set<Integer> _outputNodeIndices,
final List<Connection> _connections, final Function<Float, Float> _activationFunction) {
Validate.isTrue(CollectionUtils.isNotEmpty(_inputNodeIndices));
Validate.isTrue(CollectionUtils.isNotEmpty(_outputNodeIndices));
Validate.isTrue(CollectionUtils.isNotEmpty(_connections));
Validate.notNull(_activationFunction);
this.inputNodeIndices = _inputNodeIndices;
this.outputNodeIndices = _outputNodeIndices;
this.connections = _connections;
this.activationFunction = _activationFunction;
this.layers = NeatUtils.partitionLayersNodes(this.inputNodeIndices, this.outputNodeIndices, this.connections);
this.backwardConnections = NeatUtils.computeBackwardConnections(this.connections);
}
/**
* Computes the network output for the given input values.
*
* <p>This method performs forward propagation through the network, computing node activations
* layer by layer in topological order. Input values are assigned to input nodes, then each
* subsequent layer is computed by applying weighted sums and activation functions.
*
* <p>The computation process:
* <ol>
* <li>Input values are assigned to input nodes</li>
* <li>For each layer (starting from first hidden layer):</li>
* <li> For each node in the layer:</li>
* <li> Compute weighted sum of inputs from previous layers</li>
* <li> Apply activation function to the sum</li>
* <li> Store the result for use in subsequent layers</li>
* <li>Extract and return output values from output nodes</li>
* </ol>
*
* @param inputValues mapping from input node indices to their values
* @return mapping from output node indices to their computed values
* @throws IllegalArgumentException if inputValues is null, has wrong size, or missing required inputs
*/
public Map<Integer, Float> compute(final Map<Integer, Float> inputValues) {
Validate.notNull(inputValues);
Validate.isTrue(inputValues.size() == inputNodeIndices.size());
final Map<Integer, Float> nodeValues = new HashMap<>();
for (final Integer inputNodeIndex : inputNodeIndices) {
Float nodeValue = inputValues.get(inputNodeIndex);
if (nodeValue == null) {
throw new IllegalArgumentException("Input vector missing values for input node " + inputNodeIndex);
}
nodeValues.put(inputNodeIndex, nodeValue);
}
int layerIndex = 1;
while (layerIndex < layers.size()) {
final List<Integer> layer = layers.get(layerIndex);
if (CollectionUtils.isNotEmpty(layer)) {
for (Integer nodeIndex : layer) {
float sum = 0.0f;
final var incomingNodes = backwardConnections.getOrDefault(nodeIndex, Set.of());
for (final Connection incomingConnection : incomingNodes) {
if (incomingConnection.toNodeIndex() != nodeIndex) {
throw new IllegalStateException();
}
// Incoming connection may have been disabled and dangling
if (nodeValues.containsKey(incomingConnection.fromNodeIndex())) {
final float weight = incomingConnection.weight();
final float incomingNodeValue = nodeValues.get(incomingConnection.fromNodeIndex());
sum += weight * incomingNodeValue;
}
}
final Float outputValue = activationFunction.apply(sum);
nodeValues.put(nodeIndex, outputValue);
}
}
layerIndex++;
}
final Map<Integer, Float> outputValues = new HashMap<>();
for (final Integer outputNodeIndex : outputNodeIndices) {
final Float value = nodeValues.get(outputNodeIndex);
if (value == null) {
throw new IllegalArgumentException("Missing output value for node " + outputNodeIndex);
}
outputValues.put(outputNodeIndex, value);
}
return outputValues;
}
}