1 package net.bmahe.genetics4j.neat;
2
3 import java.util.ArrayDeque;
4 import java.util.ArrayList;
5 import java.util.Deque;
6 import java.util.HashMap;
7 import java.util.HashSet;
8 import java.util.List;
9 import java.util.Map;
10 import java.util.Map.Entry;
11 import java.util.Objects;
12 import java.util.Set;
13 import java.util.function.BiPredicate;
14 import java.util.random.RandomGenerator;
15
16 import org.apache.commons.collections4.CollectionUtils;
17 import org.apache.commons.lang3.Validate;
18
19 import net.bmahe.genetics4j.core.Genotype;
20 import net.bmahe.genetics4j.core.Individual;
21 import net.bmahe.genetics4j.core.Population;
22 import net.bmahe.genetics4j.neat.chromosomes.NeatChromosome;
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71 public class NeatUtils {
72
73 private NeatUtils() {
74 }
75
76
77
78
79
80
81
82
83
84
85 public static Set<Integer> computeDeadNodes(final List<Connection> connections,
86 final Map<Integer, Set<Integer>> forwardConnections, final Map<Integer, Set<Integer>> backwardConnections,
87 final Set<Integer> outputNodeIndices) {
88 Objects.requireNonNull(connections);
89
90 final Set<Integer> deadNodes = new HashSet<>();
91 for (final Connection connection : connections) {
92 deadNodes.add(connection.fromNodeIndex());
93 deadNodes.add(connection.toNodeIndex());
94 }
95 deadNodes.removeAll(outputNodeIndices);
96
97 final Set<Integer> visited = new HashSet<>();
98 final Deque<Integer> toVisit = new ArrayDeque<>(outputNodeIndices);
99 while (toVisit.size() > 0) {
100 final Integer currentNode = toVisit.poll();
101
102 deadNodes.remove(currentNode);
103 if (visited.contains(currentNode) == false) {
104
105 visited.add(currentNode);
106
107 final var next = backwardConnections.getOrDefault(currentNode, Set.of());
108 if (next.size() > 0) {
109 toVisit.addAll(next);
110 }
111 }
112 }
113
114 return deadNodes;
115 }
116
117 public static Map<Integer, Set<Integer>> computeForwardLinks(final List<Connection> connections) {
118 Objects.requireNonNull(connections);
119
120 final Map<Integer, Set<Integer>> forwardConnections = new HashMap<>();
121 for (final Connection connection : connections) {
122 final var fromNodeIndex = connection.fromNodeIndex();
123 final var toNodeIndex = connection.toNodeIndex();
124
125 if (connection.isEnabled()) {
126 final var toNodes = forwardConnections.computeIfAbsent(fromNodeIndex, k -> new HashSet<>());
127
128 if (toNodes.add(toNodeIndex) == false) {
129 throw new IllegalArgumentException(
130 "Found duplicate entries for nodes defined in connection " + connection);
131 }
132 }
133 }
134
135 return forwardConnections;
136 }
137
138 public static Map<Integer, Set<Integer>> computeBackwardLinks(final List<Connection> connections) {
139 Objects.requireNonNull(connections);
140
141 final Map<Integer, Set<Integer>> backwardConnections = new HashMap<>();
142 for (final Connection connection : connections) {
143 final var fromNodeIndex = connection.fromNodeIndex();
144 final var toNodeIndex = connection.toNodeIndex();
145
146 if (connection.isEnabled()) {
147 final var fromNodes = backwardConnections.computeIfAbsent(toNodeIndex, k -> new HashSet<>());
148
149 if (fromNodes.add(fromNodeIndex) == false) {
150 throw new IllegalArgumentException(
151 "Found duplicate entries for nodes defined in connection " + connection);
152 }
153 }
154 }
155 return backwardConnections;
156 }
157
158 public static Map<Integer, Set<Connection>> computeBackwardConnections(final List<Connection> connections) {
159 Objects.requireNonNull(connections);
160
161 final Map<Integer, Set<Connection>> backwardConnections = new HashMap<>();
162 for (final Connection connection : connections) {
163 final var toNodeIndex = connection.toNodeIndex();
164
165 if (connection.isEnabled()) {
166 final var fromConnections = backwardConnections.computeIfAbsent(toNodeIndex, k -> new HashSet<>());
167
168 if (fromConnections.stream()
169 .anyMatch(existingConnection -> existingConnection.fromNodeIndex() == connection.fromNodeIndex())) {
170 throw new IllegalArgumentException(
171 "Found duplicate entries for nodes defined in connection " + connection);
172 }
173 fromConnections.add(connection);
174 }
175 }
176 return backwardConnections;
177 }
178
179 public static List<List<Integer>> partitionLayersNodes(final Set<Integer> inputNodeIndices,
180 final Set<Integer> outputNodeIndices, final List<Connection> connections) {
181 Validate.isTrue(CollectionUtils.isNotEmpty(inputNodeIndices));
182 Validate.isTrue(CollectionUtils.isNotEmpty(outputNodeIndices));
183 Validate.isTrue(CollectionUtils.isNotEmpty(connections));
184
185 final Map<Integer, Set<Integer>> forwardConnections = computeForwardLinks(connections);
186 final Map<Integer, Set<Integer>> backwardConnections = computeBackwardLinks(connections);
187
188
189 final var deadNodes = computeDeadNodes(connections, forwardConnections, backwardConnections, outputNodeIndices);
190
191 final Set<Integer> processedSet = new HashSet<>();
192 final List<List<Integer>> layers = new ArrayList<>();
193 processedSet.addAll(inputNodeIndices);
194 layers.add(new ArrayList<>(inputNodeIndices));
195
196 boolean done = false;
197 while (done == false) {
198 final List<Integer> layer = new ArrayList<>();
199
200 final Set<Integer> layerCandidates = new HashSet<>();
201 for (final Entry<Integer, Set<Integer>> entry : forwardConnections.entrySet()) {
202 final var key = entry.getKey();
203 final var values = entry.getValue();
204
205 if (processedSet.contains(key) == true) {
206 for (final Integer candidate : values) {
207 if (deadNodes.contains(candidate) == false && processedSet.contains(candidate) == false
208 && outputNodeIndices.contains(candidate) == false) {
209 layerCandidates.add(candidate);
210 }
211 }
212 }
213 }
214
215
216
217
218
219 for (final Integer candidate : layerCandidates) {
220 final var backwardLinks = backwardConnections.getOrDefault(candidate, Set.of());
221
222 final boolean allBackwardInEndSet = backwardLinks.stream()
223 .allMatch(next -> processedSet.contains(next) || deadNodes.contains(next));
224
225 if (allBackwardInEndSet) {
226 layer.add(candidate);
227 }
228 }
229
230 if (layer.size() == 0) {
231 done = true;
232 layer.addAll(outputNodeIndices);
233 } else {
234 processedSet.addAll(layer);
235 }
236 layers.add(layer);
237 }
238 return layers;
239 }
240
241 public static float compatibilityDistance(final List<Connection> firstConnections,
242 final List<Connection> secondConnections, final float c1, final float c2, final float c3) {
243 if (firstConnections == null || secondConnections == null) {
244 return Float.MAX_VALUE;
245 }
246
247
248
249
250
251 final int maxConnectionSize = Math.max(firstConnections.size(), secondConnections.size());
252 final float n = maxConnectionSize < 20 ? 1.0f : maxConnectionSize;
253
254 int disjointGenes = 0;
255
256 float sumWeightDifference = 0;
257 int numMatchingGenes = 0;
258
259 int indexFirst = 0;
260 int indexSecond = 0;
261
262 while (indexFirst < firstConnections.size() && indexSecond < secondConnections.size()) {
263
264 final Connection firstConnection = firstConnections.get(indexFirst);
265 final int firstInnovation = firstConnection.innovation();
266
267 final Connection secondConnection = secondConnections.get(indexSecond);
268 final int secondInnovation = secondConnection.innovation();
269
270 if (firstInnovation == secondInnovation) {
271 sumWeightDifference += Math.abs(secondConnection.weight() - firstConnection.weight());
272 numMatchingGenes++;
273
274 indexFirst++;
275 indexSecond++;
276 } else {
277
278 disjointGenes++;
279
280 if (firstInnovation < secondInnovation) {
281 indexFirst++;
282 } else {
283 indexSecond++;
284 }
285 }
286 }
287
288 int excessGenes = 0;
289
290
291
292 if (indexFirst < firstConnections.size()) {
293 excessGenes += firstConnections.size() - indexSecond;
294 } else if (indexSecond < secondConnections.size()) {
295 excessGenes += secondConnections.size() - indexFirst;
296 }
297
298 final float averageWeightDifference = sumWeightDifference / Math.max(1, numMatchingGenes);
299
300 return (c1 * excessGenes) / n + (c2 * disjointGenes) / n + c3 * averageWeightDifference;
301 }
302
303 public static float compatibilityDistance(final Genotype genotype1, final Genotype genotype2,
304 final int chromosomeIndex, final float c1, final float c2, final float c3) {
305 Objects.requireNonNull(genotype1);
306 Objects.requireNonNull(genotype2);
307 Validate.isTrue(chromosomeIndex >= 0);
308 Validate.isTrue(chromosomeIndex < genotype1.getSize());
309 Validate.isTrue(chromosomeIndex < genotype2.getSize());
310
311 final var neatChromosome1 = genotype1.getChromosome(chromosomeIndex, NeatChromosome.class);
312 final var connections1 = neatChromosome1.getConnections();
313
314 final var neatChromosome2 = genotype2.getChromosome(chromosomeIndex, NeatChromosome.class);
315 final var connections2 = neatChromosome2.getConnections();
316
317 return compatibilityDistance(connections1, connections2, c1, c2, c3);
318 }
319
320 public static <T extends Comparable<T>> List<Species<T>> speciate(final RandomGenerator random,
321 final SpeciesIdGenerator speciesIdGenerator, final List<Species<T>> seedSpecies,
322 final Population<T> population, final BiPredicate<Individual<T>, Individual<T>> speciesPredicate) {
323 Objects.requireNonNull(random);
324 Objects.requireNonNull(speciesIdGenerator);
325 Objects.requireNonNull(seedSpecies);
326 Objects.requireNonNull(population);
327 Objects.requireNonNull(speciesPredicate);
328
329 final List<Species<T>> species = new ArrayList<>();
330
331 for (final Species<T> speciesIterator : seedSpecies) {
332 final var speciesId = speciesIterator.getId();
333 final int numMembers = speciesIterator.getNumMembers();
334 if (numMembers > 0) {
335 final int randomIndex = random.nextInt(numMembers);
336 final var newAncestors = List.of(speciesIterator.getMembers().get(randomIndex));
337 final var newSpecies = new Species<>(speciesId, newAncestors);
338 species.add(newSpecies);
339 }
340 }
341
342 for (final Individual<T> individual : population) {
343
344 boolean existingSpeciesFound = false;
345 int currentSpeciesIndex = 0;
346 while (existingSpeciesFound == false && currentSpeciesIndex < species.size()) {
347
348 final var currentSpecies = species.get(currentSpeciesIndex);
349
350 final boolean anyAncestorMatch = currentSpecies.getAncestors()
351 .stream()
352 .anyMatch(candidate -> speciesPredicate.test(individual, candidate));
353
354 final boolean anyMemberMatch = currentSpecies.getMembers()
355 .stream()
356 .anyMatch(candidate -> speciesPredicate.test(individual, candidate));
357
358 if (anyAncestorMatch || anyMemberMatch) {
359 currentSpecies.addMember(individual);
360 existingSpeciesFound = true;
361 } else {
362 currentSpeciesIndex++;
363 }
364 }
365
366 if (existingSpeciesFound == false) {
367 final int newSpeciesId = speciesIdGenerator.computeNewId();
368 final var newSpecies = new Species<T>(newSpeciesId, List.of());
369 newSpecies.addMember(individual);
370 species.add(newSpecies);
371 }
372 }
373
374 return species.stream().filter(sp -> sp.getNumMembers() > 0).toList();
375 }
376 }