1 package net.bmahe.genetics4j.gp.program;
2
3 import org.apache.commons.lang3.Validate;
4
5 import net.bmahe.genetics4j.core.chromosomes.TreeNode;
6 import net.bmahe.genetics4j.gp.Operation;
7 import net.bmahe.genetics4j.gp.OperationFactory;
8
9 public class GrowProgramGenerator implements ProgramGenerator {
10
11 private final ProgramHelper programHelper;
12
13 @SuppressWarnings({ "unchecked", "rawtypes" })
14 private <T, U> TreeNode<Operation<T>> generate(final Program program, final Class<U> acceptedType,
15 final int maxDepth, final int depth) {
16
17 OperationFactory currentNode = depth < maxDepth - 1
18 ? programHelper.pickRandomFunctionOrTerminal(program, acceptedType)
19 : programHelper.pickRandomTerminal(program, acceptedType);
20
21 final Operation<T> currentOperation = currentNode.build(program.inputSpec());
22 final TreeNode<Operation<T>> currentTreeNode = new TreeNode<>(currentOperation);
23
24 final Class[] acceptedTypes = currentNode.acceptedTypes();
25
26 for (int i = 0; i < acceptedTypes.length; i++) {
27 final Class childAcceptedType = acceptedTypes[i];
28 final TreeNode<Operation<T>> operation = generate(program, childAcceptedType, maxDepth, depth + 1);
29
30 currentTreeNode.addChild(operation);
31 }
32
33 return currentTreeNode;
34 }
35
36 public GrowProgramGenerator(final ProgramHelper _programHelper) {
37 Validate.notNull(_programHelper);
38
39 this.programHelper = _programHelper;
40 }
41
42 @Override
43 public TreeNode<Operation<?>> generate(final Program program) {
44 return generate(program, program.maxDepth());
45 }
46
47 @SuppressWarnings("rawtypes")
48 @Override
49 public TreeNode<Operation<?>> generate(final Program program, final int maxDepth) {
50 Validate.notNull(program);
51 Validate.isTrue(maxDepth > 0);
52
53 final OperationFactory currentNode = programHelper.pickRandomFunctionOrTerminal(program);
54
55 final Operation currentOperation = currentNode.build(program.inputSpec());
56 final TreeNode<Operation<?>> currentTreeNode = new TreeNode<>(currentOperation);
57
58 final Class[] acceptedTypes = currentNode.acceptedTypes();
59
60 for (int i = 0; i < acceptedTypes.length; i++) {
61 final Class acceptedType = acceptedTypes[i];
62 final TreeNode<Operation<?>> operation = generate(program, acceptedType, maxDepth, 1);
63
64 currentTreeNode.addChild(operation);
65 }
66
67 return currentTreeNode;
68 }
69
70 @Override
71 public <T, U> TreeNode<Operation<T>> generate(final Program program, final int maxDepth, final Class<U> rootType) {
72 Validate.notNull(program);
73 Validate.notNull(rootType);
74 Validate.isTrue(maxDepth > 0);
75
76 return generate(program, rootType, maxDepth, 0);
77 }
78 }