SingleKernelFitness.java
package net.bmahe.genetics4j.gpu.spec.fitness;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import org.apache.commons.collections4.MapUtils;
import org.apache.commons.lang3.Validate;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.jocl.CL;
import org.jocl.Pointer;
import org.jocl.Sizeof;
import net.bmahe.genetics4j.core.Genotype;
import net.bmahe.genetics4j.gpu.opencl.OpenCLExecutionContext;
import net.bmahe.genetics4j.gpu.opencl.model.Device;
import net.bmahe.genetics4j.gpu.spec.fitness.cldata.CLData;
import net.bmahe.genetics4j.gpu.spec.fitness.kernelcontext.KernelExecutionContext;
public class SingleKernelFitness<T extends Comparable<T>> extends OpenCLFitness<T> {
public static final Logger logger = LogManager.getLogger(SingleKernelFitness.class);
private final SingleKernelFitnessDescriptor singleKernelFitnessDescriptor;
private final FitnessExtractor<T> fitnessExtractor;
private final Map<Device, Map<Integer, CLData>> staticData = new ConcurrentHashMap<>();
private final Map<Device, Map<Integer, CLData>> data = new ConcurrentHashMap<>();
private final Map<Device, Map<Integer, CLData>> resultData = new ConcurrentHashMap<>();
private final Map<Device, KernelExecutionContext> kernelExecutionContexts = new ConcurrentHashMap<>();
protected void clearStaticData(final Device device) {
if (MapUtils.isEmpty(staticData) || MapUtils.isEmpty(staticData.get(device))) {
return;
}
final Map<Integer, CLData> mapData = staticData.get(device);
for (final CLData clData : mapData.values()) {
CL.clReleaseMemObject(clData.clMem());
}
mapData.clear();
staticData.remove(device);
}
protected void clearData(final Device device) {
if (MapUtils.isEmpty(data) || MapUtils.isEmpty(data.get(device))) {
return;
}
final Map<Integer, CLData> mapData = data.get(device);
for (final CLData clData : mapData.values()) {
CL.clReleaseMemObject(clData.clMem());
}
mapData.clear();
data.remove(device);
}
protected void clearResultData(final Device device) {
if (MapUtils.isEmpty(resultData) || MapUtils.isEmpty(resultData.get(device))) {
return;
}
final Map<Integer, CLData> mapData = resultData.get(device);
for (final CLData clData : mapData.values()) {
CL.clReleaseMemObject(clData.clMem());
}
mapData.clear();
resultData.remove(device);
}
public SingleKernelFitness(final SingleKernelFitnessDescriptor _singleKernelFitnessDescriptor,
final FitnessExtractor<T> _fitnessExtractor) {
Validate.notNull(_singleKernelFitnessDescriptor);
Validate.notNull(_fitnessExtractor);
this.singleKernelFitnessDescriptor = _singleKernelFitnessDescriptor;
this.fitnessExtractor = _fitnessExtractor;
}
@Override
public void beforeAllEvaluations(final OpenCLExecutionContext openCLExecutionContext,
final ExecutorService executorService) {
super.beforeAllEvaluations(openCLExecutionContext, executorService);
final var device = openCLExecutionContext.device();
clearStaticData(device);
final var staticDataLoaders = singleKernelFitnessDescriptor.staticDataLoaders();
for (final var entry : staticDataLoaders.entrySet()) {
final int argumentIdx = entry.getKey();
final var dataSupplier = entry.getValue();
if (logger.isTraceEnabled()) {
final var deviceName = openCLExecutionContext.device()
.name();
logger.trace("[{}] Loading static data for index {}", deviceName, argumentIdx);
}
final CLData clData = dataSupplier.load(openCLExecutionContext);
final var mapData = staticData.computeIfAbsent(device, k -> new HashMap<>());
if (mapData.put(argumentIdx, clData) != null) {
throw new IllegalArgumentException("Multiple data configured for index " + argumentIdx);
}
}
}
@Override
public void beforeEvaluation(OpenCLExecutionContext openCLExecutionContext, ExecutorService executorService,
long generation, final List<Genotype> genotypes) {
super.beforeEvaluation(openCLExecutionContext, executorService, generation, genotypes);
final var device = openCLExecutionContext.device();
final var kernels = openCLExecutionContext.kernels();
final var kernelName = singleKernelFitnessDescriptor.kernelName();
final var kernel = kernels.get(kernelName);
if (kernelExecutionContexts.containsKey(device)) {
throw new IllegalStateException("Found existing kernelExecutionContext");
}
final var kernelExecutionContextComputer = singleKernelFitnessDescriptor.kernelExecutionContextComputer();
final var kernelExecutionContext = kernelExecutionContextComputer
.compute(openCLExecutionContext, kernelName, generation, genotypes);
kernelExecutionContexts.put(device, kernelExecutionContext);
final var mapData = staticData.get(device);
if (MapUtils.isNotEmpty(mapData)) {
for (final var entry : mapData.entrySet()) {
final int argumentIdx = entry.getKey();
final var clStaticData = entry.getValue();
logger.trace("[{}] Loading static data for index {}", device.name(), argumentIdx);
CL.clSetKernelArg(kernel, argumentIdx, Sizeof.cl_mem, Pointer.to(clStaticData.clMem()));
}
}
final var dataLoaders = singleKernelFitnessDescriptor.dataLoaders();
if (MapUtils.isNotEmpty(dataLoaders)) {
for (final var entry : dataLoaders.entrySet()) {
final int argumentIdx = entry.getKey();
final var dataLoader = entry.getValue();
final var clDdata = dataLoader.load(openCLExecutionContext, generation, genotypes);
final var dataMapping = data.computeIfAbsent(device, k -> new HashMap<>());
if (dataMapping.put(argumentIdx, clDdata) != null) {
throw new IllegalArgumentException("Multiple data configured for index " + argumentIdx);
}
logger.trace("[{}] Loading data for index {}", device.name(), argumentIdx);
CL.clSetKernelArg(kernel, argumentIdx, Sizeof.cl_mem, Pointer.to(clDdata.clMem()));
}
}
final var localMemoryAllocators = singleKernelFitnessDescriptor.localMemoryAllocators();
if (MapUtils.isNotEmpty(localMemoryAllocators)) {
for (final var entry : localMemoryAllocators.entrySet()) {
final int argumentIdx = entry.getKey();
final var localMemoryAllocator = entry.getValue();
final var size = localMemoryAllocator
.load(openCLExecutionContext, kernelExecutionContext, generation, genotypes);
logger.trace("[{}] Setting local data for index {} with size of {}", device.name(), argumentIdx, size);
CL.clSetKernelArg(kernel, argumentIdx, size, null);
}
}
final var resultAllocators = singleKernelFitnessDescriptor.resultAllocators();
if (MapUtils.isNotEmpty(resultAllocators)) {
for (final var entry : resultAllocators.entrySet()) {
final int argumentIdx = entry.getKey();
final var resultAllocator = entry.getValue();
final var clDdata = resultAllocator
.load(openCLExecutionContext, kernelExecutionContext, generation, genotypes);
final var dataMapping = resultData.computeIfAbsent(device, k -> new HashMap<>());
if (dataMapping.put(argumentIdx, clDdata) != null) {
throw new IllegalArgumentException("Multiple result allocators configured for index " + argumentIdx);
}
logger.trace("[{}] Preparing result data memory for index {}", device.name(), argumentIdx);
CL.clSetKernelArg(kernel, argumentIdx, Sizeof.cl_mem, Pointer.to(clDdata.clMem()));
}
}
}
@Override
public CompletableFuture<List<T>> compute(final OpenCLExecutionContext openCLExecutionContext,
final ExecutorService executorService, final long generation, List<Genotype> genotypes) {
return CompletableFuture.supplyAsync(() -> {
final var clCommandQueue = openCLExecutionContext.clCommandQueue();
final var kernels = openCLExecutionContext.kernels();
final var kernelName = singleKernelFitnessDescriptor.kernelName();
final var kernel = kernels.get(kernelName);
if (kernel == null) {
throw new IllegalStateException("Could not find kernel [" + kernelName + "]");
}
final var device = openCLExecutionContext.device();
final var kernelExecutionContext = kernelExecutionContexts.get(device);
final var globalWorkDimensions = kernelExecutionContext.globalWorkDimensions();
final var globalWorkSize = kernelExecutionContext.globalWorkSize();
final long[] workGroupSize = kernelExecutionContext.workGroupSize()
.orElse(null);
logger.trace(
"Starting computation on kernel {} for {} genotypes and global work size {} and local work size {}",
kernelName,
genotypes.size(),
globalWorkSize,
workGroupSize);
final long startTime = System.nanoTime();
CL.clEnqueueNDRangeKernel(clCommandQueue,
kernel,
globalWorkDimensions,
null,
globalWorkSize,
workGroupSize,
0,
null,
null);
final long endTime = System.nanoTime();
final long duration = endTime - startTime;
if (logger.isDebugEnabled()) {
final var deviceName = openCLExecutionContext.device()
.name();
logger.debug("{} - Took {} microsec for {} genotypes", deviceName, duration / 1000., genotypes.size());
}
return kernelExecutionContext;
}, executorService)
.thenApply(kernelExecutionContext -> {
final var resultExtractor = new ResultExtractor(resultData);
return fitnessExtractor.compute(openCLExecutionContext,
kernelExecutionContext,
executorService,
generation,
genotypes,
resultExtractor);
});
}
@Override
public void afterEvaluation(OpenCLExecutionContext openCLExecutionContext, ExecutorService executorService,
long generation, List<Genotype> genotypes) {
super.afterEvaluation(openCLExecutionContext, executorService, generation, genotypes);
final var device = openCLExecutionContext.device();
logger.trace("[{}] Releasing data", device.name());
clearData(device);
clearResultData(device);
kernelExecutionContexts.remove(device);
}
@Override
public void afterAllEvaluations(final OpenCLExecutionContext openCLExecutionContext,
final ExecutorService executorService) {
super.afterAllEvaluations(openCLExecutionContext, executorService);
final var device = openCLExecutionContext.device();
logger.trace("[{}] Releasing static data", device.name());
clearStaticData(device);
}
public static <U extends Comparable<U>> SingleKernelFitness<U> of(
final SingleKernelFitnessDescriptor singleKernelFitnessDescriptor,
final FitnessExtractor<U> fitnessExtractor) {
Validate.notNull(singleKernelFitnessDescriptor);
Validate.notNull(fitnessExtractor);
return new SingleKernelFitness<>(singleKernelFitnessDescriptor, fitnessExtractor);
}
}