View Javadoc
1   package net.bmahe.genetics4j.gp.program;
2   
3   import java.util.random.RandomGenerator;
4   
5   import org.apache.commons.lang3.Validate;
6   
7   import net.bmahe.genetics4j.core.chromosomes.TreeNode;
8   import net.bmahe.genetics4j.gp.Operation;
9   import net.bmahe.genetics4j.gp.OperationFactory;
10  
11  public class StdProgramGenerator implements ProgramGenerator {
12  
13  	private final ProgramHelper programHelper;
14  	private final RandomGenerator randomGenerator;
15  
16  	@SuppressWarnings({ "unchecked", "rawtypes" })
17  	private <T, U> TreeNode<Operation<T>> generate(final Program program, Class<U> acceptedType, int maxDepth,
18  			int depth) {
19  
20  		OperationFactory currentNode = depth < maxDepth - 1 && randomGenerator.nextDouble() < 0.5
21  				? programHelper.pickRandomFunction(program, acceptedType)
22  				: programHelper.pickRandomTerminal(program, acceptedType);
23  
24  		final Operation<T> currentOperation = currentNode.build(program.inputSpec());
25  		final TreeNode<Operation<T>> currentTreeNode = new TreeNode<>(currentOperation);
26  
27  		final Class[] acceptedTypes = currentNode.acceptedTypes();
28  
29  		for (int i = 0; i < acceptedTypes.length; i++) {
30  			final Class childAcceptedType = acceptedTypes[i];
31  			final TreeNode<Operation<T>> operation = generate(program, childAcceptedType, maxDepth, depth + 1);
32  
33  			currentTreeNode.addChild(operation);
34  		}
35  
36  		return currentTreeNode;
37  	}
38  
39  	public StdProgramGenerator(final ProgramHelper _programHelper, final RandomGenerator _randomGenerator) {
40  		Validate.notNull(_programHelper);
41  		Validate.notNull(_randomGenerator);
42  
43  		this.programHelper = _programHelper;
44  		this.randomGenerator = _randomGenerator;
45  	}
46  
47  	@SuppressWarnings("rawtypes")
48  	@Override
49  	public TreeNode<Operation<?>> generate(final Program program, final int maxDepth) {
50  		Validate.notNull(program);
51  		Validate.isTrue(maxDepth > 0);
52  
53  		final OperationFactory currentNode = randomGenerator.nextDouble() < 0.98
54  				? programHelper.pickRandomFunction(program)
55  				: programHelper.pickRandomTerminal(program);
56  
57  		final Operation currentOperation = currentNode.build(program.inputSpec());
58  		final TreeNode<Operation<?>> currentTreeNode = new TreeNode<>(currentOperation);
59  
60  		Class[] acceptedTypes = currentNode.acceptedTypes();
61  
62  		for (int i = 0; i < acceptedTypes.length; i++) {
63  			final Class acceptedType = acceptedTypes[i];
64  			final TreeNode<Operation<?>> operation = generate(program, acceptedType, maxDepth, 1);
65  
66  			currentTreeNode.addChild(operation);
67  		}
68  
69  		return currentTreeNode;
70  	}
71  
72  	@Override
73  	public TreeNode<Operation<?>> generate(final Program program) {
74  		return generate(program, program.maxDepth());
75  	}
76  
77  	@Override
78  	public <T, U> TreeNode<Operation<T>> generate(final Program program, final int maxDepth, final Class<U> rootType) {
79  		Validate.notNull(program);
80  		Validate.notNull(rootType);
81  		Validate.isTrue(maxDepth > 0);
82  
83  		return generate(program, rootType, maxDepth, 0);
84  	}
85  }