RecurrentNetwork.java

package net.bmahe.genetics4j.neat;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.Function;

import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.Validate;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

/**
 * Implements a recurrent neural network evaluator for NEAT chromosomes.
 *
 * <p>Unlike {@link FeedForwardNetwork}, this implementation can execute arbitrary directed graphs that include
 * recurrent (cyclic) connections. Activations are propagated iteratively until they converge or a maximum number of
 * iterations is reached, making it suitable for tasks that rely on short-term memory or feedback loops.</p>
 */
public class RecurrentNetwork {

	public static final Logger logger = LogManager.getLogger(RecurrentNetwork.class);

	/** Default maximum number of recurrent iterations. */
	public static final int DEFAULT_MAX_ITERATIONS = 12;

	/** Default convergence threshold applied to node deltas between iterations. */
	public static final float DEFAULT_CONVERGENCE_THRESHOLD = 1e-3f;

	/** Default value assigned to non-input nodes before the first iteration. */
	public static final float DEFAULT_INITIAL_STATE_VALUE = 0.0f;

	private final Set<Integer> inputNodeIndices;
	private final Set<Integer> outputNodeIndices;
	private final List<Connection> connections;

	private final Map<Integer, Set<Connection>> backwardConnections;
	private final List<Integer> evaluatedNodeIndices;
	private final Set<Integer> allNodeIndices;

	private final Function<Float, Float> activationFunction;
	private final int maxIterations;
	private final float convergenceThreshold;
	private final float initialStateValue;
	private final Map<Integer, Float> nodeState;

	public RecurrentNetwork(final Set<Integer> _inputNodeIndices,
			final Set<Integer> _outputNodeIndices,
			final List<Connection> _connections,
			final Function<Float, Float> _activationFunction) {
		this(_inputNodeIndices,
				_outputNodeIndices,
				_connections,
				_activationFunction,
				DEFAULT_MAX_ITERATIONS,
				DEFAULT_CONVERGENCE_THRESHOLD,
				DEFAULT_INITIAL_STATE_VALUE);
	}

	public RecurrentNetwork(final Set<Integer> _inputNodeIndices,
			final Set<Integer> _outputNodeIndices,
			final List<Connection> _connections,
			final Function<Float, Float> _activationFunction,
			final int _maxIterations,
			final float _convergenceThreshold,
			final float _initialStateValue) {
		Validate.isTrue(CollectionUtils.isNotEmpty(_inputNodeIndices));
		Validate.isTrue(CollectionUtils.isNotEmpty(_outputNodeIndices));
		Objects.requireNonNull(_connections);
		Objects.requireNonNull(_activationFunction);
		Validate.isTrue(_maxIterations > 0, "maxIterations must be strictly positive");
		Validate.isTrue(_convergenceThreshold >= 0.0f, "convergenceThreshold must be non-negative");

		this.inputNodeIndices = Collections.unmodifiableSet(new HashSet<>(_inputNodeIndices));
		this.outputNodeIndices = Collections.unmodifiableSet(new HashSet<>(_outputNodeIndices));
		this.connections = Collections.unmodifiableList(new ArrayList<>(_connections));
		this.activationFunction = _activationFunction;
		this.maxIterations = _maxIterations;
		this.convergenceThreshold = _convergenceThreshold;
		this.initialStateValue = _initialStateValue;

		final Map<Integer, Set<Connection>> backward = NeatUtils.computeBackwardConnections(this.connections);
		this.backwardConnections = Collections.unmodifiableMap(backward);

		final Set<Integer> allNodes = new HashSet<>(this.inputNodeIndices);
		allNodes.addAll(this.outputNodeIndices);
		for (final Connection connection : this.connections) {
			allNodes.add(connection.fromNodeIndex());
			allNodes.add(connection.toNodeIndex());
		}

		this.allNodeIndices = Collections.unmodifiableSet(allNodes);
		this.evaluatedNodeIndices = Collections.unmodifiableList(
				this.allNodeIndices.stream()
						.filter(nodeIndex -> this.inputNodeIndices.contains(nodeIndex) == false)
						.sorted()
						.toList());
		this.nodeState = new HashMap<>();
		resetState();
	}

	public Map<Integer, Float> compute(final Map<Integer, Float> inputValues) {
		resetState();
		return step(inputValues);
	}

	public Map<Integer, Float> step(final Map<Integer, Float> inputValues) {
		Objects.requireNonNull(inputValues);
		Validate.isTrue(
				inputValues.size() == inputNodeIndices.size(),
					"Missing input values: expected %d entries but found %d",
					inputNodeIndices.size(),
					inputValues.size());

		for (final Integer inputNodeIndex : inputNodeIndices) {
			if (inputValues.containsKey(inputNodeIndex) == false) {
				throw new IllegalArgumentException("Input vector missing values for input node " + inputNodeIndex);
			}
		}

		final Map<Integer, Float> previousValues = new HashMap<>(nodeState);
		for (final Integer inputNodeIndex : inputNodeIndices) {
			previousValues.put(inputNodeIndex, inputValues.get(inputNodeIndex));
		}

		Map<Integer, Float> currentValues = previousValues;

		for (int iteration = 0; iteration < maxIterations; iteration++) {
			final Map<Integer, Float> nextValues = new HashMap<>(currentValues);

			float maxDelta = 0.0f;
			for (final Integer nodeIndex : evaluatedNodeIndices) {
				final Set<Connection> incomingConnections = backwardConnections.getOrDefault(nodeIndex, Set.of());

				float sum = 0.0f;
				for (final Connection incomingConnection : incomingConnections) {
					final float weight = incomingConnection.weight();
					final float incomingValue = currentValues
							.getOrDefault(incomingConnection.fromNodeIndex(), initialStateValue);
					sum += weight * incomingValue;
				}

				final float newValue = activationFunction.apply(sum);
				final float previousValue = currentValues.getOrDefault(nodeIndex, initialStateValue);
				nextValues.put(nodeIndex, newValue);
				final float delta = Math.abs(newValue - previousValue);
				if (delta > maxDelta) {
					maxDelta = delta;
				}
			}

			currentValues = nextValues;

			if (maxDelta <= convergenceThreshold) {
				break;
			}
		}

		nodeState.clear();
		nodeState.putAll(currentValues);

		return buildOutputValues(currentValues);
	}

	public void resetState() {
		nodeState.clear();
		for (final Integer nodeIndex : allNodeIndices) {
			nodeState.put(nodeIndex, initialStateValue);
		}
	}

	private Map<Integer, Float> buildOutputValues(final Map<Integer, Float> nodeValues) {
		final Map<Integer, Float> outputValues = new HashMap<>();
		for (final Integer outputNodeIndex : outputNodeIndices) {
			final Float value = nodeValues.get(outputNodeIndex);
			if (value == null) {
				throw new IllegalArgumentException("Missing output value for node " + outputNodeIndex);
			}
			outputValues.put(outputNodeIndex, value);
		}
		return outputValues;
	}
}