MultiStageFitness.java
package net.bmahe.genetics4j.gpu.spec.fitness;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import org.apache.commons.collections4.MapUtils;
import org.apache.commons.lang3.Validate;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.jocl.CL;
import org.jocl.Pointer;
import org.jocl.Sizeof;
import net.bmahe.genetics4j.core.Genotype;
import net.bmahe.genetics4j.gpu.opencl.OpenCLExecutionContext;
import net.bmahe.genetics4j.gpu.opencl.model.Device;
import net.bmahe.genetics4j.gpu.spec.fitness.cldata.CLData;
import net.bmahe.genetics4j.gpu.spec.fitness.kernelcontext.KernelExecutionContext;
import net.bmahe.genetics4j.gpu.spec.fitness.multistage.MultiStageDescriptor;
import net.bmahe.genetics4j.gpu.spec.fitness.multistage.StageDescriptor;
/**
* GPU-accelerated fitness evaluator that executes multiple sequential OpenCL kernels for complex fitness computation.
*
* <p>MultiStageFitness provides a framework for implementing fitness evaluation that requires multiple sequential
* GPU kernel executions, where each stage can use results from previous stages as input. This is ideal for complex
* fitness functions that require multiple computational phases, such as neural network training, multi-objective
* optimization, or hierarchical problem decomposition.
*
* <p>Key features:
* <ul>
* <li><strong>Sequential execution</strong>: Multiple OpenCL kernels executed in sequence</li>
* <li><strong>Inter-stage data flow</strong>: Results from earlier stages used as inputs to later stages</li>
* <li><strong>Memory optimization</strong>: Automatic cleanup and reuse of intermediate results</li>
* <li><strong>Pipeline processing</strong>: Support for complex computational pipelines</li>
* <li><strong>Stage configuration</strong>: Individual configuration for each computational stage</li>
* </ul>
*
* <p>Multi-stage computation architecture:
* <ul>
* <li><strong>Stage descriptors</strong>: Each stage defines its kernel, data loaders, and result allocators</li>
* <li><strong>Data reuse patterns</strong>: Previous stage results can be reused as arguments or size parameters</li>
* <li><strong>Memory lifecycle</strong>: Automatic management of intermediate results between stages</li>
* <li><strong>Static data sharing</strong>: Algorithm parameters shared across all stages</li>
* </ul>
*
* <p>Typical usage pattern:
* <pre>{@code
* // Define multi-stage descriptor with sequential kernels
* MultiStageDescriptor descriptor = MultiStageDescriptor.builder()
* .addStaticDataLoader("parameters", parametersLoader)
* .addStage(StageDescriptor.builder()
* .kernelName("preprocessing")
* .addDataLoader(0, inputDataLoader)
* .addResultAllocator(1, preprocessedResultAllocator)
* .build())
* .addStage(StageDescriptor.builder()
* .kernelName("fitness_evaluation")
* .reusePreviousResultAsArgument(1, 0) // Use previous result as input
* .addResultAllocator(1, fitnessResultAllocator)
* .build())
* .build();
*
* // Define fitness extraction from final stage results
* FitnessExtractor<Double> extractor = (context, kernelCtx, executor, generation, genotypes, results) -> {
* float[] fitnessValues = results.extractFloatArray(context, 1);
* return Arrays.stream(fitnessValues)
* .mapToDouble(f -> (double) f)
* .boxed()
* .collect(Collectors.toList());
* };
*
* // Create multi-stage fitness evaluator
* MultiStageFitness<Double> fitness = MultiStageFitness.of(descriptor, extractor);
* }</pre>
*
* <p>Stage execution workflow:
* <ol>
* <li><strong>Initialization</strong>: Load shared static data once before all evaluations</li>
* <li><strong>Stage iteration</strong>: For each stage in sequence:</li>
* <li><strong>Context computation</strong>: Calculate kernel execution parameters for the stage</li>
* <li><strong>Data preparation</strong>: Load stage-specific data and map previous results</li>
* <li><strong>Kernel execution</strong>: Execute the stage kernel with configured parameters</li>
* <li><strong>Result management</strong>: Store results for potential use in subsequent stages</li>
* <li><strong>Final extraction</strong>: Extract fitness values from the last stage results</li>
* <li><strong>Cleanup</strong>: Release all intermediate and final result memory</li>
* </ol>
*
* <p>Inter-stage data flow patterns:
* <ul>
* <li><strong>Result reuse</strong>: Use previous stage output buffers as input to subsequent stages</li>
* <li><strong>Size propagation</strong>: Use previous stage result sizes as parameters for memory allocation</li>
* <li><strong>Memory optimization</strong>: Automatic cleanup of intermediate results no longer needed</li>
* <li><strong>Data type preservation</strong>: Maintain OpenCL data types across stage boundaries</li>
* </ul>
*
* <p>Memory management strategy:
* <ul>
* <li><strong>Static data persistence</strong>: Shared parameters allocated once across all stages</li>
* <li><strong>Intermediate cleanup</strong>: Automatic release of stage results when no longer needed</li>
* <li><strong>Result chaining</strong>: Efficient memory reuse between consecutive stages</li>
* <li><strong>Final cleanup</strong>: Complete memory cleanup after fitness extraction</li>
* </ul>
*
* <p>Performance optimization features:
* <ul>
* <li><strong>Pipeline efficiency</strong>: Minimized memory transfers between stages</li>
* <li><strong>Memory coalescing</strong>: Optimized data layouts for GPU memory access</li>
* <li><strong>Stage-specific tuning</strong>: Individual work group optimization per stage</li>
* <li><strong>Asynchronous execution</strong>: Non-blocking fitness computation</li>
* </ul>
*
* @param <T> the fitness value type, must be Comparable for optimization algorithms
* @see OpenCLFitness
* @see MultiStageDescriptor
* @see StageDescriptor
* @see FitnessExtractor
*/
public class MultiStageFitness<T extends Comparable<T>> extends OpenCLFitness<T> {
public static final Logger logger = LogManager.getLogger(MultiStageFitness.class);
private final MultiStageDescriptor multiStageDescriptor;
private final FitnessExtractor<T> fitnessExtractor;
private final Map<Device, Map<String, CLData>> staticData = new ConcurrentHashMap<>();
protected void clearStaticData(final Device device) {
if (MapUtils.isEmpty(staticData) || MapUtils.isEmpty(staticData.get(device))) {
return;
}
final Map<String, CLData> mapData = staticData.get(device);
for (final CLData clData : mapData.values()) {
CL.clReleaseMemObject(clData.clMem());
}
mapData.clear();
staticData.remove(device);
}
protected void clearData(final Map<Integer, CLData> data) {
if (MapUtils.isEmpty(data)) {
return;
}
for (final CLData clData : data.values()) {
CL.clReleaseMemObject(clData.clMem());
}
data.clear();
}
protected void clearResultData(final Map<Integer, CLData> resultData) {
if (MapUtils.isEmpty(resultData)) {
return;
}
for (final CLData clData : resultData.values()) {
CL.clReleaseMemObject(clData.clMem());
}
resultData.clear();
}
protected void prepareStaticData(final OpenCLExecutionContext openCLExecutionContext,
final StageDescriptor stageDescriptor) {
Validate.notNull(openCLExecutionContext);
Validate.notNull(stageDescriptor);
final var device = openCLExecutionContext.device();
logger.trace("[{}] Preparing static data", device.name());
final var kernels = openCLExecutionContext.kernels();
final var kernelName = stageDescriptor.kernelName();
final var kernel = kernels.get(kernelName);
final var mapStaticDataAsArgument = stageDescriptor.mapStaticDataAsArgument();
for (final var entry : mapStaticDataAsArgument.entrySet()) {
final var argumentName = entry.getKey();
final var argumentIndex = entry.getValue();
final var staticDataMap = staticData.get(device);
if (staticDataMap.containsKey(argumentName) == false) {
throw new IllegalArgumentException("Unknown static argument " + argumentName);
}
final CLData clStaticData = staticDataMap.get(argumentName);
logger.trace("[{}] Index {} - Loading static data with name {}", device.name(), argumentIndex, argumentName);
CL.clSetKernelArg(kernel, argumentIndex, Sizeof.cl_mem, Pointer.to(clStaticData.clMem()));
}
}
private void allocateLocalMemory(OpenCLExecutionContext openCLExecutionContext, StageDescriptor stageDescriptor,
long generation, List<Genotype> genotypes, final KernelExecutionContext kernelExecutionContext) {
Validate.notNull(openCLExecutionContext);
Validate.notNull(stageDescriptor);
Validate.notNull(kernelExecutionContext);
final var device = openCLExecutionContext.device();
logger.trace("[{}] Allocating local memory", device.name());
final var kernels = openCLExecutionContext.kernels();
final var kernelName = stageDescriptor.kernelName();
final var kernel = kernels.get(kernelName);
final var localMemoryAllocators = stageDescriptor.localMemoryAllocators();
if (MapUtils.isNotEmpty(localMemoryAllocators)) {
for (final var entry : localMemoryAllocators.entrySet()) {
final int argumentIdx = entry.getKey();
final var localMemoryAllocator = entry.getValue();
final var size = localMemoryAllocator
.load(openCLExecutionContext, kernelExecutionContext, generation, genotypes);
logger.trace("[{}] Index {} - Setting local data with size of {}", device.name(), argumentIdx, size);
CL.clSetKernelArg(kernel, argumentIdx, size, null);
}
}
}
protected void loadData(final OpenCLExecutionContext openCLExecutionContext, final StageDescriptor stageDescriptor,
final Map<Integer, CLData> data, final long generation, final List<Genotype> genotypes) {
Validate.notNull(openCLExecutionContext);
Validate.notNull(stageDescriptor);
Validate.notNull(data);
final var device = openCLExecutionContext.device();
logger.trace("[{}] Loading data", device.name());
final var kernels = openCLExecutionContext.kernels();
final var kernelName = stageDescriptor.kernelName();
final var kernel = kernels.get(kernelName);
final var dataLoaders = stageDescriptor.dataLoaders();
if (MapUtils.isNotEmpty(dataLoaders)) {
for (final var entry : dataLoaders.entrySet()) {
final int argumentIdx = entry.getKey();
final var dataLoader = entry.getValue();
final var clDdata = dataLoader.load(openCLExecutionContext, generation, genotypes);
if (data.put(argumentIdx, clDdata) != null) {
throw new IllegalArgumentException("Multiple data configured for index " + argumentIdx);
}
logger.trace("[{}] Index {} - Loading data of size {}", device.name(), argumentIdx, clDdata.size());
CL.clSetKernelArg(kernel, argumentIdx, Sizeof.cl_mem, Pointer.to(clDdata.clMem()));
}
}
}
@Override
public void beforeAllEvaluations(final OpenCLExecutionContext openCLExecutionContext,
final ExecutorService executorService) {
super.beforeAllEvaluations(openCLExecutionContext, executorService);
final var device = openCLExecutionContext.device();
logger.trace("[{}] Loading static data", device.name());
clearStaticData(device);
final var staticDataLoaders = multiStageDescriptor.staticDataLoaders();
for (final var entry : staticDataLoaders.entrySet()) {
final String argumentName = entry.getKey();
final var dataSupplier = entry.getValue();
if (logger.isTraceEnabled()) {
final var deviceName = openCLExecutionContext.device()
.name();
logger.trace("[{}] Loading static data for entry name {}", deviceName, argumentName);
}
final CLData clData = dataSupplier.load(openCLExecutionContext);
final var mapData = staticData.computeIfAbsent(device, k -> new HashMap<>());
if (mapData.put(argumentName, clData) != null) {
throw new IllegalArgumentException("Multiple data configured with name " + argumentName);
}
}
}
/**
* Constructs a MultiStageFitness with the specified stage descriptor and fitness extractor.
*
* @param _multiStageDescriptor configuration for multi-stage kernel execution and data management
* @param _fitnessExtractor function to extract fitness values from final stage results
* @throws IllegalArgumentException if any parameter is null
*/
public MultiStageFitness(final MultiStageDescriptor _multiStageDescriptor,
final FitnessExtractor<T> _fitnessExtractor) {
Validate.notNull(_multiStageDescriptor);
Validate.notNull(_fitnessExtractor);
this.multiStageDescriptor = _multiStageDescriptor;
this.fitnessExtractor = _fitnessExtractor;
}
@Override
public CompletableFuture<List<T>> compute(final OpenCLExecutionContext openCLExecutionContext,
final ExecutorService executorService, final long generation, final List<Genotype> genotypes) {
Validate.notNull(openCLExecutionContext);
return CompletableFuture.supplyAsync(() -> {
List<T> finalResults = null;
final var device = openCLExecutionContext.device();
final Map<Integer, CLData> data = new ConcurrentHashMap<>();
Map<Integer, CLData> resultData = new ConcurrentHashMap<>();
final var stageDescriptors = multiStageDescriptor.stageDescriptors();
for (int i = 0; i < stageDescriptors.size(); i++) {
final StageDescriptor stageDescriptor = stageDescriptors.get(i);
final var kernels = openCLExecutionContext.kernels();
final var kernelName = stageDescriptor.kernelName();
final var kernel = kernels.get(kernelName);
logger.debug("[{}] Executing {}-th stage for kernel {}", device.name(), i, kernelName);
/**
* Compute the Kernel Execution Context
*/
final var kernelExecutionContextComputer = stageDescriptor.kernelExecutionContextComputer();
final var kernelExecutionContext = kernelExecutionContextComputer
.compute(openCLExecutionContext, kernelName, generation, genotypes);
/**
* Map previous results to new arguments
*/
final Map<Integer, CLData> oldResultData = new HashMap<>(resultData);
resultData = new ConcurrentHashMap<>();
final Map<Integer, Integer> reusePreviousResultSizeAsArguments = stageDescriptor
.reusePreviousResultSizeAsArguments();
final Map<Integer, Integer> reusePreviousResultAsArguments = stageDescriptor
.reusePreviousResultAsArguments();
final Set<CLData> reusedArguments = new HashSet<>();
if (MapUtils.isNotEmpty(reusePreviousResultAsArguments)
|| MapUtils.isNotEmpty(reusePreviousResultSizeAsArguments)) {
if (MapUtils.isNotEmpty(reusePreviousResultAsArguments)) {
for (final Entry<Integer, Integer> entry : reusePreviousResultAsArguments.entrySet()) {
final var oldKeyArgument = entry.getKey();
final var newKeyArgument = entry.getValue();
final var previousResultData = oldResultData.get(oldKeyArgument);
if (previousResultData == null) {
logger.error(
"[{}] Could not find previous argument with index {}. Known previous arguments: {}",
device.name(),
oldKeyArgument,
oldResultData);
throw new IllegalArgumentException(
"Could not find previous argument with index " + oldKeyArgument);
}
logger.trace("[{}] Index {} - Reuse previous result that had index {}",
device.name(),
newKeyArgument,
oldKeyArgument);
CL.clSetKernelArg(kernel, newKeyArgument, Sizeof.cl_mem, Pointer.to(previousResultData.clMem()));
reusedArguments.add(previousResultData);
}
}
if (MapUtils.isNotEmpty(reusePreviousResultSizeAsArguments)) {
for (final Entry<Integer, Integer> entry : reusePreviousResultSizeAsArguments.entrySet()) {
final var oldKeyArgument = entry.getKey();
final var newKeyArgument = entry.getValue();
final var previousResultData = oldResultData.get(oldKeyArgument);
if (previousResultData == null) {
logger.error(
"[{}] Could not find previous argument with index {}. Known previous arguments: {}",
device.name(),
oldKeyArgument,
oldResultData);
throw new IllegalArgumentException(
"Could not find previous argument with index " + oldKeyArgument);
}
if (logger.isTraceEnabled()) {
logger.trace("[{}] Index {} - Setting previous result size of {} of previous argument index {}",
device.name(),
newKeyArgument,
previousResultData.size(),
oldKeyArgument);
}
CL.clSetKernelArg(kernel,
newKeyArgument,
Sizeof.cl_int,
Pointer.to(new int[] { previousResultData.size() }));
}
}
// Clean up unused results
final var previousResultsToKeep = reusePreviousResultAsArguments.keySet();
for (Entry<Integer, CLData> entry2 : oldResultData.entrySet()) {
if (previousResultsToKeep.contains(entry2.getKey()) == false) {
CL.clReleaseMemObject(entry2.getValue()
.clMem());
}
}
}
prepareStaticData(openCLExecutionContext, stageDescriptor);
loadData(openCLExecutionContext, stageDescriptor, data, generation, genotypes);
allocateLocalMemory(openCLExecutionContext, stageDescriptor, generation, genotypes, kernelExecutionContext);
/**
* Allocate memory for results
*/
final var resultAllocators = stageDescriptor.resultAllocators();
if (MapUtils.isNotEmpty(resultAllocators)) {
logger.trace("[{}] Result allocators: {}", device.name(), resultAllocators);
for (final var entry : resultAllocators.entrySet()) {
final int argumentIdx = entry.getKey();
final var resultAllocator = entry.getValue();
final var clDdata = resultAllocator
.load(openCLExecutionContext, kernelExecutionContext, generation, genotypes);
if (resultData.put(argumentIdx, clDdata) != null) {
throw new IllegalArgumentException(
"Multiple result allocators configured for index " + argumentIdx);
}
if (logger.isTraceEnabled()) {
logger.trace("[{}] Index {} - Allocate result data memory of type {} and size {}",
device.name(),
argumentIdx,
clDdata.clType(),
clDdata.size());
}
CL.clSetKernelArg(kernel, argumentIdx, Sizeof.cl_mem, Pointer.to(clDdata.clMem()));
}
} else {
logger.trace("[{}] No result allocator found", device.name());
}
final var clCommandQueue = openCLExecutionContext.clCommandQueue();
final var globalWorkDimensions = kernelExecutionContext.globalWorkDimensions();
final var globalWorkSize = kernelExecutionContext.globalWorkSize();
final long[] workGroupSize = kernelExecutionContext.workGroupSize()
.orElse(null);
logger.trace(
"[{}] Starting computation on kernel {} for {} genotypes and global work size {} and local work size {}",
device.name(),
kernelName,
genotypes.size(),
globalWorkSize,
workGroupSize);
try {
final long startTime = System.nanoTime();
CL.clEnqueueNDRangeKernel(clCommandQueue,
kernel,
globalWorkDimensions,
null,
globalWorkSize,
workGroupSize,
0,
null,
null);
// CL.clFinish(openCLExecutionContext.clCommandQueue());
final long endTime = System.nanoTime();
final long duration = endTime - startTime;
if (logger.isDebugEnabled()) {
final var deviceName = openCLExecutionContext.device()
.name();
logger.debug("[{}] - Stage {} - Took {} microsec for {} genotypes",
deviceName,
i,
duration / 1000.,
genotypes.size());
}
} catch (Exception e) {
logger.error("[{}] Failure to compute", device.name(), e);
throw e;
}
if (i == stageDescriptors.size() - 1) {
finalResults = fitnessExtractor.compute(openCLExecutionContext,
kernelExecutionContext,
executorService,
generation,
genotypes,
new ResultExtractor(Map.of(device, resultData)));
clearResultData(resultData);
}
for (final CLData clData : reusedArguments) {
CL.clReleaseMemObject(clData.clMem());
}
clearData(data);
}
if (finalResults == null) {
throw new IllegalStateException("final results cannot be null");
}
return finalResults;
}, executorService);
}
@Override
public void afterEvaluation(OpenCLExecutionContext openCLExecutionContext, ExecutorService executorService,
long generation, List<Genotype> genotypes) {
super.afterEvaluation(openCLExecutionContext, executorService, generation, genotypes);
final var device = openCLExecutionContext.device();
logger.trace("[{}] Releasing data", device.name());
}
@Override
public void afterAllEvaluations(final OpenCLExecutionContext openCLExecutionContext,
final ExecutorService executorService) {
super.afterAllEvaluations(openCLExecutionContext, executorService);
final var device = openCLExecutionContext.device();
logger.trace("[{}] Releasing static data", device.name());
clearStaticData(device);
}
/**
* Creates a new MultiStageFitness instance with the specified configuration.
*
* @param <U> the fitness value type
* @param multiStageDescriptor configuration for multi-stage kernel execution and data management
* @param fitnessExtractor function to extract fitness values from final stage results
* @return a new MultiStageFitness instance
* @throws IllegalArgumentException if any parameter is null
*/
public static <U extends Comparable<U>> MultiStageFitness<U> of(final MultiStageDescriptor multiStageDescriptor,
final FitnessExtractor<U> fitnessExtractor) {
Validate.notNull(multiStageDescriptor);
Validate.notNull(fitnessExtractor);
return new MultiStageFitness<>(multiStageDescriptor, fitnessExtractor);
}
}