NeatChromosomeCombinator.java
package net.bmahe.genetics4j.neat.combination;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.random.RandomGenerator;
import org.apache.commons.lang3.Validate;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import net.bmahe.genetics4j.core.chromosomes.Chromosome;
import net.bmahe.genetics4j.core.combination.ChromosomeCombinator;
import net.bmahe.genetics4j.core.spec.AbstractEAConfiguration;
import net.bmahe.genetics4j.neat.Connection;
import net.bmahe.genetics4j.neat.chromosomes.NeatChromosome;
import net.bmahe.genetics4j.neat.combination.parentcompare.ChosenOtherChromosome;
import net.bmahe.genetics4j.neat.combination.parentcompare.ParentComparisonHandler;
import net.bmahe.genetics4j.neat.spec.combination.NeatCombination;
import net.bmahe.genetics4j.neat.spec.combination.parentcompare.ParentComparisonPolicy;
public class NeatChromosomeCombinator<T extends Comparable<T>> implements ChromosomeCombinator<T> {
public static final Logger logger = LogManager.getLogger(NeatChromosomeCombinator.class);
private final RandomGenerator randomGenerator;
private final NeatCombination neatCombination;
private final ParentComparisonHandler parentComparisonHandler;
private boolean linksCacheContainsConnection(final Map<Integer, Set<Integer>> linksCache,
final Connection connection) {
Validate.notNull(linksCache);
Validate.notNull(connection);
final int fromNodeIndex = connection.fromNodeIndex();
final int toNodeIndex = connection.toNodeIndex();
return linksCache.containsKey(fromNodeIndex) == true && linksCache.get(fromNodeIndex)
.contains(toNodeIndex) == true;
}
private void insertInlinksCache(final Map<Integer, Set<Integer>> linksCache, final Connection connection) {
Validate.notNull(linksCache);
Validate.notNull(connection);
final int fromNodeIndex = connection.fromNodeIndex();
final int toNodeIndex = connection.toNodeIndex();
linksCache.computeIfAbsent(fromNodeIndex, k -> new HashSet<>())
.add(toNodeIndex);
}
protected boolean shouldReEnable(final Connection chosenParent, final Connection otherParent) {
Validate.notNull(chosenParent);
Validate.notNull(otherParent);
boolean shouldReEnable = false;
if (chosenParent.isEnabled() == false && otherParent.isEnabled() == true) {
if (randomGenerator.nextDouble() < neatCombination.reenableGeneInheritanceThresold()) {
shouldReEnable = true;
}
}
return shouldReEnable;
}
public NeatChromosomeCombinator(final RandomGenerator _randomGenerator, final NeatCombination _neatCombination,
final ParentComparisonHandler _parentComparisonHandler) {
Validate.notNull(_randomGenerator);
Validate.notNull(_neatCombination);
Validate.notNull(_parentComparisonHandler);
this.randomGenerator = _randomGenerator;
this.neatCombination = _neatCombination;
this.parentComparisonHandler = _parentComparisonHandler;
}
@Override
public List<Chromosome> combine(final AbstractEAConfiguration<T> eaConfiguration, final Chromosome firstChromosome,
final T firstParentFitness, final Chromosome secondChromosome, final T secondParentFitness) {
Validate.notNull(eaConfiguration);
Validate.notNull(firstChromosome);
Validate.notNull(firstParentFitness);
Validate.isInstanceOf(NeatChromosome.class, firstChromosome);
Validate.notNull(secondChromosome);
Validate.notNull(secondParentFitness);
Validate.isInstanceOf(NeatChromosome.class, secondChromosome);
final NeatChromosome firstNeatChromosome = (NeatChromosome) firstChromosome;
final NeatChromosome secondNeatChromosome = (NeatChromosome) secondChromosome;
final Comparator<T> fitnessComparator = eaConfiguration.fitnessComparator();
final double inheritanceThresold = neatCombination.inheritanceThresold();
final ParentComparisonPolicy parentComparisonPolicy = neatCombination.parentComparisonPolicy();
final int fitnessComparison = fitnessComparator.compare(firstParentFitness, secondParentFitness);
final ChosenOtherChromosome comparedChromosomes = parentComparisonHandler
.compare(parentComparisonPolicy, firstNeatChromosome, secondNeatChromosome, fitnessComparison);
final NeatChromosome bestChromosome = comparedChromosomes.chosen();
final NeatChromosome worstChromosome = comparedChromosomes.other();
final List<Connection> combinedConnections = new ArrayList<>();
final Map<Integer, Set<Integer>> linksCache = new HashMap<>();
final var bestConnections = bestChromosome.getConnections();
final var worstConnections = worstChromosome.getConnections();
int indexBest = 0;
int indexWorst = 0;
while (indexBest < bestConnections.size() && indexWorst < worstConnections.size()) {
final var bestConnection = bestConnections.get(indexBest);
final var worstConnection = worstConnections.get(indexWorst);
if (bestConnection.innovation() == worstConnection.innovation()) {
/**
* If innovation is the same, we pick the connection randomly
*/
var original = bestConnection;
var other = worstConnection;
if (randomGenerator.nextDouble() < 1 - inheritanceThresold) {
original = worstConnection;
other = bestConnection;
}
if (linksCacheContainsConnection(linksCache, original) == false) {
/**
* If the chosen gene is disabled but the other one is enabled, then there is a
* chance we will re-enable it
*/
final boolean isEnabled = shouldReEnable(original, other) ? true : original.isEnabled();
final var childConnection = Connection.builder()
.from(original)
.isEnabled(isEnabled)
.build();
combinedConnections.add(childConnection);
insertInlinksCache(linksCache, original);
}
indexBest++;
indexWorst++;
} else if (bestConnection.innovation() > worstConnection.innovation()) {
/**
* If the fitnesses are equal, then we randomly inherit from the parent
* Otherwise, we do not inherit from the lesser gene
*/
if (fitnessComparison == 0 && randomGenerator.nextDouble() < 1.0 - inheritanceThresold) {
final var original = worstConnection;
if (linksCacheContainsConnection(linksCache, original) == false) {
combinedConnections.add(Connection.copyOf(original));
insertInlinksCache(linksCache, original);
}
}
indexWorst++;
} else {
/**
* If the fitnesses are equal, then we randomly inherit from the parent
* Otherwise, we always inherit from the better gene
*/
if (fitnessComparison != 0 || randomGenerator.nextDouble() < inheritanceThresold) {
if (linksCacheContainsConnection(linksCache, bestConnection) == false) {
combinedConnections.add(Connection.copyOf(bestConnection));
insertInlinksCache(linksCache, bestConnection);
}
}
indexBest++;
}
}
/*
* Case where the best connection has more genes. It's called excess genes
*/
while (indexBest < bestConnections.size()) {
/**
* If the fitnesses are equal, then we randomly inherit from the parent
* Otherwise, we always inherit from the better gene
*/
if (fitnessComparison != 0 || randomGenerator.nextDouble() < inheritanceThresold) {
final var bestConnection = bestConnections.get(indexBest);
if (linksCacheContainsConnection(linksCache, bestConnection) == false) {
combinedConnections.add(Connection.copyOf(bestConnection));
insertInlinksCache(linksCache, bestConnection);
}
}
indexBest++;
}
/*
* Case where the worst connection has more genes. It's called excess genes.
* Since we don't inherit when their fitness aren't equal, it means we can skip
* the excess genes from the weaker connections. However we will randomly
* inherit if their fitnesses are equal
*/
while (fitnessComparison == 0 && indexWorst < worstConnections.size()) {
if (randomGenerator.nextDouble() < 1.0 - inheritanceThresold) {
final var worstConnection = worstConnections.get(indexWorst);
if (linksCacheContainsConnection(linksCache, worstConnection) == false) {
combinedConnections.add(Connection.copyOf(worstConnection));
insertInlinksCache(linksCache, worstConnection);
}
}
indexWorst++;
}
return List.of(new NeatChromosome(bestChromosome.getNumInputs(),
bestChromosome.getNumOutputs(),
bestChromosome.getMinWeightValue(),
bestChromosome.getMaxWeightValue(),
combinedConnections));
}
}