1 package net.bmahe.genetics4j.neat;
2
3 import java.util.ArrayDeque;
4 import java.util.ArrayList;
5 import java.util.Deque;
6 import java.util.HashMap;
7 import java.util.HashSet;
8 import java.util.List;
9 import java.util.Map;
10 import java.util.Map.Entry;
11 import java.util.Set;
12 import java.util.function.BiPredicate;
13 import java.util.random.RandomGenerator;
14
15 import org.apache.commons.collections4.CollectionUtils;
16 import org.apache.commons.lang3.Validate;
17
18 import net.bmahe.genetics4j.core.Genotype;
19 import net.bmahe.genetics4j.core.Individual;
20 import net.bmahe.genetics4j.core.Population;
21 import net.bmahe.genetics4j.neat.chromosomes.NeatChromosome;
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70 public class NeatUtils {
71
72 private NeatUtils() {
73 }
74
75
76
77
78
79
80
81
82
83
84 public static Set<Integer> computeDeadNodes(final List<Connection> connections,
85 final Map<Integer, Set<Integer>> forwardConnections, final Map<Integer, Set<Integer>> backwardConnections,
86 final Set<Integer> outputNodeIndices) {
87 Validate.notNull(connections);
88
89 final Set<Integer> deadNodes = new HashSet<>();
90 for (final Connection connection : connections) {
91 deadNodes.add(connection.fromNodeIndex());
92 deadNodes.add(connection.toNodeIndex());
93 }
94 deadNodes.removeAll(outputNodeIndices);
95
96 final Set<Integer> visited = new HashSet<>();
97 final Deque<Integer> toVisit = new ArrayDeque<>(outputNodeIndices);
98 while (toVisit.size() > 0) {
99 final Integer currentNode = toVisit.poll();
100
101 deadNodes.remove(currentNode);
102 if (visited.contains(currentNode) == false) {
103
104 visited.add(currentNode);
105
106 final var next = backwardConnections.getOrDefault(currentNode, Set.of());
107 if (next.size() > 0) {
108 toVisit.addAll(next);
109 }
110 }
111 }
112
113 return deadNodes;
114 }
115
116 public static Map<Integer, Set<Integer>> computeForwardLinks(final List<Connection> connections) {
117 Validate.notNull(connections);
118
119 final Map<Integer, Set<Integer>> forwardConnections = new HashMap<>();
120 for (final Connection connection : connections) {
121 final var fromNodeIndex = connection.fromNodeIndex();
122 final var toNodeIndex = connection.toNodeIndex();
123
124 if (connection.isEnabled()) {
125 final var toNodes = forwardConnections.computeIfAbsent(fromNodeIndex, k -> new HashSet<>());
126
127 if (toNodes.add(toNodeIndex) == false) {
128 throw new IllegalArgumentException(
129 "Found duplicate entries for nodes defined in connection " + connection);
130 }
131 }
132 }
133
134 return forwardConnections;
135 }
136
137 public static Map<Integer, Set<Integer>> computeBackwardLinks(final List<Connection> connections) {
138 Validate.notNull(connections);
139
140 final Map<Integer, Set<Integer>> backwardConnections = new HashMap<>();
141 for (final Connection connection : connections) {
142 final var fromNodeIndex = connection.fromNodeIndex();
143 final var toNodeIndex = connection.toNodeIndex();
144
145 if (connection.isEnabled()) {
146 final var fromNodes = backwardConnections.computeIfAbsent(toNodeIndex, k -> new HashSet<>());
147
148 if (fromNodes.add(fromNodeIndex) == false) {
149 throw new IllegalArgumentException(
150 "Found duplicate entries for nodes defined in connection " + connection);
151 }
152 }
153 }
154 return backwardConnections;
155 }
156
157 public static Map<Integer, Set<Connection>> computeBackwardConnections(final List<Connection> connections) {
158 Validate.notNull(connections);
159
160 final Map<Integer, Set<Connection>> backwardConnections = new HashMap<>();
161 for (final Connection connection : connections) {
162 final var toNodeIndex = connection.toNodeIndex();
163
164 if (connection.isEnabled()) {
165 final var fromConnections = backwardConnections.computeIfAbsent(toNodeIndex, k -> new HashSet<>());
166
167 if (fromConnections.stream()
168 .anyMatch(existingConnection -> existingConnection.fromNodeIndex() == connection.fromNodeIndex())) {
169 throw new IllegalArgumentException(
170 "Found duplicate entries for nodes defined in connection " + connection);
171 }
172 fromConnections.add(connection);
173 }
174 }
175 return backwardConnections;
176 }
177
178 public static List<List<Integer>> partitionLayersNodes(final Set<Integer> inputNodeIndices,
179 final Set<Integer> outputNodeIndices, final List<Connection> connections) {
180 Validate.isTrue(CollectionUtils.isNotEmpty(inputNodeIndices));
181 Validate.isTrue(CollectionUtils.isNotEmpty(outputNodeIndices));
182 Validate.isTrue(CollectionUtils.isNotEmpty(connections));
183
184 final Map<Integer, Set<Integer>> forwardConnections = computeForwardLinks(connections);
185 final Map<Integer, Set<Integer>> backwardConnections = computeBackwardLinks(connections);
186
187
188 final var deadNodes = computeDeadNodes(connections, forwardConnections, backwardConnections, outputNodeIndices);
189
190 final Set<Integer> processedSet = new HashSet<>();
191 final List<List<Integer>> layers = new ArrayList<>();
192 processedSet.addAll(inputNodeIndices);
193 layers.add(new ArrayList<>(inputNodeIndices));
194
195 boolean done = false;
196 while (done == false) {
197 final List<Integer> layer = new ArrayList<>();
198
199 final Set<Integer> layerCandidates = new HashSet<>();
200 for (final Entry<Integer, Set<Integer>> entry : forwardConnections.entrySet()) {
201 final var key = entry.getKey();
202 final var values = entry.getValue();
203
204 if (processedSet.contains(key) == true) {
205 for (final Integer candidate : values) {
206 if (deadNodes.contains(candidate) == false && processedSet.contains(candidate) == false
207 && outputNodeIndices.contains(candidate) == false) {
208 layerCandidates.add(candidate);
209 }
210 }
211 }
212 }
213
214
215
216
217
218 for (final Integer candidate : layerCandidates) {
219 final var backwardLinks = backwardConnections.getOrDefault(candidate, Set.of());
220
221 final boolean allBackwardInEndSet = backwardLinks.stream()
222 .allMatch(next -> processedSet.contains(next) || deadNodes.contains(next));
223
224 if (allBackwardInEndSet) {
225 layer.add(candidate);
226 }
227 }
228
229 if (layer.size() == 0) {
230 done = true;
231 layer.addAll(outputNodeIndices);
232 } else {
233 processedSet.addAll(layer);
234 }
235 layers.add(layer);
236 }
237 return layers;
238 }
239
240 public static float compatibilityDistance(final List<Connection> firstConnections,
241 final List<Connection> secondConnections, final float c1, final float c2, final float c3) {
242 if (firstConnections == null || secondConnections == null) {
243 return Float.MAX_VALUE;
244 }
245
246
247
248
249
250 final int maxConnectionSize = Math.max(firstConnections.size(), secondConnections.size());
251 final float n = maxConnectionSize < 20 ? 1.0f : maxConnectionSize;
252
253 int disjointGenes = 0;
254
255 float sumWeightDifference = 0;
256 int numMatchingGenes = 0;
257
258 int indexFirst = 0;
259 int indexSecond = 0;
260
261 while (indexFirst < firstConnections.size() && indexSecond < secondConnections.size()) {
262
263 final Connection firstConnection = firstConnections.get(indexFirst);
264 final int firstInnovation = firstConnection.innovation();
265
266 final Connection secondConnection = secondConnections.get(indexSecond);
267 final int secondInnovation = secondConnection.innovation();
268
269 if (firstInnovation == secondInnovation) {
270 sumWeightDifference += Math.abs(secondConnection.weight() - firstConnection.weight());
271 numMatchingGenes++;
272
273 indexFirst++;
274 indexSecond++;
275 } else {
276
277 disjointGenes++;
278
279 if (firstInnovation < secondInnovation) {
280 indexFirst++;
281 } else {
282 indexSecond++;
283 }
284 }
285 }
286
287 int excessGenes = 0;
288
289
290
291 if (indexFirst < firstConnections.size()) {
292 excessGenes += firstConnections.size() - indexSecond;
293 } else if (indexSecond < secondConnections.size()) {
294 excessGenes += secondConnections.size() - indexFirst;
295 }
296
297 final float averageWeightDifference = sumWeightDifference / Math.max(1, numMatchingGenes);
298
299 return (c1 * excessGenes) / n + (c2 * disjointGenes) / n + c3 * averageWeightDifference;
300 }
301
302 public static float compatibilityDistance(final Genotype genotype1, final Genotype genotype2,
303 final int chromosomeIndex, final float c1, final float c2, final float c3) {
304 Validate.notNull(genotype1);
305 Validate.notNull(genotype2);
306 Validate.isTrue(chromosomeIndex >= 0);
307 Validate.isTrue(chromosomeIndex < genotype1.getSize());
308 Validate.isTrue(chromosomeIndex < genotype2.getSize());
309
310 final var neatChromosome1 = genotype1.getChromosome(chromosomeIndex, NeatChromosome.class);
311 final var connections1 = neatChromosome1.getConnections();
312
313 final var neatChromosome2 = genotype2.getChromosome(chromosomeIndex, NeatChromosome.class);
314 final var connections2 = neatChromosome2.getConnections();
315
316 return compatibilityDistance(connections1, connections2, c1, c2, c3);
317 }
318
319 public static <T extends Comparable<T>> List<Species<T>> speciate(final RandomGenerator random,
320 final SpeciesIdGenerator speciesIdGenerator, final List<Species<T>> seedSpecies,
321 final Population<T> population, final BiPredicate<Individual<T>, Individual<T>> speciesPredicate) {
322 Validate.notNull(random);
323 Validate.notNull(speciesIdGenerator);
324 Validate.notNull(seedSpecies);
325 Validate.notNull(population);
326 Validate.notNull(speciesPredicate);
327
328 final List<Species<T>> species = new ArrayList<>();
329
330 for (final Species<T> speciesIterator : seedSpecies) {
331 final var speciesId = speciesIterator.getId();
332 final int numMembers = speciesIterator.getNumMembers();
333 if (numMembers > 0) {
334 final int randomIndex = random.nextInt(numMembers);
335 final var newAncestors = List.of(speciesIterator.getMembers()
336 .get(randomIndex));
337 final var newSpecies = new Species<>(speciesId, newAncestors);
338 species.add(newSpecies);
339 }
340 }
341
342 for (final Individual<T> individual : population) {
343
344 boolean existingSpeciesFound = false;
345 int currentSpeciesIndex = 0;
346 while (existingSpeciesFound == false && currentSpeciesIndex < species.size()) {
347
348 final var currentSpecies = species.get(currentSpeciesIndex);
349
350 final boolean anyAncestorMatch = currentSpecies.getAncestors()
351 .stream()
352 .anyMatch(candidate -> speciesPredicate.test(individual, candidate));
353
354 final boolean anyMemberMatch = currentSpecies.getMembers()
355 .stream()
356 .anyMatch(candidate -> speciesPredicate.test(individual, candidate));
357
358 if (anyAncestorMatch || anyMemberMatch) {
359 currentSpecies.addMember(individual);
360 existingSpeciesFound = true;
361 } else {
362 currentSpeciesIndex++;
363 }
364 }
365
366 if (existingSpeciesFound == false) {
367 final int newSpeciesId = speciesIdGenerator.computeNewId();
368 final var newSpecies = new Species<T>(newSpeciesId, List.of());
369 newSpecies.addMember(individual);
370 species.add(newSpecies);
371 }
372 }
373
374 return species.stream()
375 .filter(sp -> sp.getNumMembers() > 0)
376 .toList();
377 }
378 }