View Javadoc
1   package net.bmahe.genetics4j.neat;
2   
3   import java.util.ArrayDeque;
4   import java.util.ArrayList;
5   import java.util.Deque;
6   import java.util.HashMap;
7   import java.util.HashSet;
8   import java.util.List;
9   import java.util.Map;
10  import java.util.Map.Entry;
11  import java.util.Set;
12  import java.util.function.BiPredicate;
13  import java.util.random.RandomGenerator;
14  
15  import org.apache.commons.collections4.CollectionUtils;
16  import org.apache.commons.lang3.Validate;
17  
18  import net.bmahe.genetics4j.core.Genotype;
19  import net.bmahe.genetics4j.core.Individual;
20  import net.bmahe.genetics4j.core.Population;
21  import net.bmahe.genetics4j.neat.chromosomes.NeatChromosome;
22  
23  /**
24   * Utility class providing core algorithmic operations for the NEAT (NeuroEvolution of Augmenting Topologies) algorithm.
25   * 
26   * <p>NeatUtils contains essential algorithms and helper methods for implementing NEAT neural network evolution,
27   * including network topology analysis, compatibility distance calculation, speciation, and structural operations.
28   * These utilities support the NEAT algorithm's key features of topology innovation, structural mutation, and 
29   * species-based population organization.
30   * 
31   * <p>Key functionality areas:
32   * <ul>
33   * <li><strong>Network topology analysis</strong>: Computing network layers, forward/backward connections, and dead node detection</li>
34   * <li><strong>Compatibility distance</strong>: Measuring genetic similarity between neural networks for speciation</li>
35   * <li><strong>Speciation management</strong>: Organizing populations into species based on genetic similarity</li>
36   * <li><strong>Structural analysis</strong>: Analyzing network connectivity patterns and structural properties</li>
37   * </ul>
38   * 
39   * <p>NEAT algorithm integration:
40   * <ul>
41   * <li><strong>Innovation tracking</strong>: Support for historical marking and innovation numbers</li>
42   * <li><strong>Structural mutations</strong>: Utilities for add-node and add-connection operations</li>
43   * <li><strong>Network evaluation</strong>: Layer-based network evaluation ordering</li>
44   * <li><strong>Population diversity</strong>: Species-based diversity maintenance</li>
45   * </ul>
46   * 
47   * <p>Core NEAT concepts implemented:
48   * <ul>
49   * <li><strong>Genetic similarity</strong>: Compatibility distance based on excess, disjoint, and weight differences</li>
50   * <li><strong>Topological innovation</strong>: Structural changes tracked through innovation numbers</li>
51   * <li><strong>Speciation</strong>: Dynamic species formation based on genetic distance thresholds</li>
52   * <li><strong>Network evaluation</strong>: Feed-forward evaluation through computed network layers</li>
53   * </ul>
54   * 
55   * <p>Algorithmic foundations:
56   * <ul>
57   * <li><strong>Graph algorithms</strong>: Topological sorting, connectivity analysis, and layer computation</li>
58   * <li><strong>Genetic distance metrics</strong>: NEAT-specific compatibility distance calculation</li>
59   * <li><strong>Population clustering</strong>: Species formation and maintenance algorithms</li>
60   * <li><strong>Network optimization</strong>: Dead node removal and structural simplification</li>
61   * </ul>
62   * 
63   * @see NeatChromosome
64   * @see Connection
65   * @see Species
66   * @see InnovationManager
67   */
68  public class NeatUtils {
69  
70  	private NeatUtils() {
71  	}
72  
73  	/**
74  	 * Working backward from the output nodes, we identify the nodes that did not
75  	 * get visited as dead nodes
76  	 * 
77  	 * @param connections
78  	 * @param forwardConnections
79  	 * @param backwardConnections
80  	 * @param outputNodeIndices
81  	 * @return
82  	 */
83  	public static Set<Integer> computeDeadNodes(final List<Connection> connections,
84  			final Map<Integer, Set<Integer>> forwardConnections, final Map<Integer, Set<Integer>> backwardConnections,
85  			final Set<Integer> outputNodeIndices) {
86  		Validate.notNull(connections);
87  
88  		final Set<Integer> deadNodes = new HashSet<>();
89  		for (final Connection connection : connections) {
90  			deadNodes.add(connection.fromNodeIndex());
91  			deadNodes.add(connection.toNodeIndex());
92  		}
93  		deadNodes.removeAll(outputNodeIndices);
94  
95  		final Set<Integer> visited = new HashSet<>();
96  		final Deque<Integer> toVisit = new ArrayDeque<>(outputNodeIndices);
97  		while (toVisit.size() > 0) {
98  			final Integer currentNode = toVisit.poll();
99  
100 			deadNodes.remove(currentNode);
101 			if (visited.contains(currentNode) == false) {
102 
103 				visited.add(currentNode);
104 
105 				final var next = backwardConnections.getOrDefault(currentNode, Set.of());
106 				if (next.size() > 0) {
107 					toVisit.addAll(next);
108 				}
109 			}
110 		}
111 
112 		return deadNodes;
113 	}
114 
115 	public static Map<Integer, Set<Integer>> computeForwardLinks(final List<Connection> connections) {
116 		Validate.notNull(connections);
117 
118 		final Map<Integer, Set<Integer>> forwardConnections = new HashMap<>();
119 		for (final Connection connection : connections) {
120 			final var fromNodeIndex = connection.fromNodeIndex();
121 			final var toNodeIndex = connection.toNodeIndex();
122 
123 			if (connection.isEnabled()) {
124 				final var toNodes = forwardConnections.computeIfAbsent(fromNodeIndex, k -> new HashSet<>());
125 
126 				if (toNodes.add(toNodeIndex) == false) {
127 					throw new IllegalArgumentException(
128 							"Found duplicate entries for nodes defined in connection " + connection);
129 				}
130 			}
131 		}
132 
133 		return forwardConnections;
134 	}
135 
136 	public static Map<Integer, Set<Integer>> computeBackwardLinks(final List<Connection> connections) {
137 		Validate.notNull(connections);
138 
139 		final Map<Integer, Set<Integer>> backwardConnections = new HashMap<>();
140 		for (final Connection connection : connections) {
141 			final var fromNodeIndex = connection.fromNodeIndex();
142 			final var toNodeIndex = connection.toNodeIndex();
143 
144 			if (connection.isEnabled()) {
145 				final var fromNodes = backwardConnections.computeIfAbsent(toNodeIndex, k -> new HashSet<>());
146 
147 				if (fromNodes.add(fromNodeIndex) == false) {
148 					throw new IllegalArgumentException(
149 							"Found duplicate entries for nodes defined in connection " + connection);
150 				}
151 			}
152 		}
153 		return backwardConnections;
154 	}
155 
156 	public static Map<Integer, Set<Connection>> computeBackwardConnections(final List<Connection> connections) {
157 		Validate.notNull(connections);
158 
159 		final Map<Integer, Set<Connection>> backwardConnections = new HashMap<>();
160 		for (final Connection connection : connections) {
161 			final var toNodeIndex = connection.toNodeIndex();
162 
163 			if (connection.isEnabled()) {
164 				final var fromConnections = backwardConnections.computeIfAbsent(toNodeIndex, k -> new HashSet<>());
165 
166 				if (fromConnections.stream()
167 						.anyMatch(existingConnection -> existingConnection.fromNodeIndex() == connection.fromNodeIndex())) {
168 					throw new IllegalArgumentException(
169 							"Found duplicate entries for nodes defined in connection " + connection);
170 				}
171 				fromConnections.add(connection);
172 			}
173 		}
174 		return backwardConnections;
175 	}
176 
177 	public static List<List<Integer>> partitionLayersNodes(final Set<Integer> inputNodeIndices,
178 			final Set<Integer> outputNodeIndices, final List<Connection> connections) {
179 		Validate.isTrue(CollectionUtils.isNotEmpty(inputNodeIndices));
180 		Validate.isTrue(CollectionUtils.isNotEmpty(outputNodeIndices));
181 		Validate.isTrue(CollectionUtils.isNotEmpty(connections));
182 
183 		final Map<Integer, Set<Integer>> forwardConnections = computeForwardLinks(connections);
184 		final Map<Integer, Set<Integer>> backwardConnections = computeBackwardLinks(connections);
185 
186 		// Is it useful? If it's connected to the input node, it's not dead
187 		final var deadNodes = computeDeadNodes(connections, forwardConnections, backwardConnections, outputNodeIndices);
188 
189 		final Set<Integer> processedSet = new HashSet<>();
190 		final List<List<Integer>> layers = new ArrayList<>();
191 		processedSet.addAll(inputNodeIndices);
192 		layers.add(new ArrayList<>(inputNodeIndices));
193 
194 		boolean done = false;
195 		while (done == false) {
196 			final List<Integer> layer = new ArrayList<>();
197 
198 			final Set<Integer> layerCandidates = new HashSet<>();
199 			for (final Entry<Integer, Set<Integer>> entry : forwardConnections.entrySet()) {
200 				final var key = entry.getKey();
201 				final var values = entry.getValue();
202 
203 				if (processedSet.contains(key) == true) {
204 					for (final Integer candidate : values) {
205 						if (deadNodes.contains(candidate) == false && processedSet.contains(candidate) == false
206 								&& outputNodeIndices.contains(candidate) == false) {
207 							layerCandidates.add(candidate);
208 						}
209 					}
210 				}
211 			}
212 
213 			/**
214 			 * We need to ensure that all the nodes pointed at the candidate are either a
215 			 * dead node (and we don't care) or is already in the processedSet
216 			 */
217 			for (final Integer candidate : layerCandidates) {
218 				final var backwardLinks = backwardConnections.getOrDefault(candidate, Set.of());
219 
220 				final boolean allBackwardInEndSet = backwardLinks.stream()
221 						.allMatch(next -> processedSet.contains(next) || deadNodes.contains(next));
222 
223 				if (allBackwardInEndSet) {
224 					layer.add(candidate);
225 				}
226 			}
227 
228 			if (layer.size() == 0) {
229 				done = true;
230 				layer.addAll(outputNodeIndices);
231 			} else {
232 				processedSet.addAll(layer);
233 			}
234 			layers.add(layer);
235 		}
236 		return layers;
237 	}
238 
239 	public static float compatibilityDistance(final List<Connection> firstConnections,
240 			final List<Connection> secondConnections, final float c1, final float c2, final float c3) {
241 		if (firstConnections == null || secondConnections == null) {
242 			return Float.MAX_VALUE;
243 		}
244 
245 		/**
246 		 * Both connections are expected to already be sorted
247 		 */
248 
249 		final int maxConnectionSize = Math.max(firstConnections.size(), secondConnections.size());
250 		final float n = maxConnectionSize < 20 ? 1.0f : maxConnectionSize;
251 
252 		int disjointGenes = 0;
253 
254 		float sumWeightDifference = 0;
255 		int numMatchingGenes = 0;
256 
257 		int indexFirst = 0;
258 		int indexSecond = 0;
259 
260 		while (indexFirst < firstConnections.size() && indexSecond < secondConnections.size()) {
261 
262 			final Connection firstConnection = firstConnections.get(indexFirst);
263 			final int firstInnovation = firstConnection.innovation();
264 
265 			final Connection secondConnection = secondConnections.get(indexSecond);
266 			final int secondInnovation = secondConnection.innovation();
267 
268 			if (firstInnovation == secondInnovation) {
269 				sumWeightDifference += Math.abs(secondConnection.weight() - firstConnection.weight());
270 				numMatchingGenes++;
271 
272 				indexFirst++;
273 				indexSecond++;
274 			} else {
275 
276 				disjointGenes++;
277 
278 				if (firstInnovation < secondInnovation) {
279 					indexFirst++;
280 				} else {
281 					indexSecond++;
282 				}
283 			}
284 		}
285 
286 		int excessGenes = 0;
287 		/**
288 		 * We have consumed all elements from secondConnections and thus have their
289 		 * remaining difference as excess genes
290 		 */
291 		if (indexFirst < firstConnections.size()) {
292 			excessGenes += firstConnections.size() - indexSecond;
293 		} else if (indexSecond < secondConnections.size()) {
294 			excessGenes += secondConnections.size() - indexFirst;
295 		}
296 
297 		final float averageWeightDifference = sumWeightDifference / Math.max(1, numMatchingGenes);
298 
299 		return (c1 * excessGenes) / n + (c2 * disjointGenes) / n + c3 * averageWeightDifference;
300 	}
301 
302 	public static float compatibilityDistance(final Genotype genotype1, final Genotype genotype2,
303 			final int chromosomeIndex, final float c1, final float c2, final float c3) {
304 		Validate.notNull(genotype1);
305 		Validate.notNull(genotype2);
306 		Validate.isTrue(chromosomeIndex >= 0);
307 		Validate.isTrue(chromosomeIndex < genotype1.getSize());
308 		Validate.isTrue(chromosomeIndex < genotype2.getSize());
309 
310 		final var neatChromosome1 = genotype1.getChromosome(chromosomeIndex, NeatChromosome.class);
311 		final var connections1 = neatChromosome1.getConnections();
312 
313 		final var neatChromosome2 = genotype2.getChromosome(chromosomeIndex, NeatChromosome.class);
314 		final var connections2 = neatChromosome2.getConnections();
315 
316 		return compatibilityDistance(connections1, connections2, c1, c2, c3);
317 	}
318 
319 	public static <T extends Comparable<T>> List<Species<T>> speciate(final RandomGenerator random,
320 			final SpeciesIdGenerator speciesIdGenerator, final List<Species<T>> seedSpecies,
321 			final Population<T> population, final BiPredicate<Individual<T>, Individual<T>> speciesPredicate) {
322 		Validate.notNull(random);
323 		Validate.notNull(speciesIdGenerator);
324 		Validate.notNull(seedSpecies);
325 		Validate.notNull(population);
326 		Validate.notNull(speciesPredicate);
327 
328 		final List<Species<T>> species = new ArrayList<>();
329 
330 		for (final Species<T> speciesIterator : seedSpecies) {
331 			final var speciesId = speciesIterator.getId();
332 			final int numMembers = speciesIterator.getNumMembers();
333 			if (numMembers > 0) {
334 				final int randomIndex = random.nextInt(numMembers);
335 				final var newAncestors = List.of(speciesIterator.getMembers()
336 						.get(randomIndex));
337 				final var newSpecies = new Species<>(speciesId, newAncestors);
338 				species.add(newSpecies);
339 			}
340 		}
341 
342 		for (final Individual<T> individual : population) {
343 
344 			boolean existingSpeciesFound = false;
345 			int currentSpeciesIndex = 0;
346 			while (existingSpeciesFound == false && currentSpeciesIndex < species.size()) {
347 
348 				final var currentSpecies = species.get(currentSpeciesIndex);
349 
350 				final boolean anyAncestorMatch = currentSpecies.getAncestors()
351 						.stream()
352 						.anyMatch(candidate -> speciesPredicate.test(individual, candidate));
353 
354 				final boolean anyMemberMatch = currentSpecies.getMembers()
355 						.stream()
356 						.anyMatch(candidate -> speciesPredicate.test(individual, candidate));
357 
358 				if (anyAncestorMatch || anyMemberMatch) {
359 					currentSpecies.addMember(individual);
360 					existingSpeciesFound = true;
361 				} else {
362 					currentSpeciesIndex++;
363 				}
364 			}
365 
366 			if (existingSpeciesFound == false) {
367 				final int newSpeciesId = speciesIdGenerator.computeNewId();
368 				final var newSpecies = new Species<T>(newSpeciesId, List.of());
369 				newSpecies.addMember(individual);
370 				species.add(newSpecies);
371 			}
372 		}
373 
374 		return species.stream()
375 				.filter(sp -> sp.getNumMembers() > 0)
376 				.toList();
377 	}
378 }