FeedForwardNetwork.java

package net.bmahe.genetics4j.neat;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;

import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.Validate;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

public class FeedForwardNetwork {
	public static final Logger logger = LogManager.getLogger(FeedForwardNetwork.class);

	private final Set<Integer> inputNodeIndices;
	private final Set<Integer> outputNodeIndices;
	private final List<Connection> connections;

	private final List<List<Integer>> layers;
	private final Map<Integer, Set<Connection>> backwardConnections;

	private final Function<Float, Float> activationFunction;

	public FeedForwardNetwork(final Set<Integer> _inputNodeIndices, final Set<Integer> _outputNodeIndices,
			final List<Connection> _connections, final Function<Float, Float> _activationFunction) {
		Validate.isTrue(CollectionUtils.isNotEmpty(_inputNodeIndices));
		Validate.isTrue(CollectionUtils.isNotEmpty(_outputNodeIndices));
		Validate.isTrue(CollectionUtils.isNotEmpty(_connections));
		Validate.notNull(_activationFunction);

		this.inputNodeIndices = _inputNodeIndices;
		this.outputNodeIndices = _outputNodeIndices;
		this.connections = _connections;
		this.activationFunction = _activationFunction;

		this.layers = NeatUtils.partitionLayersNodes(this.inputNodeIndices, this.outputNodeIndices, this.connections);
		this.backwardConnections = NeatUtils.computeBackwardConnections(this.connections);
	}

	public Map<Integer, Float> compute(final Map<Integer, Float> inputValues) {
		Validate.notNull(inputValues);
		Validate.isTrue(inputValues.size() == inputNodeIndices.size());

		final Map<Integer, Float> nodeValues = new HashMap<>();

		for (final Integer inputNodeIndex : inputNodeIndices) {
			Float nodeValue = inputValues.get(inputNodeIndex);
			if (nodeValue == null) {
				throw new IllegalArgumentException("Input vector missing values for input node " + inputNodeIndex);
			}
			nodeValues.put(inputNodeIndex, nodeValue);
		}

		int layerIndex = 1;
		while (layerIndex < layers.size()) {

			final List<Integer> layer = layers.get(layerIndex);

			if (CollectionUtils.isNotEmpty(layer)) {

				for (Integer nodeIndex : layer) {
					float sum = 0.0f;
					final var incomingNodes = backwardConnections.getOrDefault(nodeIndex, Set.of());
					for (final Connection incomingConnection : incomingNodes) {
						if (incomingConnection.toNodeIndex() != nodeIndex) {
							throw new IllegalStateException();
						}

						// Incoming connection may have been disabled and dangling
						if (nodeValues.containsKey(incomingConnection.fromNodeIndex())) {
							final float weight = incomingConnection.weight();
							final float incomingNodeValue = nodeValues.get(incomingConnection.fromNodeIndex());

							sum += weight * incomingNodeValue;
						}
					}
					final Float outputValue = activationFunction.apply(sum);
					nodeValues.put(nodeIndex, outputValue);
				}
			}

			layerIndex++;
		}

		final Map<Integer, Float> outputValues = new HashMap<>();
		for (final Integer outputNodeIndex : outputNodeIndices) {
			final Float value = nodeValues.get(outputNodeIndex);
			if (value == null) {
				throw new IllegalArgumentException("Missing output value for node " + outputNodeIndex);
			}
			outputValues.put(outputNodeIndex, value);
		}
		return outputValues;
	}
}