1 package net.bmahe.genetics4j.neat;
2
3 import java.util.ArrayDeque;
4 import java.util.ArrayList;
5 import java.util.Deque;
6 import java.util.HashMap;
7 import java.util.HashSet;
8 import java.util.List;
9 import java.util.Map;
10 import java.util.Map.Entry;
11 import java.util.Set;
12 import java.util.function.BiPredicate;
13 import java.util.random.RandomGenerator;
14
15 import org.apache.commons.collections4.CollectionUtils;
16 import org.apache.commons.lang3.Validate;
17
18 import net.bmahe.genetics4j.core.Genotype;
19 import net.bmahe.genetics4j.core.Individual;
20 import net.bmahe.genetics4j.core.Population;
21 import net.bmahe.genetics4j.neat.chromosomes.NeatChromosome;
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68 public class NeatUtils {
69
70 private NeatUtils() {
71 }
72
73
74
75
76
77
78
79
80
81
82
83 public static Set<Integer> computeDeadNodes(final List<Connection> connections,
84 final Map<Integer, Set<Integer>> forwardConnections, final Map<Integer, Set<Integer>> backwardConnections,
85 final Set<Integer> outputNodeIndices) {
86 Validate.notNull(connections);
87
88 final Set<Integer> deadNodes = new HashSet<>();
89 for (final Connection connection : connections) {
90 deadNodes.add(connection.fromNodeIndex());
91 deadNodes.add(connection.toNodeIndex());
92 }
93 deadNodes.removeAll(outputNodeIndices);
94
95 final Set<Integer> visited = new HashSet<>();
96 final Deque<Integer> toVisit = new ArrayDeque<>(outputNodeIndices);
97 while (toVisit.size() > 0) {
98 final Integer currentNode = toVisit.poll();
99
100 deadNodes.remove(currentNode);
101 if (visited.contains(currentNode) == false) {
102
103 visited.add(currentNode);
104
105 final var next = backwardConnections.getOrDefault(currentNode, Set.of());
106 if (next.size() > 0) {
107 toVisit.addAll(next);
108 }
109 }
110 }
111
112 return deadNodes;
113 }
114
115 public static Map<Integer, Set<Integer>> computeForwardLinks(final List<Connection> connections) {
116 Validate.notNull(connections);
117
118 final Map<Integer, Set<Integer>> forwardConnections = new HashMap<>();
119 for (final Connection connection : connections) {
120 final var fromNodeIndex = connection.fromNodeIndex();
121 final var toNodeIndex = connection.toNodeIndex();
122
123 if (connection.isEnabled()) {
124 final var toNodes = forwardConnections.computeIfAbsent(fromNodeIndex, k -> new HashSet<>());
125
126 if (toNodes.add(toNodeIndex) == false) {
127 throw new IllegalArgumentException(
128 "Found duplicate entries for nodes defined in connection " + connection);
129 }
130 }
131 }
132
133 return forwardConnections;
134 }
135
136 public static Map<Integer, Set<Integer>> computeBackwardLinks(final List<Connection> connections) {
137 Validate.notNull(connections);
138
139 final Map<Integer, Set<Integer>> backwardConnections = new HashMap<>();
140 for (final Connection connection : connections) {
141 final var fromNodeIndex = connection.fromNodeIndex();
142 final var toNodeIndex = connection.toNodeIndex();
143
144 if (connection.isEnabled()) {
145 final var fromNodes = backwardConnections.computeIfAbsent(toNodeIndex, k -> new HashSet<>());
146
147 if (fromNodes.add(fromNodeIndex) == false) {
148 throw new IllegalArgumentException(
149 "Found duplicate entries for nodes defined in connection " + connection);
150 }
151 }
152 }
153 return backwardConnections;
154 }
155
156 public static Map<Integer, Set<Connection>> computeBackwardConnections(final List<Connection> connections) {
157 Validate.notNull(connections);
158
159 final Map<Integer, Set<Connection>> backwardConnections = new HashMap<>();
160 for (final Connection connection : connections) {
161 final var toNodeIndex = connection.toNodeIndex();
162
163 if (connection.isEnabled()) {
164 final var fromConnections = backwardConnections.computeIfAbsent(toNodeIndex, k -> new HashSet<>());
165
166 if (fromConnections.stream()
167 .anyMatch(existingConnection -> existingConnection.fromNodeIndex() == connection.fromNodeIndex())) {
168 throw new IllegalArgumentException(
169 "Found duplicate entries for nodes defined in connection " + connection);
170 }
171 fromConnections.add(connection);
172 }
173 }
174 return backwardConnections;
175 }
176
177 public static List<List<Integer>> partitionLayersNodes(final Set<Integer> inputNodeIndices,
178 final Set<Integer> outputNodeIndices, final List<Connection> connections) {
179 Validate.isTrue(CollectionUtils.isNotEmpty(inputNodeIndices));
180 Validate.isTrue(CollectionUtils.isNotEmpty(outputNodeIndices));
181 Validate.isTrue(CollectionUtils.isNotEmpty(connections));
182
183 final Map<Integer, Set<Integer>> forwardConnections = computeForwardLinks(connections);
184 final Map<Integer, Set<Integer>> backwardConnections = computeBackwardLinks(connections);
185
186
187 final var deadNodes = computeDeadNodes(connections, forwardConnections, backwardConnections, outputNodeIndices);
188
189 final Set<Integer> processedSet = new HashSet<>();
190 final List<List<Integer>> layers = new ArrayList<>();
191 processedSet.addAll(inputNodeIndices);
192 layers.add(new ArrayList<>(inputNodeIndices));
193
194 boolean done = false;
195 while (done == false) {
196 final List<Integer> layer = new ArrayList<>();
197
198 final Set<Integer> layerCandidates = new HashSet<>();
199 for (final Entry<Integer, Set<Integer>> entry : forwardConnections.entrySet()) {
200 final var key = entry.getKey();
201 final var values = entry.getValue();
202
203 if (processedSet.contains(key) == true) {
204 for (final Integer candidate : values) {
205 if (deadNodes.contains(candidate) == false && processedSet.contains(candidate) == false
206 && outputNodeIndices.contains(candidate) == false) {
207 layerCandidates.add(candidate);
208 }
209 }
210 }
211 }
212
213
214
215
216
217 for (final Integer candidate : layerCandidates) {
218 final var backwardLinks = backwardConnections.getOrDefault(candidate, Set.of());
219
220 final boolean allBackwardInEndSet = backwardLinks.stream()
221 .allMatch(next -> processedSet.contains(next) || deadNodes.contains(next));
222
223 if (allBackwardInEndSet) {
224 layer.add(candidate);
225 }
226 }
227
228 if (layer.size() == 0) {
229 done = true;
230 layer.addAll(outputNodeIndices);
231 } else {
232 processedSet.addAll(layer);
233 }
234 layers.add(layer);
235 }
236 return layers;
237 }
238
239 public static float compatibilityDistance(final List<Connection> firstConnections,
240 final List<Connection> secondConnections, final float c1, final float c2, final float c3) {
241 if (firstConnections == null || secondConnections == null) {
242 return Float.MAX_VALUE;
243 }
244
245
246
247
248
249 final int maxConnectionSize = Math.max(firstConnections.size(), secondConnections.size());
250 final float n = maxConnectionSize < 20 ? 1.0f : maxConnectionSize;
251
252 int disjointGenes = 0;
253
254 float sumWeightDifference = 0;
255 int numMatchingGenes = 0;
256
257 int indexFirst = 0;
258 int indexSecond = 0;
259
260 while (indexFirst < firstConnections.size() && indexSecond < secondConnections.size()) {
261
262 final Connection firstConnection = firstConnections.get(indexFirst);
263 final int firstInnovation = firstConnection.innovation();
264
265 final Connection secondConnection = secondConnections.get(indexSecond);
266 final int secondInnovation = secondConnection.innovation();
267
268 if (firstInnovation == secondInnovation) {
269 sumWeightDifference += Math.abs(secondConnection.weight() - firstConnection.weight());
270 numMatchingGenes++;
271
272 indexFirst++;
273 indexSecond++;
274 } else {
275
276 disjointGenes++;
277
278 if (firstInnovation < secondInnovation) {
279 indexFirst++;
280 } else {
281 indexSecond++;
282 }
283 }
284 }
285
286 int excessGenes = 0;
287
288
289
290
291 if (indexFirst < firstConnections.size()) {
292 excessGenes += firstConnections.size() - indexSecond;
293 } else if (indexSecond < secondConnections.size()) {
294 excessGenes += secondConnections.size() - indexFirst;
295 }
296
297 final float averageWeightDifference = sumWeightDifference / Math.max(1, numMatchingGenes);
298
299 return (c1 * excessGenes) / n + (c2 * disjointGenes) / n + c3 * averageWeightDifference;
300 }
301
302 public static float compatibilityDistance(final Genotype genotype1, final Genotype genotype2,
303 final int chromosomeIndex, final float c1, final float c2, final float c3) {
304 Validate.notNull(genotype1);
305 Validate.notNull(genotype2);
306 Validate.isTrue(chromosomeIndex >= 0);
307 Validate.isTrue(chromosomeIndex < genotype1.getSize());
308 Validate.isTrue(chromosomeIndex < genotype2.getSize());
309
310 final var neatChromosome1 = genotype1.getChromosome(chromosomeIndex, NeatChromosome.class);
311 final var connections1 = neatChromosome1.getConnections();
312
313 final var neatChromosome2 = genotype2.getChromosome(chromosomeIndex, NeatChromosome.class);
314 final var connections2 = neatChromosome2.getConnections();
315
316 return compatibilityDistance(connections1, connections2, c1, c2, c3);
317 }
318
319 public static <T extends Comparable<T>> List<Species<T>> speciate(final RandomGenerator random,
320 final SpeciesIdGenerator speciesIdGenerator, final List<Species<T>> seedSpecies,
321 final Population<T> population, final BiPredicate<Individual<T>, Individual<T>> speciesPredicate) {
322 Validate.notNull(random);
323 Validate.notNull(speciesIdGenerator);
324 Validate.notNull(seedSpecies);
325 Validate.notNull(population);
326 Validate.notNull(speciesPredicate);
327
328 final List<Species<T>> species = new ArrayList<>();
329
330 for (final Species<T> speciesIterator : seedSpecies) {
331 final var speciesId = speciesIterator.getId();
332 final int numMembers = speciesIterator.getNumMembers();
333 if (numMembers > 0) {
334 final int randomIndex = random.nextInt(numMembers);
335 final var newAncestors = List.of(speciesIterator.getMembers()
336 .get(randomIndex));
337 final var newSpecies = new Species<>(speciesId, newAncestors);
338 species.add(newSpecies);
339 }
340 }
341
342 for (final Individual<T> individual : population) {
343
344 boolean existingSpeciesFound = false;
345 int currentSpeciesIndex = 0;
346 while (existingSpeciesFound == false && currentSpeciesIndex < species.size()) {
347
348 final var currentSpecies = species.get(currentSpeciesIndex);
349
350 final boolean anyAncestorMatch = currentSpecies.getAncestors()
351 .stream()
352 .anyMatch(candidate -> speciesPredicate.test(individual, candidate));
353
354 final boolean anyMemberMatch = currentSpecies.getMembers()
355 .stream()
356 .anyMatch(candidate -> speciesPredicate.test(individual, candidate));
357
358 if (anyAncestorMatch || anyMemberMatch) {
359 currentSpecies.addMember(individual);
360 existingSpeciesFound = true;
361 } else {
362 currentSpeciesIndex++;
363 }
364 }
365
366 if (existingSpeciesFound == false) {
367 final int newSpeciesId = speciesIdGenerator.computeNewId();
368 final var newSpecies = new Species<T>(newSpeciesId, List.of());
369 newSpecies.addMember(individual);
370 species.add(newSpecies);
371 }
372 }
373
374 return species.stream()
375 .filter(sp -> sp.getNumMembers() > 0)
376 .toList();
377 }
378 }