1 package net.bmahe.genetics4j.neat.chromosomes; 2 3 import java.util.ArrayList; 4 import java.util.Collections; 5 import java.util.Comparator; 6 import java.util.List; 7 import java.util.Objects; 8 import java.util.Set; 9 import java.util.stream.Collectors; 10 import java.util.stream.IntStream; 11 12 import org.apache.commons.lang3.Validate; 13 14 import net.bmahe.genetics4j.core.chromosomes.Chromosome; 15 import net.bmahe.genetics4j.neat.Connection; 16 17 public class NeatChromosome implements Chromosome { 18 19 private final int numInputs; 20 private final int numOutputs; 21 private final float minWeightValue; 22 private final float maxWeightValue; 23 private final List<Connection> connections; 24 25 public NeatChromosome(final int _numInputs, final int _numOutputs, final float _minWeightValue, 26 final float _maxWeightValue, final List<Connection> _connections) { 27 Validate.isTrue(_numInputs > 0); 28 Validate.isTrue(_numOutputs > 0); 29 Validate.isTrue(_minWeightValue < _maxWeightValue); 30 Validate.notNull(_connections); 31 32 this.numInputs = _numInputs; 33 this.numOutputs = _numOutputs; 34 this.minWeightValue = _minWeightValue; 35 this.maxWeightValue = _maxWeightValue; 36 37 final List<Connection> copyOfConnections = new ArrayList<>(_connections); 38 Collections.sort(copyOfConnections, Comparator.comparing(Connection::innovation)); 39 this.connections = Collections.unmodifiableList(copyOfConnections); 40 } 41 42 @Override 43 public int getNumAlleles() { 44 return numInputs + numOutputs + connections.size(); 45 } 46 47 public int getNumInputs() { 48 return numInputs; 49 } 50 51 public int getNumOutputs() { 52 return numOutputs; 53 } 54 55 public float getMinWeightValue() { 56 return minWeightValue; 57 } 58 59 public float getMaxWeightValue() { 60 return maxWeightValue; 61 } 62 63 public List<Connection> getConnections() { 64 return connections; 65 } 66 67 public Set<Integer> getInputNodeIndices() { 68 return IntStream.range(0, numInputs) 69 .boxed() 70 .collect(Collectors.toSet()); 71 } 72 73 public Set<Integer> getOutputNodeIndices() { 74 return IntStream.range(numInputs, getNumInputs() + getNumOutputs()) 75 .boxed() 76 .collect(Collectors.toSet()); 77 } 78 79 @Override 80 public int hashCode() { 81 return Objects.hash(connections, maxWeightValue, minWeightValue, numInputs, numOutputs); 82 } 83 84 @Override 85 public boolean equals(Object obj) { 86 if (this == obj) 87 return true; 88 if (obj == null) 89 return false; 90 if (getClass() != obj.getClass()) 91 return false; 92 NeatChromosome other = (NeatChromosome) obj; 93 return Objects.equals(connections, other.connections) 94 && Float.floatToIntBits(maxWeightValue) == Float.floatToIntBits(other.maxWeightValue) 95 && Float.floatToIntBits(minWeightValue) == Float.floatToIntBits(other.minWeightValue) 96 && numInputs == other.numInputs && numOutputs == other.numOutputs; 97 } 98 99 @Override 100 public String toString() { 101 return "NeatChromosome [numInputs=" + numInputs + ", numOutputs=" + numOutputs + ", minWeightValue=" 102 + minWeightValue + ", maxWeightValue=" + maxWeightValue + ", connections=" + connections + "]"; 103 } 104 }