1 package net.bmahe.genetics4j.gpu;
2
3 import java.io.IOException;
4 import java.nio.charset.StandardCharsets;
5 import java.util.ArrayList;
6 import java.util.HashMap;
7 import java.util.List;
8 import java.util.Map;
9 import java.util.Set;
10 import java.util.concurrent.CompletableFuture;
11 import java.util.concurrent.ExecutorService;
12
13 import org.apache.commons.collections4.ListUtils;
14 import org.apache.commons.io.IOUtils;
15 import org.apache.commons.lang3.Validate;
16 import org.apache.commons.lang3.tuple.Pair;
17 import org.apache.logging.log4j.LogManager;
18 import org.apache.logging.log4j.Logger;
19 import org.jocl.CL;
20 import org.jocl.cl_command_queue;
21 import org.jocl.cl_context;
22 import org.jocl.cl_context_properties;
23 import org.jocl.cl_device_id;
24 import org.jocl.cl_kernel;
25 import org.jocl.cl_platform_id;
26 import org.jocl.cl_program;
27 import org.jocl.cl_queue_properties;
28
29 import net.bmahe.genetics4j.core.Genotype;
30 import net.bmahe.genetics4j.core.evaluation.FitnessEvaluator;
31 import net.bmahe.genetics4j.gpu.opencl.DeviceReader;
32 import net.bmahe.genetics4j.gpu.opencl.DeviceUtils;
33 import net.bmahe.genetics4j.gpu.opencl.KernelInfoReader;
34 import net.bmahe.genetics4j.gpu.opencl.OpenCLExecutionContext;
35 import net.bmahe.genetics4j.gpu.opencl.PlatformReader;
36 import net.bmahe.genetics4j.gpu.opencl.PlatformUtils;
37 import net.bmahe.genetics4j.gpu.opencl.model.Device;
38 import net.bmahe.genetics4j.gpu.opencl.model.KernelInfo;
39 import net.bmahe.genetics4j.gpu.opencl.model.Platform;
40 import net.bmahe.genetics4j.gpu.spec.GPUEAConfiguration;
41 import net.bmahe.genetics4j.gpu.spec.GPUEAExecutionContext;
42 import net.bmahe.genetics4j.gpu.spec.Program;
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133 public class GPUFitnessEvaluator<T extends Comparable<T>> implements FitnessEvaluator<T> {
134 public static final Logger logger = LogManager.getLogger(GPUFitnessEvaluator.class);
135
136 private final GPUEAExecutionContext<T> gpuEAExecutionContext;
137 private final GPUEAConfiguration<T> gpuEAConfiguration;
138 private final ExecutorService executorService;
139
140 private List<Pair<Platform, Device>> selectedPlatformToDevice;
141
142 final List<cl_context> clContexts = new ArrayList<>();
143 final List<cl_command_queue> clCommandQueues = new ArrayList<>();
144 final List<cl_program> clPrograms = new ArrayList<>();
145 final List<Map<String, cl_kernel>> clKernels = new ArrayList<>();
146 final List<OpenCLExecutionContext> clExecutionContexts = new ArrayList<>();
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162 public GPUFitnessEvaluator(final GPUEAExecutionContext<T> _gpuEAExecutionContext,
163 final GPUEAConfiguration<T> _gpuEAConfiguration,
164 final ExecutorService _executorService) {
165 Validate.notNull(_gpuEAExecutionContext);
166 Validate.notNull(_gpuEAConfiguration);
167 Validate.notNull(_executorService);
168
169 this.gpuEAExecutionContext = _gpuEAExecutionContext;
170 this.gpuEAConfiguration = _gpuEAConfiguration;
171 this.executorService = _executorService;
172
173 CL.setExceptionsEnabled(true);
174 }
175
176 private String loadResource(final String filename) {
177 Validate.notBlank(filename);
178
179 try {
180 return IOUtils.resourceToString(filename, StandardCharsets.UTF_8);
181 } catch (IOException e) {
182 throw new IllegalStateException("Unable to load resource " + filename, e);
183 }
184 }
185
186 private List<String> grabProgramSources() {
187 final Program programSpec = gpuEAConfiguration.program();
188
189 logger.info("Load program source: {}", programSpec);
190
191 final List<String> sources = new ArrayList<>();
192
193 sources.addAll(programSpec.content());
194
195 programSpec.resources().stream().map(resource -> loadResource(resource)).forEach(program -> {
196 sources.add(program);
197 });
198
199 return sources;
200 }
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230 @Override
231 public void preEvaluation() {
232 logger.trace("Init...");
233 FitnessEvaluator.super.preEvaluation();
234
235 final var platformReader = new PlatformReader();
236 final var deviceReader = new DeviceReader();
237 final var kernelInfoReader = new KernelInfoReader();
238
239 final int numPlatforms = PlatformUtils.numPlatforms();
240 logger.info("Found {} platforms", numPlatforms);
241
242 final List<cl_platform_id> platformIds = PlatformUtils.platformIds(numPlatforms);
243
244 logger.info("Selecting platform and devices");
245 final var platformFilters = gpuEAExecutionContext.platformFilters();
246 final var deviceFilters = gpuEAExecutionContext.deviceFilters();
247
248 selectedPlatformToDevice = platformIds.stream()
249 .map(platformReader::read)
250 .filter(platformFilters)
251 .flatMap(platform -> {
252 final var platformId = platform.platformId();
253 final int numDevices = DeviceUtils.numDevices(platformId);
254 logger.trace("\tPlatform {}: {} devices", platform.name(), numDevices);
255
256 final var deviceIds = DeviceUtils.getDeviceIds(platformId, numDevices);
257 return deviceIds.stream().map(deviceId -> Pair.of(platform, deviceId));
258 })
259 .map(platformToDeviceId -> {
260 final var platform = platformToDeviceId.getLeft();
261 final var platformId = platform.platformId();
262 final var deviceID = platformToDeviceId.getRight();
263
264 return Pair.of(platform, deviceReader.read(platformId, deviceID));
265 })
266 .filter(platformToDevice -> deviceFilters.test(platformToDevice.getRight()))
267 .toList();
268
269 if (logger.isTraceEnabled()) {
270 logger.trace("============================");
271 logger.trace("Selected devices:");
272 selectedPlatformToDevice.forEach(pd -> {
273 logger.trace("{}", pd.getLeft());
274 logger.trace("\t{}", pd.getRight());
275 });
276 logger.trace("============================");
277 }
278
279 Validate.isTrue(selectedPlatformToDevice.size() > 0);
280
281 final List<String> programs = grabProgramSources();
282 final String[] programsArr = programs.toArray(new String[programs.size()]);
283
284 for (final var platformAndDevice : selectedPlatformToDevice) {
285 final var platform = platformAndDevice.getLeft();
286 final var device = platformAndDevice.getRight();
287
288 logger.info("Processing platform [{}] / device [{}]", platform.name(), device.name());
289
290 logger.info("\tCreating context");
291 cl_context_properties contextProperties = new cl_context_properties();
292 contextProperties.addProperty(CL.CL_CONTEXT_PLATFORM, platform.platformId());
293
294 final cl_context context = CL
295 .clCreateContext(contextProperties, 1, new cl_device_id[] { device.deviceId() }, null, null, null);
296
297 logger.info("\tCreating command queue");
298 final cl_queue_properties queueProperties = new cl_queue_properties();
299 queueProperties.addProperty(
300 CL.CL_QUEUE_PROPERTIES,
301 CL.CL_QUEUE_PROFILING_ENABLE | CL.CL_QUEUE_OUT_OF_ORDER_EXEC_MODE_ENABLE);
302 final cl_command_queue commandQueue = CL
303 .clCreateCommandQueueWithProperties(context, device.deviceId(), queueProperties, null);
304
305 logger.info("\tCreate program");
306 final cl_program program = CL.clCreateProgramWithSource(context, programsArr.length, programsArr, null, null);
307
308 final var programSpec = gpuEAConfiguration.program();
309 final var buildOptions = programSpec.buildOptions().orElse(null);
310 logger.info("\tBuilding program with options: {}", buildOptions);
311 CL.clBuildProgram(program, 0, null, buildOptions, null, null);
312
313 final Set<String> kernelNames = gpuEAConfiguration.program().kernelNames();
314
315 final Map<String, cl_kernel> kernels = new HashMap<>();
316 final Map<String, KernelInfo> kernelInfos = new HashMap<>();
317 for (final String kernelName : kernelNames) {
318
319 logger.info("\tCreate kernel {}", kernelName);
320 final cl_kernel kernel = CL.clCreateKernel(program, kernelName, null);
321 Validate.notNull(kernel);
322
323 kernels.put(kernelName, kernel);
324
325 final var kernelInfo = kernelInfoReader.read(device.deviceId(), kernel, kernelName);
326 logger.trace("\t{}", kernelInfo);
327 kernelInfos.put(kernelName, kernelInfo);
328 }
329
330 clContexts.add(context);
331 clCommandQueues.add(commandQueue);
332 clKernels.add(kernels);
333 clPrograms.add(program);
334
335 final var openCLExecutionContext = OpenCLExecutionContext.builder()
336 .platform(platform)
337 .device(device)
338 .clContext(context)
339 .clCommandQueue(commandQueue)
340 .kernels(kernels)
341 .kernelInfos(kernelInfos)
342 .clProgram(program)
343 .build();
344
345 clExecutionContexts.add(openCLExecutionContext);
346 }
347
348 final var fitness = gpuEAConfiguration.fitness();
349 fitness.beforeAllEvaluations();
350 for (final OpenCLExecutionContext clExecutionContext : clExecutionContexts) {
351 fitness.beforeAllEvaluations(clExecutionContext, executorService);
352 }
353 }
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396 @Override
397 public List<T> evaluate(final long generation, final List<Genotype> genotypes) {
398
399 final var fitness = gpuEAConfiguration.fitness();
400
401
402
403
404 final int partitionSize = (int) (Math.ceil((double) genotypes.size() / clExecutionContexts.size()));
405 final var subGenotypes = ListUtils.partition(genotypes, partitionSize);
406 logger.debug("Genotype decomposed in {} partition(s)", subGenotypes.size());
407 if (logger.isTraceEnabled()) {
408 for (int i = 0; i < subGenotypes.size(); i++) {
409 final List<Genotype> subGenotype = subGenotypes.get(i);
410 logger.trace("\tPartition {} with {} elements", i, subGenotype.size());
411 }
412 }
413
414 final List<CompletableFuture<List<T>>> subResultsCF = new ArrayList<>();
415 for (int i = 0; i < subGenotypes.size(); i++) {
416 final var openCLExecutionContext = clExecutionContexts.get(i % clExecutionContexts.size());
417 final var subGenotype = subGenotypes.get(i);
418
419 fitness.beforeEvaluation(generation, subGenotype);
420 fitness.beforeEvaluation(openCLExecutionContext, executorService, generation, subGenotype);
421
422 final var resultsCF = fitness.compute(openCLExecutionContext, executorService, generation, subGenotype)
423 .thenApply((results) -> {
424
425 fitness.afterEvaluation(openCLExecutionContext, executorService, generation, subGenotype);
426 fitness.afterEvaluation(generation, subGenotype);
427
428 return results;
429 });
430
431 subResultsCF.add(resultsCF);
432 }
433
434 final List<T> resultsEvaluation = new ArrayList<>(genotypes.size());
435 for (final CompletableFuture<List<T>> subResultCF : subResultsCF) {
436 final var fitnessResults = subResultCF.join();
437 resultsEvaluation.addAll(fitnessResults);
438 }
439 return resultsEvaluation;
440 }
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471 @Override
472 public void postEvaluation() {
473
474 final var fitness = gpuEAConfiguration.fitness();
475
476 for (final OpenCLExecutionContext clExecutionContext : clExecutionContexts) {
477 fitness.afterAllEvaluations(clExecutionContext, executorService);
478 }
479 fitness.afterAllEvaluations();
480
481 logger.debug("Releasing kernels");
482
483 for (final Map<String, cl_kernel> kernels : clKernels) {
484 for (final cl_kernel clKernel : kernels.values()) {
485 CL.clReleaseKernel(clKernel);
486 }
487 }
488 clKernels.clear();
489
490 logger.debug("Releasing programs");
491 for (final cl_program clProgram : clPrograms) {
492 CL.clReleaseProgram(clProgram);
493 }
494 clPrograms.clear();
495
496 logger.debug("Releasing command queues");
497 for (final cl_command_queue clCommandQueue : clCommandQueues) {
498 CL.clReleaseCommandQueue(clCommandQueue);
499 }
500 clCommandQueues.clear();
501
502 logger.debug("Releasing contexts");
503 for (final cl_context clContext : clContexts) {
504 CL.clReleaseContext(clContext);
505 }
506 clContexts.clear();
507
508 clExecutionContexts.clear();
509 selectedPlatformToDevice = null;
510
511 FitnessEvaluator.super.postEvaluation();
512 }
513 }