View Javadoc
1   package net.bmahe.genetics4j.neat;
2   
3   import java.util.ArrayDeque;
4   import java.util.ArrayList;
5   import java.util.Deque;
6   import java.util.HashMap;
7   import java.util.HashSet;
8   import java.util.List;
9   import java.util.Map;
10  import java.util.Map.Entry;
11  import java.util.Set;
12  import java.util.function.BiPredicate;
13  import java.util.random.RandomGenerator;
14  
15  import org.apache.commons.collections4.CollectionUtils;
16  import org.apache.commons.lang3.Validate;
17  
18  import net.bmahe.genetics4j.core.Genotype;
19  import net.bmahe.genetics4j.core.Individual;
20  import net.bmahe.genetics4j.core.Population;
21  import net.bmahe.genetics4j.neat.chromosomes.NeatChromosome;
22  
23  /**
24   * Utility class providing core algorithmic operations for the NEAT (NeuroEvolution of Augmenting Topologies) algorithm.
25   * 
26   * <p>NeatUtils contains essential algorithms and helper methods for implementing NEAT neural network evolution,
27   * including network topology analysis, compatibility distance calculation, speciation, and structural operations. These
28   * utilities support the NEAT algorithm's key features of topology innovation, structural mutation, and species-based
29   * population organization.
30   * 
31   * <p>Key functionality areas:
32   * <ul>
33   * <li><strong>Network topology analysis</strong>: Computing network layers, forward/backward connections, and dead node
34   * detection</li>
35   * <li><strong>Compatibility distance</strong>: Measuring genetic similarity between neural networks for speciation</li>
36   * <li><strong>Speciation management</strong>: Organizing populations into species based on genetic similarity</li>
37   * <li><strong>Structural analysis</strong>: Analyzing network connectivity patterns and structural properties</li>
38   * </ul>
39   * 
40   * <p>NEAT algorithm integration:
41   * <ul>
42   * <li><strong>Innovation tracking</strong>: Support for historical marking and innovation numbers</li>
43   * <li><strong>Structural mutations</strong>: Utilities for add-node and add-connection operations</li>
44   * <li><strong>Network evaluation</strong>: Layer-based network evaluation ordering</li>
45   * <li><strong>Population diversity</strong>: Species-based diversity maintenance</li>
46   * </ul>
47   * 
48   * <p>Core NEAT concepts implemented:
49   * <ul>
50   * <li><strong>Genetic similarity</strong>: Compatibility distance based on excess, disjoint, and weight
51   * differences</li>
52   * <li><strong>Topological innovation</strong>: Structural changes tracked through innovation numbers</li>
53   * <li><strong>Speciation</strong>: Dynamic species formation based on genetic distance thresholds</li>
54   * <li><strong>Network evaluation</strong>: Feed-forward evaluation through computed network layers</li>
55   * </ul>
56   * 
57   * <p>Algorithmic foundations:
58   * <ul>
59   * <li><strong>Graph algorithms</strong>: Topological sorting, connectivity analysis, and layer computation</li>
60   * <li><strong>Genetic distance metrics</strong>: NEAT-specific compatibility distance calculation</li>
61   * <li><strong>Population clustering</strong>: Species formation and maintenance algorithms</li>
62   * <li><strong>Network optimization</strong>: Dead node removal and structural simplification</li>
63   * </ul>
64   * 
65   * @see NeatChromosome
66   * @see Connection
67   * @see Species
68   * @see InnovationManager
69   */
70  public class NeatUtils {
71  
72  	private NeatUtils() {
73  	}
74  
75  	/**
76  	 * Working backward from the output nodes, we identify the nodes that did not get visited as dead nodes
77  	 * 
78  	 * @param connections
79  	 * @param forwardConnections
80  	 * @param backwardConnections
81  	 * @param outputNodeIndices
82  	 * @return
83  	 */
84  	public static Set<Integer> computeDeadNodes(final List<Connection> connections,
85  			final Map<Integer, Set<Integer>> forwardConnections, final Map<Integer, Set<Integer>> backwardConnections,
86  			final Set<Integer> outputNodeIndices) {
87  		Validate.notNull(connections);
88  
89  		final Set<Integer> deadNodes = new HashSet<>();
90  		for (final Connection connection : connections) {
91  			deadNodes.add(connection.fromNodeIndex());
92  			deadNodes.add(connection.toNodeIndex());
93  		}
94  		deadNodes.removeAll(outputNodeIndices);
95  
96  		final Set<Integer> visited = new HashSet<>();
97  		final Deque<Integer> toVisit = new ArrayDeque<>(outputNodeIndices);
98  		while (toVisit.size() > 0) {
99  			final Integer currentNode = toVisit.poll();
100 
101 			deadNodes.remove(currentNode);
102 			if (visited.contains(currentNode) == false) {
103 
104 				visited.add(currentNode);
105 
106 				final var next = backwardConnections.getOrDefault(currentNode, Set.of());
107 				if (next.size() > 0) {
108 					toVisit.addAll(next);
109 				}
110 			}
111 		}
112 
113 		return deadNodes;
114 	}
115 
116 	public static Map<Integer, Set<Integer>> computeForwardLinks(final List<Connection> connections) {
117 		Validate.notNull(connections);
118 
119 		final Map<Integer, Set<Integer>> forwardConnections = new HashMap<>();
120 		for (final Connection connection : connections) {
121 			final var fromNodeIndex = connection.fromNodeIndex();
122 			final var toNodeIndex = connection.toNodeIndex();
123 
124 			if (connection.isEnabled()) {
125 				final var toNodes = forwardConnections.computeIfAbsent(fromNodeIndex, k -> new HashSet<>());
126 
127 				if (toNodes.add(toNodeIndex) == false) {
128 					throw new IllegalArgumentException(
129 							"Found duplicate entries for nodes defined in connection " + connection);
130 				}
131 			}
132 		}
133 
134 		return forwardConnections;
135 	}
136 
137 	public static Map<Integer, Set<Integer>> computeBackwardLinks(final List<Connection> connections) {
138 		Validate.notNull(connections);
139 
140 		final Map<Integer, Set<Integer>> backwardConnections = new HashMap<>();
141 		for (final Connection connection : connections) {
142 			final var fromNodeIndex = connection.fromNodeIndex();
143 			final var toNodeIndex = connection.toNodeIndex();
144 
145 			if (connection.isEnabled()) {
146 				final var fromNodes = backwardConnections.computeIfAbsent(toNodeIndex, k -> new HashSet<>());
147 
148 				if (fromNodes.add(fromNodeIndex) == false) {
149 					throw new IllegalArgumentException(
150 							"Found duplicate entries for nodes defined in connection " + connection);
151 				}
152 			}
153 		}
154 		return backwardConnections;
155 	}
156 
157 	public static Map<Integer, Set<Connection>> computeBackwardConnections(final List<Connection> connections) {
158 		Validate.notNull(connections);
159 
160 		final Map<Integer, Set<Connection>> backwardConnections = new HashMap<>();
161 		for (final Connection connection : connections) {
162 			final var toNodeIndex = connection.toNodeIndex();
163 
164 			if (connection.isEnabled()) {
165 				final var fromConnections = backwardConnections.computeIfAbsent(toNodeIndex, k -> new HashSet<>());
166 
167 				if (fromConnections.stream()
168 						.anyMatch(existingConnection -> existingConnection.fromNodeIndex() == connection.fromNodeIndex())) {
169 					throw new IllegalArgumentException(
170 							"Found duplicate entries for nodes defined in connection " + connection);
171 				}
172 				fromConnections.add(connection);
173 			}
174 		}
175 		return backwardConnections;
176 	}
177 
178 	public static List<List<Integer>> partitionLayersNodes(final Set<Integer> inputNodeIndices,
179 			final Set<Integer> outputNodeIndices, final List<Connection> connections) {
180 		Validate.isTrue(CollectionUtils.isNotEmpty(inputNodeIndices));
181 		Validate.isTrue(CollectionUtils.isNotEmpty(outputNodeIndices));
182 		Validate.isTrue(CollectionUtils.isNotEmpty(connections));
183 
184 		final Map<Integer, Set<Integer>> forwardConnections = computeForwardLinks(connections);
185 		final Map<Integer, Set<Integer>> backwardConnections = computeBackwardLinks(connections);
186 
187 		// Is it useful? If it's connected to the input node, it's not dead
188 		final var deadNodes = computeDeadNodes(connections, forwardConnections, backwardConnections, outputNodeIndices);
189 
190 		final Set<Integer> processedSet = new HashSet<>();
191 		final List<List<Integer>> layers = new ArrayList<>();
192 		processedSet.addAll(inputNodeIndices);
193 		layers.add(new ArrayList<>(inputNodeIndices));
194 
195 		boolean done = false;
196 		while (done == false) {
197 			final List<Integer> layer = new ArrayList<>();
198 
199 			final Set<Integer> layerCandidates = new HashSet<>();
200 			for (final Entry<Integer, Set<Integer>> entry : forwardConnections.entrySet()) {
201 				final var key = entry.getKey();
202 				final var values = entry.getValue();
203 
204 				if (processedSet.contains(key) == true) {
205 					for (final Integer candidate : values) {
206 						if (deadNodes.contains(candidate) == false && processedSet.contains(candidate) == false
207 								&& outputNodeIndices.contains(candidate) == false) {
208 							layerCandidates.add(candidate);
209 						}
210 					}
211 				}
212 			}
213 
214 			/**
215 			 * We need to ensure that all the nodes pointed at the candidate are either a dead node (and we don't care) or
216 			 * is already in the processedSet
217 			 */
218 			for (final Integer candidate : layerCandidates) {
219 				final var backwardLinks = backwardConnections.getOrDefault(candidate, Set.of());
220 
221 				final boolean allBackwardInEndSet = backwardLinks.stream()
222 						.allMatch(next -> processedSet.contains(next) || deadNodes.contains(next));
223 
224 				if (allBackwardInEndSet) {
225 					layer.add(candidate);
226 				}
227 			}
228 
229 			if (layer.size() == 0) {
230 				done = true;
231 				layer.addAll(outputNodeIndices);
232 			} else {
233 				processedSet.addAll(layer);
234 			}
235 			layers.add(layer);
236 		}
237 		return layers;
238 	}
239 
240 	public static float compatibilityDistance(final List<Connection> firstConnections,
241 			final List<Connection> secondConnections, final float c1, final float c2, final float c3) {
242 		if (firstConnections == null || secondConnections == null) {
243 			return Float.MAX_VALUE;
244 		}
245 
246 		/**
247 		 * Both connections are expected to already be sorted
248 		 */
249 
250 		final int maxConnectionSize = Math.max(firstConnections.size(), secondConnections.size());
251 		final float n = maxConnectionSize < 20 ? 1.0f : maxConnectionSize;
252 
253 		int disjointGenes = 0;
254 
255 		float sumWeightDifference = 0;
256 		int numMatchingGenes = 0;
257 
258 		int indexFirst = 0;
259 		int indexSecond = 0;
260 
261 		while (indexFirst < firstConnections.size() && indexSecond < secondConnections.size()) {
262 
263 			final Connection firstConnection = firstConnections.get(indexFirst);
264 			final int firstInnovation = firstConnection.innovation();
265 
266 			final Connection secondConnection = secondConnections.get(indexSecond);
267 			final int secondInnovation = secondConnection.innovation();
268 
269 			if (firstInnovation == secondInnovation) {
270 				sumWeightDifference += Math.abs(secondConnection.weight() - firstConnection.weight());
271 				numMatchingGenes++;
272 
273 				indexFirst++;
274 				indexSecond++;
275 			} else {
276 
277 				disjointGenes++;
278 
279 				if (firstInnovation < secondInnovation) {
280 					indexFirst++;
281 				} else {
282 					indexSecond++;
283 				}
284 			}
285 		}
286 
287 		int excessGenes = 0;
288 		/**
289 		 * We have consumed all elements from secondConnections and thus have their remaining difference as excess genes
290 		 */
291 		if (indexFirst < firstConnections.size()) {
292 			excessGenes += firstConnections.size() - indexSecond;
293 		} else if (indexSecond < secondConnections.size()) {
294 			excessGenes += secondConnections.size() - indexFirst;
295 		}
296 
297 		final float averageWeightDifference = sumWeightDifference / Math.max(1, numMatchingGenes);
298 
299 		return (c1 * excessGenes) / n + (c2 * disjointGenes) / n + c3 * averageWeightDifference;
300 	}
301 
302 	public static float compatibilityDistance(final Genotype genotype1, final Genotype genotype2,
303 			final int chromosomeIndex, final float c1, final float c2, final float c3) {
304 		Validate.notNull(genotype1);
305 		Validate.notNull(genotype2);
306 		Validate.isTrue(chromosomeIndex >= 0);
307 		Validate.isTrue(chromosomeIndex < genotype1.getSize());
308 		Validate.isTrue(chromosomeIndex < genotype2.getSize());
309 
310 		final var neatChromosome1 = genotype1.getChromosome(chromosomeIndex, NeatChromosome.class);
311 		final var connections1 = neatChromosome1.getConnections();
312 
313 		final var neatChromosome2 = genotype2.getChromosome(chromosomeIndex, NeatChromosome.class);
314 		final var connections2 = neatChromosome2.getConnections();
315 
316 		return compatibilityDistance(connections1, connections2, c1, c2, c3);
317 	}
318 
319 	public static <T extends Comparable<T>> List<Species<T>> speciate(final RandomGenerator random,
320 			final SpeciesIdGenerator speciesIdGenerator, final List<Species<T>> seedSpecies,
321 			final Population<T> population, final BiPredicate<Individual<T>, Individual<T>> speciesPredicate) {
322 		Validate.notNull(random);
323 		Validate.notNull(speciesIdGenerator);
324 		Validate.notNull(seedSpecies);
325 		Validate.notNull(population);
326 		Validate.notNull(speciesPredicate);
327 
328 		final List<Species<T>> species = new ArrayList<>();
329 
330 		for (final Species<T> speciesIterator : seedSpecies) {
331 			final var speciesId = speciesIterator.getId();
332 			final int numMembers = speciesIterator.getNumMembers();
333 			if (numMembers > 0) {
334 				final int randomIndex = random.nextInt(numMembers);
335 				final var newAncestors = List.of(speciesIterator.getMembers()
336 						.get(randomIndex));
337 				final var newSpecies = new Species<>(speciesId, newAncestors);
338 				species.add(newSpecies);
339 			}
340 		}
341 
342 		for (final Individual<T> individual : population) {
343 
344 			boolean existingSpeciesFound = false;
345 			int currentSpeciesIndex = 0;
346 			while (existingSpeciesFound == false && currentSpeciesIndex < species.size()) {
347 
348 				final var currentSpecies = species.get(currentSpeciesIndex);
349 
350 				final boolean anyAncestorMatch = currentSpecies.getAncestors()
351 						.stream()
352 						.anyMatch(candidate -> speciesPredicate.test(individual, candidate));
353 
354 				final boolean anyMemberMatch = currentSpecies.getMembers()
355 						.stream()
356 						.anyMatch(candidate -> speciesPredicate.test(individual, candidate));
357 
358 				if (anyAncestorMatch || anyMemberMatch) {
359 					currentSpecies.addMember(individual);
360 					existingSpeciesFound = true;
361 				} else {
362 					currentSpeciesIndex++;
363 				}
364 			}
365 
366 			if (existingSpeciesFound == false) {
367 				final int newSpeciesId = speciesIdGenerator.computeNewId();
368 				final var newSpecies = new Species<T>(newSpeciesId, List.of());
369 				newSpecies.addMember(individual);
370 				species.add(newSpecies);
371 			}
372 		}
373 
374 		return species.stream()
375 				.filter(sp -> sp.getNumMembers() > 0)
376 				.toList();
377 	}
378 }