View Javadoc
1   package net.bmahe.genetics4j.neat;
2   
3   import java.util.ArrayDeque;
4   import java.util.ArrayList;
5   import java.util.Deque;
6   import java.util.HashMap;
7   import java.util.HashSet;
8   import java.util.List;
9   import java.util.Map;
10  import java.util.Map.Entry;
11  import java.util.Objects;
12  import java.util.Set;
13  import java.util.function.BiPredicate;
14  import java.util.random.RandomGenerator;
15  
16  import org.apache.commons.collections4.CollectionUtils;
17  import org.apache.commons.lang3.Validate;
18  
19  import net.bmahe.genetics4j.core.Genotype;
20  import net.bmahe.genetics4j.core.Individual;
21  import net.bmahe.genetics4j.core.Population;
22  import net.bmahe.genetics4j.neat.chromosomes.NeatChromosome;
23  
24  /**
25   * Utility class providing core algorithmic operations for the NEAT (NeuroEvolution of Augmenting Topologies) algorithm.
26   * 
27   * <p>NeatUtils contains essential algorithms and helper methods for implementing NEAT neural network evolution,
28   * including network topology analysis, compatibility distance calculation, speciation, and structural operations. These
29   * utilities support the NEAT algorithm's key features of topology innovation, structural mutation, and species-based
30   * population organization.
31   * 
32   * <p>Key functionality areas:
33   * <ul>
34   * <li><strong>Network topology analysis</strong>: Computing network layers, forward/backward connections, and dead node
35   * detection</li>
36   * <li><strong>Compatibility distance</strong>: Measuring genetic similarity between neural networks for speciation</li>
37   * <li><strong>Speciation management</strong>: Organizing populations into species based on genetic similarity</li>
38   * <li><strong>Structural analysis</strong>: Analyzing network connectivity patterns and structural properties</li>
39   * </ul>
40   * 
41   * <p>NEAT algorithm integration:
42   * <ul>
43   * <li><strong>Innovation tracking</strong>: Support for historical marking and innovation numbers</li>
44   * <li><strong>Structural mutations</strong>: Utilities for add-node and add-connection operations</li>
45   * <li><strong>Network evaluation</strong>: Layer-based network evaluation ordering</li>
46   * <li><strong>Population diversity</strong>: Species-based diversity maintenance</li>
47   * </ul>
48   * 
49   * <p>Core NEAT concepts implemented:
50   * <ul>
51   * <li><strong>Genetic similarity</strong>: Compatibility distance based on excess, disjoint, and weight
52   * differences</li>
53   * <li><strong>Topological innovation</strong>: Structural changes tracked through innovation numbers</li>
54   * <li><strong>Speciation</strong>: Dynamic species formation based on genetic distance thresholds</li>
55   * <li><strong>Network evaluation</strong>: Feed-forward evaluation through computed network layers</li>
56   * </ul>
57   * 
58   * <p>Algorithmic foundations:
59   * <ul>
60   * <li><strong>Graph algorithms</strong>: Topological sorting, connectivity analysis, and layer computation</li>
61   * <li><strong>Genetic distance metrics</strong>: NEAT-specific compatibility distance calculation</li>
62   * <li><strong>Population clustering</strong>: Species formation and maintenance algorithms</li>
63   * <li><strong>Network optimization</strong>: Dead node removal and structural simplification</li>
64   * </ul>
65   * 
66   * @see NeatChromosome
67   * @see Connection
68   * @see Species
69   * @see InnovationManager
70   */
71  public class NeatUtils {
72  
73  	private NeatUtils() {
74  	}
75  
76  	/**
77  	 * Working backward from the output nodes, we identify the nodes that did not get visited as dead nodes
78  	 * 
79  	 * @param connections
80  	 * @param forwardConnections
81  	 * @param backwardConnections
82  	 * @param outputNodeIndices
83  	 * @return
84  	 */
85  	public static Set<Integer> computeDeadNodes(final List<Connection> connections,
86  			final Map<Integer, Set<Integer>> forwardConnections, final Map<Integer, Set<Integer>> backwardConnections,
87  			final Set<Integer> outputNodeIndices) {
88  		Objects.requireNonNull(connections);
89  
90  		final Set<Integer> deadNodes = new HashSet<>();
91  		for (final Connection connection : connections) {
92  			deadNodes.add(connection.fromNodeIndex());
93  			deadNodes.add(connection.toNodeIndex());
94  		}
95  		deadNodes.removeAll(outputNodeIndices);
96  
97  		final Set<Integer> visited = new HashSet<>();
98  		final Deque<Integer> toVisit = new ArrayDeque<>(outputNodeIndices);
99  		while (toVisit.size() > 0) {
100 			final Integer currentNode = toVisit.poll();
101 
102 			deadNodes.remove(currentNode);
103 			if (visited.contains(currentNode) == false) {
104 
105 				visited.add(currentNode);
106 
107 				final var next = backwardConnections.getOrDefault(currentNode, Set.of());
108 				if (next.size() > 0) {
109 					toVisit.addAll(next);
110 				}
111 			}
112 		}
113 
114 		return deadNodes;
115 	}
116 
117 	public static Map<Integer, Set<Integer>> computeForwardLinks(final List<Connection> connections) {
118 		Objects.requireNonNull(connections);
119 
120 		final Map<Integer, Set<Integer>> forwardConnections = new HashMap<>();
121 		for (final Connection connection : connections) {
122 			final var fromNodeIndex = connection.fromNodeIndex();
123 			final var toNodeIndex = connection.toNodeIndex();
124 
125 			if (connection.isEnabled()) {
126 				final var toNodes = forwardConnections.computeIfAbsent(fromNodeIndex, k -> new HashSet<>());
127 
128 				if (toNodes.add(toNodeIndex) == false) {
129 					throw new IllegalArgumentException(
130 							"Found duplicate entries for nodes defined in connection " + connection);
131 				}
132 			}
133 		}
134 
135 		return forwardConnections;
136 	}
137 
138 	public static Map<Integer, Set<Integer>> computeBackwardLinks(final List<Connection> connections) {
139 		Objects.requireNonNull(connections);
140 
141 		final Map<Integer, Set<Integer>> backwardConnections = new HashMap<>();
142 		for (final Connection connection : connections) {
143 			final var fromNodeIndex = connection.fromNodeIndex();
144 			final var toNodeIndex = connection.toNodeIndex();
145 
146 			if (connection.isEnabled()) {
147 				final var fromNodes = backwardConnections.computeIfAbsent(toNodeIndex, k -> new HashSet<>());
148 
149 				if (fromNodes.add(fromNodeIndex) == false) {
150 					throw new IllegalArgumentException(
151 							"Found duplicate entries for nodes defined in connection " + connection);
152 				}
153 			}
154 		}
155 		return backwardConnections;
156 	}
157 
158 	public static Map<Integer, Set<Connection>> computeBackwardConnections(final List<Connection> connections) {
159 		Objects.requireNonNull(connections);
160 
161 		final Map<Integer, Set<Connection>> backwardConnections = new HashMap<>();
162 		for (final Connection connection : connections) {
163 			final var toNodeIndex = connection.toNodeIndex();
164 
165 			if (connection.isEnabled()) {
166 				final var fromConnections = backwardConnections.computeIfAbsent(toNodeIndex, k -> new HashSet<>());
167 
168 				if (fromConnections.stream()
169 						.anyMatch(existingConnection -> existingConnection.fromNodeIndex() == connection.fromNodeIndex())) {
170 					throw new IllegalArgumentException(
171 							"Found duplicate entries for nodes defined in connection " + connection);
172 				}
173 				fromConnections.add(connection);
174 			}
175 		}
176 		return backwardConnections;
177 	}
178 
179 	public static List<List<Integer>> partitionLayersNodes(final Set<Integer> inputNodeIndices,
180 			final Set<Integer> outputNodeIndices, final List<Connection> connections) {
181 		Validate.isTrue(CollectionUtils.isNotEmpty(inputNodeIndices));
182 		Validate.isTrue(CollectionUtils.isNotEmpty(outputNodeIndices));
183 		Validate.isTrue(CollectionUtils.isNotEmpty(connections));
184 
185 		final Map<Integer, Set<Integer>> forwardConnections = computeForwardLinks(connections);
186 		final Map<Integer, Set<Integer>> backwardConnections = computeBackwardLinks(connections);
187 
188 		// Is it useful? If it's connected to the input node, it's not dead
189 		final var deadNodes = computeDeadNodes(connections, forwardConnections, backwardConnections, outputNodeIndices);
190 
191 		final Set<Integer> processedSet = new HashSet<>();
192 		final List<List<Integer>> layers = new ArrayList<>();
193 		processedSet.addAll(inputNodeIndices);
194 		layers.add(new ArrayList<>(inputNodeIndices));
195 
196 		boolean done = false;
197 		while (done == false) {
198 			final List<Integer> layer = new ArrayList<>();
199 
200 			final Set<Integer> layerCandidates = new HashSet<>();
201 			for (final Entry<Integer, Set<Integer>> entry : forwardConnections.entrySet()) {
202 				final var key = entry.getKey();
203 				final var values = entry.getValue();
204 
205 				if (processedSet.contains(key) == true) {
206 					for (final Integer candidate : values) {
207 						if (deadNodes.contains(candidate) == false && processedSet.contains(candidate) == false
208 								&& outputNodeIndices.contains(candidate) == false) {
209 							layerCandidates.add(candidate);
210 						}
211 					}
212 				}
213 			}
214 
215 			/**
216 			 * We need to ensure that all the nodes pointed at the candidate are either a dead node (and we don't care) or
217 			 * is already in the processedSet
218 			 */
219 			for (final Integer candidate : layerCandidates) {
220 				final var backwardLinks = backwardConnections.getOrDefault(candidate, Set.of());
221 
222 				final boolean allBackwardInEndSet = backwardLinks.stream()
223 						.allMatch(next -> processedSet.contains(next) || deadNodes.contains(next));
224 
225 				if (allBackwardInEndSet) {
226 					layer.add(candidate);
227 				}
228 			}
229 
230 			if (layer.size() == 0) {
231 				done = true;
232 				layer.addAll(outputNodeIndices);
233 			} else {
234 				processedSet.addAll(layer);
235 			}
236 			layers.add(layer);
237 		}
238 		return layers;
239 	}
240 
241 	public static float compatibilityDistance(final List<Connection> firstConnections,
242 			final List<Connection> secondConnections, final float c1, final float c2, final float c3) {
243 		if (firstConnections == null || secondConnections == null) {
244 			return Float.MAX_VALUE;
245 		}
246 
247 		/**
248 		 * Both connections are expected to already be sorted
249 		 */
250 
251 		final int maxConnectionSize = Math.max(firstConnections.size(), secondConnections.size());
252 		final float n = maxConnectionSize < 20 ? 1.0f : maxConnectionSize;
253 
254 		int disjointGenes = 0;
255 
256 		float sumWeightDifference = 0;
257 		int numMatchingGenes = 0;
258 
259 		int indexFirst = 0;
260 		int indexSecond = 0;
261 
262 		while (indexFirst < firstConnections.size() && indexSecond < secondConnections.size()) {
263 
264 			final Connection firstConnection = firstConnections.get(indexFirst);
265 			final int firstInnovation = firstConnection.innovation();
266 
267 			final Connection secondConnection = secondConnections.get(indexSecond);
268 			final int secondInnovation = secondConnection.innovation();
269 
270 			if (firstInnovation == secondInnovation) {
271 				sumWeightDifference += Math.abs(secondConnection.weight() - firstConnection.weight());
272 				numMatchingGenes++;
273 
274 				indexFirst++;
275 				indexSecond++;
276 			} else {
277 
278 				disjointGenes++;
279 
280 				if (firstInnovation < secondInnovation) {
281 					indexFirst++;
282 				} else {
283 					indexSecond++;
284 				}
285 			}
286 		}
287 
288 		int excessGenes = 0;
289 		/**
290 		 * We have consumed all elements from secondConnections and thus have their remaining difference as excess genes
291 		 */
292 		if (indexFirst < firstConnections.size()) {
293 			excessGenes += firstConnections.size() - indexSecond;
294 		} else if (indexSecond < secondConnections.size()) {
295 			excessGenes += secondConnections.size() - indexFirst;
296 		}
297 
298 		final float averageWeightDifference = sumWeightDifference / Math.max(1, numMatchingGenes);
299 
300 		return (c1 * excessGenes) / n + (c2 * disjointGenes) / n + c3 * averageWeightDifference;
301 	}
302 
303 	public static float compatibilityDistance(final Genotype genotype1, final Genotype genotype2,
304 			final int chromosomeIndex, final float c1, final float c2, final float c3) {
305 		Objects.requireNonNull(genotype1);
306 		Objects.requireNonNull(genotype2);
307 		Validate.isTrue(chromosomeIndex >= 0);
308 		Validate.isTrue(chromosomeIndex < genotype1.getSize());
309 		Validate.isTrue(chromosomeIndex < genotype2.getSize());
310 
311 		final var neatChromosome1 = genotype1.getChromosome(chromosomeIndex, NeatChromosome.class);
312 		final var connections1 = neatChromosome1.getConnections();
313 
314 		final var neatChromosome2 = genotype2.getChromosome(chromosomeIndex, NeatChromosome.class);
315 		final var connections2 = neatChromosome2.getConnections();
316 
317 		return compatibilityDistance(connections1, connections2, c1, c2, c3);
318 	}
319 
320 	public static <T extends Comparable<T>> List<Species<T>> speciate(final RandomGenerator random,
321 			final SpeciesIdGenerator speciesIdGenerator, final List<Species<T>> seedSpecies,
322 			final Population<T> population, final BiPredicate<Individual<T>, Individual<T>> speciesPredicate) {
323 		Objects.requireNonNull(random);
324 		Objects.requireNonNull(speciesIdGenerator);
325 		Objects.requireNonNull(seedSpecies);
326 		Objects.requireNonNull(population);
327 		Objects.requireNonNull(speciesPredicate);
328 
329 		final List<Species<T>> species = new ArrayList<>();
330 
331 		for (final Species<T> speciesIterator : seedSpecies) {
332 			final var speciesId = speciesIterator.getId();
333 			final int numMembers = speciesIterator.getNumMembers();
334 			if (numMembers > 0) {
335 				final int randomIndex = random.nextInt(numMembers);
336 				final var newAncestors = List.of(speciesIterator.getMembers().get(randomIndex));
337 				final var newSpecies = new Species<>(speciesId, newAncestors);
338 				species.add(newSpecies);
339 			}
340 		}
341 
342 		for (final Individual<T> individual : population) {
343 
344 			boolean existingSpeciesFound = false;
345 			int currentSpeciesIndex = 0;
346 			while (existingSpeciesFound == false && currentSpeciesIndex < species.size()) {
347 
348 				final var currentSpecies = species.get(currentSpeciesIndex);
349 
350 				final boolean anyAncestorMatch = currentSpecies.getAncestors()
351 						.stream()
352 						.anyMatch(candidate -> speciesPredicate.test(individual, candidate));
353 
354 				final boolean anyMemberMatch = currentSpecies.getMembers()
355 						.stream()
356 						.anyMatch(candidate -> speciesPredicate.test(individual, candidate));
357 
358 				if (anyAncestorMatch || anyMemberMatch) {
359 					currentSpecies.addMember(individual);
360 					existingSpeciesFound = true;
361 				} else {
362 					currentSpeciesIndex++;
363 				}
364 			}
365 
366 			if (existingSpeciesFound == false) {
367 				final int newSpeciesId = speciesIdGenerator.computeNewId();
368 				final var newSpecies = new Species<T>(newSpeciesId, List.of());
369 				newSpecies.addMember(individual);
370 				species.add(newSpecies);
371 			}
372 		}
373 
374 		return species.stream().filter(sp -> sp.getNumMembers() > 0).toList();
375 	}
376 }