1 package net.bmahe.genetics4j.gpu.spec.fitness;
2
3 import java.util.HashMap;
4 import java.util.List;
5 import java.util.Map;
6 import java.util.concurrent.CompletableFuture;
7 import java.util.concurrent.ConcurrentHashMap;
8 import java.util.concurrent.ExecutorService;
9
10 import org.apache.commons.collections4.MapUtils;
11 import org.apache.commons.lang3.Validate;
12 import org.apache.logging.log4j.LogManager;
13 import org.apache.logging.log4j.Logger;
14 import org.jocl.CL;
15 import org.jocl.Pointer;
16 import org.jocl.Sizeof;
17
18 import net.bmahe.genetics4j.core.Genotype;
19 import net.bmahe.genetics4j.gpu.opencl.OpenCLExecutionContext;
20 import net.bmahe.genetics4j.gpu.opencl.model.Device;
21 import net.bmahe.genetics4j.gpu.spec.fitness.cldata.CLData;
22 import net.bmahe.genetics4j.gpu.spec.fitness.kernelcontext.KernelExecutionContext;
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102 public class SingleKernelFitness<T extends Comparable<T>> extends OpenCLFitness<T> {
103 public static final Logger logger = LogManager.getLogger(SingleKernelFitness.class);
104
105 private final SingleKernelFitnessDescriptor singleKernelFitnessDescriptor;
106 private final FitnessExtractor<T> fitnessExtractor;
107
108 private final Map<Device, Map<Integer, CLData>> staticData = new ConcurrentHashMap<>();
109 private final Map<Device, Map<Integer, CLData>> data = new ConcurrentHashMap<>();
110 private final Map<Device, Map<Integer, CLData>> resultData = new ConcurrentHashMap<>();
111
112 private final Map<Device, KernelExecutionContext> kernelExecutionContexts = new ConcurrentHashMap<>();
113
114 protected void clearStaticData(final Device device) {
115 if (MapUtils.isEmpty(staticData) || MapUtils.isEmpty(staticData.get(device))) {
116 return;
117 }
118
119 final Map<Integer, CLData> mapData = staticData.get(device);
120 for (final CLData clData : mapData.values()) {
121 CL.clReleaseMemObject(clData.clMem());
122 }
123
124 mapData.clear();
125 staticData.remove(device);
126 }
127
128 protected void clearData(final Device device) {
129 if (MapUtils.isEmpty(data) || MapUtils.isEmpty(data.get(device))) {
130 return;
131 }
132
133 final Map<Integer, CLData> mapData = data.get(device);
134 for (final CLData clData : mapData.values()) {
135 CL.clReleaseMemObject(clData.clMem());
136 }
137
138 mapData.clear();
139 data.remove(device);
140 }
141
142 protected void clearResultData(final Device device) {
143 if (MapUtils.isEmpty(resultData) || MapUtils.isEmpty(resultData.get(device))) {
144 return;
145 }
146
147 final Map<Integer, CLData> mapData = resultData.get(device);
148 for (final CLData clData : mapData.values()) {
149 CL.clReleaseMemObject(clData.clMem());
150 }
151
152 mapData.clear();
153 resultData.remove(device);
154 }
155
156
157
158
159
160
161
162
163 public SingleKernelFitness(final SingleKernelFitnessDescriptor _singleKernelFitnessDescriptor,
164 final FitnessExtractor<T> _fitnessExtractor) {
165 Validate.notNull(_singleKernelFitnessDescriptor);
166 Validate.notNull(_fitnessExtractor);
167
168 this.singleKernelFitnessDescriptor = _singleKernelFitnessDescriptor;
169 this.fitnessExtractor = _fitnessExtractor;
170 }
171
172 @Override
173 public void beforeAllEvaluations(final OpenCLExecutionContext openCLExecutionContext,
174 final ExecutorService executorService) {
175 super.beforeAllEvaluations(openCLExecutionContext, executorService);
176
177 final var device = openCLExecutionContext.device();
178 clearStaticData(device);
179
180 final var staticDataLoaders = singleKernelFitnessDescriptor.staticDataLoaders();
181 for (final var entry : staticDataLoaders.entrySet()) {
182 final int argumentIdx = entry.getKey();
183 final var dataSupplier = entry.getValue();
184
185 if (logger.isTraceEnabled()) {
186 final var deviceName = openCLExecutionContext.device().name();
187 logger.trace("[{}] Loading static data for index {}", deviceName, argumentIdx);
188 }
189 final CLData clData = dataSupplier.load(openCLExecutionContext);
190
191 final var mapData = staticData.computeIfAbsent(device, k -> new HashMap<>());
192 if (mapData.put(argumentIdx, clData) != null) {
193 throw new IllegalArgumentException("Multiple data configured for index " + argumentIdx);
194 }
195 }
196 }
197
198 @Override
199 public void beforeEvaluation(OpenCLExecutionContext openCLExecutionContext, ExecutorService executorService,
200 long generation, final List<Genotype> genotypes) {
201 super.beforeEvaluation(openCLExecutionContext, executorService, generation, genotypes);
202
203 final var device = openCLExecutionContext.device();
204 final var kernels = openCLExecutionContext.kernels();
205
206 final var kernelName = singleKernelFitnessDescriptor.kernelName();
207 final var kernel = kernels.get(kernelName);
208
209 if (kernelExecutionContexts.containsKey(device)) {
210 throw new IllegalStateException("Found existing kernelExecutionContext");
211 }
212 final var kernelExecutionContextComputer = singleKernelFitnessDescriptor.kernelExecutionContextComputer();
213 final var kernelExecutionContext = kernelExecutionContextComputer
214 .compute(openCLExecutionContext, kernelName, generation, genotypes);
215 kernelExecutionContexts.put(device, kernelExecutionContext);
216
217 final var mapData = staticData.get(device);
218 if (MapUtils.isNotEmpty(mapData)) {
219 for (final var entry : mapData.entrySet()) {
220 final int argumentIdx = entry.getKey();
221 final var clStaticData = entry.getValue();
222
223 logger.trace("[{}] Loading static data for index {}", device.name(), argumentIdx);
224
225 CL.clSetKernelArg(kernel, argumentIdx, Sizeof.cl_mem, Pointer.to(clStaticData.clMem()));
226 }
227 }
228
229 final var dataLoaders = singleKernelFitnessDescriptor.dataLoaders();
230 if (MapUtils.isNotEmpty(dataLoaders)) {
231 for (final var entry : dataLoaders.entrySet()) {
232 final int argumentIdx = entry.getKey();
233 final var dataLoader = entry.getValue();
234
235 final var clDdata = dataLoader.load(openCLExecutionContext, generation, genotypes);
236
237 final var dataMapping = data.computeIfAbsent(device, k -> new HashMap<>());
238 if (dataMapping.put(argumentIdx, clDdata) != null) {
239 throw new IllegalArgumentException("Multiple data configured for index " + argumentIdx);
240 }
241 logger.trace("[{}] Loading data for index {}", device.name(), argumentIdx);
242
243 CL.clSetKernelArg(kernel, argumentIdx, Sizeof.cl_mem, Pointer.to(clDdata.clMem()));
244 }
245 }
246
247 final var localMemoryAllocators = singleKernelFitnessDescriptor.localMemoryAllocators();
248 if (MapUtils.isNotEmpty(localMemoryAllocators)) {
249 for (final var entry : localMemoryAllocators.entrySet()) {
250 final int argumentIdx = entry.getKey();
251 final var localMemoryAllocator = entry.getValue();
252
253 final var size = localMemoryAllocator
254 .load(openCLExecutionContext, kernelExecutionContext, generation, genotypes);
255 logger.trace("[{}] Setting local data for index {} with size of {}", device.name(), argumentIdx, size);
256
257 CL.clSetKernelArg(kernel, argumentIdx, size, null);
258 }
259 }
260
261 final var resultAllocators = singleKernelFitnessDescriptor.resultAllocators();
262 if (MapUtils.isNotEmpty(resultAllocators)) {
263 for (final var entry : resultAllocators.entrySet()) {
264 final int argumentIdx = entry.getKey();
265 final var resultAllocator = entry.getValue();
266
267 final var clDdata = resultAllocator
268 .load(openCLExecutionContext, kernelExecutionContext, generation, genotypes);
269
270 final var dataMapping = resultData.computeIfAbsent(device, k -> new HashMap<>());
271 if (dataMapping.put(argumentIdx, clDdata) != null) {
272 throw new IllegalArgumentException("Multiple result allocators configured for index " + argumentIdx);
273 }
274 logger.trace("[{}] Preparing result data memory for index {}", device.name(), argumentIdx);
275
276 CL.clSetKernelArg(kernel, argumentIdx, Sizeof.cl_mem, Pointer.to(clDdata.clMem()));
277 }
278 }
279
280 }
281
282 @Override
283 public CompletableFuture<List<T>> compute(final OpenCLExecutionContext openCLExecutionContext,
284 final ExecutorService executorService, final long generation, List<Genotype> genotypes) {
285
286 return CompletableFuture.supplyAsync(() -> {
287 final var clCommandQueue = openCLExecutionContext.clCommandQueue();
288 final var kernels = openCLExecutionContext.kernels();
289
290 final var kernelName = singleKernelFitnessDescriptor.kernelName();
291 final var kernel = kernels.get(kernelName);
292 if (kernel == null) {
293 throw new IllegalStateException("Could not find kernel [" + kernelName + "]");
294 }
295
296 final var device = openCLExecutionContext.device();
297 final var kernelExecutionContext = kernelExecutionContexts.get(device);
298
299 final var globalWorkDimensions = kernelExecutionContext.globalWorkDimensions();
300 final var globalWorkSize = kernelExecutionContext.globalWorkSize();
301 final long[] workGroupSize = kernelExecutionContext.workGroupSize().orElse(null);
302
303 logger.trace(
304 "Starting computation on kernel {} for {} genotypes and global work size {} and local work size {}",
305 kernelName,
306 genotypes.size(),
307 globalWorkSize,
308 workGroupSize);
309 final long startTime = System.nanoTime();
310 CL.clEnqueueNDRangeKernel(
311 clCommandQueue,
312 kernel,
313 globalWorkDimensions,
314 null,
315 globalWorkSize,
316 workGroupSize,
317 0,
318 null,
319 null);
320
321 final long endTime = System.nanoTime();
322 final long duration = endTime - startTime;
323 if (logger.isDebugEnabled()) {
324 final var deviceName = openCLExecutionContext.device().name();
325 logger.debug("{} - Took {} microsec for {} genotypes", deviceName, duration / 1000., genotypes.size());
326 }
327 return kernelExecutionContext;
328 }, executorService).thenApply(kernelExecutionContext -> {
329
330 final var resultExtractor = new ResultExtractor(resultData);
331 return fitnessExtractor.compute(
332 openCLExecutionContext,
333 kernelExecutionContext,
334 executorService,
335 generation,
336 genotypes,
337 resultExtractor);
338 });
339 }
340
341 @Override
342 public void afterEvaluation(OpenCLExecutionContext openCLExecutionContext, ExecutorService executorService,
343 long generation, List<Genotype> genotypes) {
344 super.afterEvaluation(openCLExecutionContext, executorService, generation, genotypes);
345
346 final var device = openCLExecutionContext.device();
347 logger.trace("[{}] Releasing data", device.name());
348 clearData(device);
349 clearResultData(device);
350 kernelExecutionContexts.remove(device);
351 }
352
353 @Override
354 public void afterAllEvaluations(final OpenCLExecutionContext openCLExecutionContext,
355 final ExecutorService executorService) {
356 super.afterAllEvaluations(openCLExecutionContext, executorService);
357
358 final var device = openCLExecutionContext.device();
359 logger.trace("[{}] Releasing static data", device.name());
360 clearStaticData(device);
361 }
362
363
364
365
366
367
368
369
370
371
372 public static <U extends Comparable<U>> SingleKernelFitness<U> of(
373 final SingleKernelFitnessDescriptor singleKernelFitnessDescriptor,
374 final FitnessExtractor<U> fitnessExtractor) {
375 Validate.notNull(singleKernelFitnessDescriptor);
376 Validate.notNull(fitnessExtractor);
377
378 return new SingleKernelFitness<>(singleKernelFitnessDescriptor, fitnessExtractor);
379 }
380 }