View Javadoc
1   package net.bmahe.genetics4j.gpu;
2   
3   import java.io.IOException;
4   import java.nio.charset.StandardCharsets;
5   import java.util.ArrayList;
6   import java.util.HashMap;
7   import java.util.List;
8   import java.util.Map;
9   import java.util.Set;
10  import java.util.concurrent.CompletableFuture;
11  import java.util.concurrent.ExecutorService;
12  
13  import org.apache.commons.collections4.ListUtils;
14  import org.apache.commons.io.IOUtils;
15  import org.apache.commons.lang3.Validate;
16  import org.apache.commons.lang3.tuple.Pair;
17  import org.apache.logging.log4j.LogManager;
18  import org.apache.logging.log4j.Logger;
19  import org.jocl.CL;
20  import org.jocl.cl_command_queue;
21  import org.jocl.cl_context;
22  import org.jocl.cl_context_properties;
23  import org.jocl.cl_device_id;
24  import org.jocl.cl_kernel;
25  import org.jocl.cl_platform_id;
26  import org.jocl.cl_program;
27  import org.jocl.cl_queue_properties;
28  
29  import net.bmahe.genetics4j.core.Genotype;
30  import net.bmahe.genetics4j.core.evaluation.FitnessEvaluator;
31  import net.bmahe.genetics4j.gpu.opencl.DeviceReader;
32  import net.bmahe.genetics4j.gpu.opencl.DeviceUtils;
33  import net.bmahe.genetics4j.gpu.opencl.KernelInfoReader;
34  import net.bmahe.genetics4j.gpu.opencl.OpenCLExecutionContext;
35  import net.bmahe.genetics4j.gpu.opencl.PlatformReader;
36  import net.bmahe.genetics4j.gpu.opencl.PlatformUtils;
37  import net.bmahe.genetics4j.gpu.opencl.model.Device;
38  import net.bmahe.genetics4j.gpu.opencl.model.KernelInfo;
39  import net.bmahe.genetics4j.gpu.opencl.model.Platform;
40  import net.bmahe.genetics4j.gpu.spec.GPUEAConfiguration;
41  import net.bmahe.genetics4j.gpu.spec.GPUEAExecutionContext;
42  import net.bmahe.genetics4j.gpu.spec.Program;
43  
44  public class GPUFitnessEvaluator<T extends Comparable<T>> implements FitnessEvaluator<T> {
45  	public static final Logger logger = LogManager.getLogger(GPUFitnessEvaluator.class);
46  
47  	private final GPUEAExecutionContext<T> gpuEAExecutionContext;
48  	private final GPUEAConfiguration<T> gpuEAConfiguration;
49  	private final ExecutorService executorService;
50  
51  	private List<Pair<Platform, Device>> selectedPlatformToDevice;
52  
53  	final List<cl_context> clContexts = new ArrayList<>();
54  	final List<cl_command_queue> clCommandQueues = new ArrayList<>();
55  	final List<cl_program> clPrograms = new ArrayList<>();
56  	final List<Map<String, cl_kernel>> clKernels = new ArrayList<>();
57  	final List<OpenCLExecutionContext> clExecutionContexts = new ArrayList<>();
58  
59  	public GPUFitnessEvaluator(final GPUEAExecutionContext<T> _gpuEAExecutionContext,
60  			final GPUEAConfiguration<T> _gpuEAConfiguration, final ExecutorService _executorService) {
61  		Validate.notNull(_gpuEAExecutionContext);
62  		Validate.notNull(_gpuEAConfiguration);
63  		Validate.notNull(_executorService);
64  
65  		this.gpuEAExecutionContext = _gpuEAExecutionContext;
66  		this.gpuEAConfiguration = _gpuEAConfiguration;
67  		this.executorService = _executorService;
68  
69  		CL.setExceptionsEnabled(true);
70  	}
71  
72  	private String loadResource(final String filename) {
73  		Validate.notBlank(filename);
74  
75  		try {
76  			return IOUtils.resourceToString(filename, StandardCharsets.UTF_8);
77  		} catch (IOException e) {
78  			throw new IllegalStateException("Unable to load resource " + filename, e);
79  		}
80  	}
81  
82  	private List<String> grabProgramSources() {
83  		final Program programSpec = gpuEAConfiguration.program();
84  
85  		logger.info("Load program source: {}", programSpec);
86  
87  		final List<String> sources = new ArrayList<>();
88  
89  		sources.addAll(programSpec.content());
90  
91  		programSpec.resources()
92  				.stream()
93  				.map(resource -> loadResource(resource))
94  				.forEach(program -> {
95  					sources.add(program);
96  				});
97  
98  		return sources;
99  	}
100 
101 	@Override
102 	public void preEvaluation() {
103 		logger.trace("Init...");
104 		FitnessEvaluator.super.preEvaluation();
105 
106 		final var platformReader = new PlatformReader();
107 		final var deviceReader = new DeviceReader();
108 		final var kernelInfoReader = new KernelInfoReader();
109 
110 		final int numPlatforms = PlatformUtils.numPlatforms();
111 		logger.info("Found {} platforms", numPlatforms);
112 
113 		final List<cl_platform_id> platformIds = PlatformUtils.platformIds(numPlatforms);
114 
115 		logger.info("Selecting platform and devices");
116 		final var platformFilters = gpuEAExecutionContext.platformFilters();
117 		final var deviceFilters = gpuEAExecutionContext.deviceFilters();
118 
119 		selectedPlatformToDevice = platformIds.stream()
120 				.map(platformReader::read)
121 				.filter(platformFilters)
122 				.flatMap(platform -> {
123 					final var platformId = platform.platformId();
124 					final int numDevices = DeviceUtils.numDevices(platformId);
125 					logger.trace("\tPlatform {}: {} devices", platform.name(), numDevices);
126 
127 					final var deviceIds = DeviceUtils.getDeviceIds(platformId, numDevices);
128 					return deviceIds.stream()
129 							.map(deviceId -> Pair.of(platform, deviceId));
130 				})
131 				.map(platformToDeviceId -> {
132 					final var platform = platformToDeviceId.getLeft();
133 					final var platformId = platform.platformId();
134 					final var deviceID = platformToDeviceId.getRight();
135 
136 					return Pair.of(platform, deviceReader.read(platformId, deviceID));
137 				})
138 				.filter(platformToDevice -> deviceFilters.test(platformToDevice.getRight()))
139 				.toList();
140 
141 		if (logger.isTraceEnabled()) {
142 			logger.trace("============================");
143 			logger.trace("Selected devices:");
144 			selectedPlatformToDevice.forEach(pd -> {
145 				logger.trace("{}", pd.getLeft());
146 				logger.trace("\t{}", pd.getRight());
147 			});
148 			logger.trace("============================");
149 		}
150 
151 		Validate.isTrue(selectedPlatformToDevice.size() > 0);
152 
153 		final List<String> programs = grabProgramSources();
154 		final String[] programsArr = programs.toArray(new String[programs.size()]);
155 
156 		for (final var platformAndDevice : selectedPlatformToDevice) {
157 			final var platform = platformAndDevice.getLeft();
158 			final var device = platformAndDevice.getRight();
159 
160 			logger.info("Processing platform [{}] / device [{}]", platform.name(), device.name());
161 
162 			logger.info("\tCreating context");
163 			cl_context_properties contextProperties = new cl_context_properties();
164 			contextProperties.addProperty(CL.CL_CONTEXT_PLATFORM, platform.platformId());
165 
166 			final cl_context context = CL
167 					.clCreateContext(contextProperties, 1, new cl_device_id[] { device.deviceId() }, null, null, null);
168 
169 			logger.info("\tCreating command queue");
170 			final cl_queue_properties queueProperties = new cl_queue_properties();
171 			queueProperties.addProperty(CL.CL_QUEUE_PROPERTIES,
172 					CL.CL_QUEUE_PROFILING_ENABLE | CL.CL_QUEUE_OUT_OF_ORDER_EXEC_MODE_ENABLE);
173 			final cl_command_queue commandQueue = CL
174 					.clCreateCommandQueueWithProperties(context, device.deviceId(), queueProperties, null);
175 
176 			logger.info("\tCreate program");
177 			final cl_program program = CL.clCreateProgramWithSource(context, programsArr.length, programsArr, null, null);
178 
179 			final var programSpec = gpuEAConfiguration.program();
180 			final var buildOptions = programSpec.buildOptions()
181 					.orElse(null);
182 			logger.info("\tBuilding program with options: {}", buildOptions);
183 			CL.clBuildProgram(program, 0, null, buildOptions, null, null);
184 
185 			final Set<String> kernelNames = gpuEAConfiguration.program()
186 					.kernelNames();
187 
188 			final Map<String, cl_kernel> kernels = new HashMap<>();
189 			final Map<String, KernelInfo> kernelInfos = new HashMap<>();
190 			for (final String kernelName : kernelNames) {
191 
192 				logger.info("\tCreate kernel {}", kernelName);
193 				final cl_kernel kernel = CL.clCreateKernel(program, kernelName, null);
194 				Validate.notNull(kernel);
195 
196 				kernels.put(kernelName, kernel);
197 
198 				final var kernelInfo = kernelInfoReader.read(device.deviceId(), kernel, kernelName);
199 				logger.trace("\t{}", kernelInfo);
200 				kernelInfos.put(kernelName, kernelInfo);
201 			}
202 
203 			clContexts.add(context);
204 			clCommandQueues.add(commandQueue);
205 			clKernels.add(kernels);
206 			clPrograms.add(program);
207 
208 			final var openCLExecutionContext = OpenCLExecutionContext.builder()
209 					.platform(platform)
210 					.device(device)
211 					.clContext(context)
212 					.clCommandQueue(commandQueue)
213 					.kernels(kernels)
214 					.kernelInfos(kernelInfos)
215 					.clProgram(program)
216 					.build();
217 
218 			clExecutionContexts.add(openCLExecutionContext);
219 		}
220 
221 		final var fitness = gpuEAConfiguration.fitness();
222 		fitness.beforeAllEvaluations();
223 		for (final OpenCLExecutionContext clExecutionContext : clExecutionContexts) {
224 			fitness.beforeAllEvaluations(clExecutionContext, executorService);
225 		}
226 	}
227 
228 	@Override
229 	public List<T> evaluate(final long generation, final List<Genotype> genotypes) {
230 
231 		final var fitness = gpuEAConfiguration.fitness();
232 
233 		/**
234 		 * TODO make it configurable from execution context
235 		 */
236 		final int partitionSize = (int) (Math.ceil((double) genotypes.size() / clExecutionContexts.size()));
237 		final var subGenotypes = ListUtils.partition(genotypes, partitionSize);
238 		logger.debug("Genotype decomposed in {} partition(s)", subGenotypes.size());
239 		if (logger.isTraceEnabled()) {
240 			for (int i = 0; i < subGenotypes.size(); i++) {
241 				final List<Genotype> subGenotype = subGenotypes.get(i);
242 				logger.trace("\tPartition {} with {} elements", i, subGenotype.size());
243 			}
244 		}
245 
246 		final List<CompletableFuture<List<T>>> subResultsCF = new ArrayList<>();
247 		for (int i = 0; i < subGenotypes.size(); i++) {
248 			final var openCLExecutionContext = clExecutionContexts.get(i % clExecutionContexts.size());
249 			final var subGenotype = subGenotypes.get(i);
250 
251 			fitness.beforeEvaluation(generation, subGenotype);
252 			fitness.beforeEvaluation(openCLExecutionContext, executorService, generation, subGenotype);
253 
254 			final var resultsCF = fitness.compute(openCLExecutionContext, executorService, generation, subGenotype)
255 					.thenApply((results) -> {
256 
257 						fitness.afterEvaluation(openCLExecutionContext, executorService, generation, subGenotype);
258 						fitness.afterEvaluation(generation, subGenotype);
259 
260 						return results;
261 					});
262 
263 			subResultsCF.add(resultsCF);
264 		}
265 
266 		final List<T> resultsEvaluation = new ArrayList<>(genotypes.size());
267 		for (final CompletableFuture<List<T>> subResultCF : subResultsCF) {
268 			final var fitnessResults = subResultCF.join();
269 			resultsEvaluation.addAll(fitnessResults);
270 		}
271 		return resultsEvaluation;
272 	}
273 
274 	@Override
275 	public void postEvaluation() {
276 
277 		final var fitness = gpuEAConfiguration.fitness();
278 
279 		for (final OpenCLExecutionContext clExecutionContext : clExecutionContexts) {
280 			fitness.afterAllEvaluations(clExecutionContext, executorService);
281 		}
282 		fitness.afterAllEvaluations();
283 
284 		logger.debug("Releasing kernels");
285 
286 		for (final Map<String, cl_kernel> kernels : clKernels) {
287 			for (final cl_kernel clKernel : kernels.values()) {
288 				CL.clReleaseKernel(clKernel);
289 			}
290 		}
291 		clKernels.clear();
292 
293 		logger.debug("Releasing programs");
294 		for (final cl_program clProgram : clPrograms) {
295 			CL.clReleaseProgram(clProgram);
296 		}
297 		clPrograms.clear();
298 
299 		logger.debug("Releasing command queues");
300 		for (final cl_command_queue clCommandQueue : clCommandQueues) {
301 			CL.clReleaseCommandQueue(clCommandQueue);
302 		}
303 		clCommandQueues.clear();
304 
305 		logger.debug("Releasing contexts");
306 		for (final cl_context clContext : clContexts) {
307 			CL.clReleaseContext(clContext);
308 		}
309 		clContexts.clear();
310 
311 		clExecutionContexts.clear();
312 		selectedPlatformToDevice = null;
313 
314 		FitnessEvaluator.super.postEvaluation();
315 	}
316 }