RecurrentNetwork.java
package net.bmahe.genetics4j.neat;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.Function;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.Validate;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
/**
* Implements a recurrent neural network evaluator for NEAT chromosomes.
*
* <p>Unlike {@link FeedForwardNetwork}, this implementation can execute arbitrary directed graphs that include
* recurrent (cyclic) connections. Activations are propagated iteratively until they converge or a maximum number of
* iterations is reached, making it suitable for tasks that rely on short-term memory or feedback loops.</p>
*/
public class RecurrentNetwork {
public static final Logger logger = LogManager.getLogger(RecurrentNetwork.class);
/** Default maximum number of recurrent iterations. */
public static final int DEFAULT_MAX_ITERATIONS = 12;
/** Default convergence threshold applied to node deltas between iterations. */
public static final float DEFAULT_CONVERGENCE_THRESHOLD = 1e-3f;
/** Default value assigned to non-input nodes before the first iteration. */
public static final float DEFAULT_INITIAL_STATE_VALUE = 0.0f;
private final Set<Integer> inputNodeIndices;
private final Set<Integer> outputNodeIndices;
private final List<Connection> connections;
private final Map<Integer, Set<Connection>> backwardConnections;
private final List<Integer> evaluatedNodeIndices;
private final Set<Integer> allNodeIndices;
private final Function<Float, Float> activationFunction;
private final int maxIterations;
private final float convergenceThreshold;
private final float initialStateValue;
private final Map<Integer, Float> nodeState;
public RecurrentNetwork(final Set<Integer> _inputNodeIndices,
final Set<Integer> _outputNodeIndices,
final List<Connection> _connections,
final Function<Float, Float> _activationFunction) {
this(_inputNodeIndices,
_outputNodeIndices,
_connections,
_activationFunction,
DEFAULT_MAX_ITERATIONS,
DEFAULT_CONVERGENCE_THRESHOLD,
DEFAULT_INITIAL_STATE_VALUE);
}
public RecurrentNetwork(final Set<Integer> _inputNodeIndices,
final Set<Integer> _outputNodeIndices,
final List<Connection> _connections,
final Function<Float, Float> _activationFunction,
final int _maxIterations,
final float _convergenceThreshold,
final float _initialStateValue) {
Validate.isTrue(CollectionUtils.isNotEmpty(_inputNodeIndices));
Validate.isTrue(CollectionUtils.isNotEmpty(_outputNodeIndices));
Objects.requireNonNull(_connections);
Objects.requireNonNull(_activationFunction);
Validate.isTrue(_maxIterations > 0, "maxIterations must be strictly positive");
Validate.isTrue(_convergenceThreshold >= 0.0f, "convergenceThreshold must be non-negative");
this.inputNodeIndices = Collections.unmodifiableSet(new HashSet<>(_inputNodeIndices));
this.outputNodeIndices = Collections.unmodifiableSet(new HashSet<>(_outputNodeIndices));
this.connections = Collections.unmodifiableList(new ArrayList<>(_connections));
this.activationFunction = _activationFunction;
this.maxIterations = _maxIterations;
this.convergenceThreshold = _convergenceThreshold;
this.initialStateValue = _initialStateValue;
final Map<Integer, Set<Connection>> backward = NeatUtils.computeBackwardConnections(this.connections);
this.backwardConnections = Collections.unmodifiableMap(backward);
final Set<Integer> allNodes = new HashSet<>(this.inputNodeIndices);
allNodes.addAll(this.outputNodeIndices);
for (final Connection connection : this.connections) {
allNodes.add(connection.fromNodeIndex());
allNodes.add(connection.toNodeIndex());
}
this.allNodeIndices = Collections.unmodifiableSet(allNodes);
this.evaluatedNodeIndices = Collections.unmodifiableList(
this.allNodeIndices.stream()
.filter(nodeIndex -> this.inputNodeIndices.contains(nodeIndex) == false)
.sorted()
.toList());
this.nodeState = new HashMap<>();
resetState();
}
public Map<Integer, Float> compute(final Map<Integer, Float> inputValues) {
resetState();
return step(inputValues);
}
public Map<Integer, Float> step(final Map<Integer, Float> inputValues) {
Objects.requireNonNull(inputValues);
Validate.isTrue(
inputValues.size() == inputNodeIndices.size(),
"Missing input values: expected %d entries but found %d",
inputNodeIndices.size(),
inputValues.size());
for (final Integer inputNodeIndex : inputNodeIndices) {
if (inputValues.containsKey(inputNodeIndex) == false) {
throw new IllegalArgumentException("Input vector missing values for input node " + inputNodeIndex);
}
}
final Map<Integer, Float> previousValues = new HashMap<>(nodeState);
for (final Integer inputNodeIndex : inputNodeIndices) {
previousValues.put(inputNodeIndex, inputValues.get(inputNodeIndex));
}
Map<Integer, Float> currentValues = previousValues;
for (int iteration = 0; iteration < maxIterations; iteration++) {
final Map<Integer, Float> nextValues = new HashMap<>(currentValues);
float maxDelta = 0.0f;
for (final Integer nodeIndex : evaluatedNodeIndices) {
final Set<Connection> incomingConnections = backwardConnections.getOrDefault(nodeIndex, Set.of());
float sum = 0.0f;
for (final Connection incomingConnection : incomingConnections) {
final float weight = incomingConnection.weight();
final float incomingValue = currentValues
.getOrDefault(incomingConnection.fromNodeIndex(), initialStateValue);
sum += weight * incomingValue;
}
final float newValue = activationFunction.apply(sum);
final float previousValue = currentValues.getOrDefault(nodeIndex, initialStateValue);
nextValues.put(nodeIndex, newValue);
final float delta = Math.abs(newValue - previousValue);
if (delta > maxDelta) {
maxDelta = delta;
}
}
currentValues = nextValues;
if (maxDelta <= convergenceThreshold) {
break;
}
}
nodeState.clear();
nodeState.putAll(currentValues);
return buildOutputValues(currentValues);
}
public void resetState() {
nodeState.clear();
for (final Integer nodeIndex : allNodeIndices) {
nodeState.put(nodeIndex, initialStateValue);
}
}
private Map<Integer, Float> buildOutputValues(final Map<Integer, Float> nodeValues) {
final Map<Integer, Float> outputValues = new HashMap<>();
for (final Integer outputNodeIndex : outputNodeIndices) {
final Float value = nodeValues.get(outputNodeIndex);
if (value == null) {
throw new IllegalArgumentException("Missing output value for node " + outputNodeIndex);
}
outputValues.put(outputNodeIndex, value);
}
return outputValues;
}
}