MultiStageFitness.java

package net.bmahe.genetics4j.gpu.spec.fitness;

import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;

import org.apache.commons.collections4.MapUtils;
import org.apache.commons.lang3.Validate;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.jocl.CL;
import org.jocl.Pointer;
import org.jocl.Sizeof;

import net.bmahe.genetics4j.core.Genotype;
import net.bmahe.genetics4j.gpu.opencl.OpenCLExecutionContext;
import net.bmahe.genetics4j.gpu.opencl.model.Device;
import net.bmahe.genetics4j.gpu.spec.fitness.cldata.CLData;
import net.bmahe.genetics4j.gpu.spec.fitness.kernelcontext.KernelExecutionContext;
import net.bmahe.genetics4j.gpu.spec.fitness.multistage.MultiStageDescriptor;
import net.bmahe.genetics4j.gpu.spec.fitness.multistage.StageDescriptor;

/**
 * GPU-accelerated fitness evaluator that executes multiple sequential OpenCL kernels for complex fitness computation.
 * 
 * <p>MultiStageFitness provides a framework for implementing fitness evaluation that requires multiple sequential
 * GPU kernel executions, where each stage can use results from previous stages as input. This is ideal for complex
 * fitness functions that require multiple computational phases, such as neural network training, multi-objective
 * optimization, or hierarchical problem decomposition.
 * 
 * <p>Key features:
 * <ul>
 * <li><strong>Sequential execution</strong>: Multiple OpenCL kernels executed in sequence</li>
 * <li><strong>Inter-stage data flow</strong>: Results from earlier stages used as inputs to later stages</li>
 * <li><strong>Memory optimization</strong>: Automatic cleanup and reuse of intermediate results</li>
 * <li><strong>Pipeline processing</strong>: Support for complex computational pipelines</li>
 * <li><strong>Stage configuration</strong>: Individual configuration for each computational stage</li>
 * </ul>
 * 
 * <p>Multi-stage computation architecture:
 * <ul>
 * <li><strong>Stage descriptors</strong>: Each stage defines its kernel, data loaders, and result allocators</li>
 * <li><strong>Data reuse patterns</strong>: Previous stage results can be reused as arguments or size parameters</li>
 * <li><strong>Memory lifecycle</strong>: Automatic management of intermediate results between stages</li>
 * <li><strong>Static data sharing</strong>: Algorithm parameters shared across all stages</li>
 * </ul>
 * 
 * <p>Typical usage pattern:
 * <pre>{@code
 * // Define multi-stage descriptor with sequential kernels
 * MultiStageDescriptor descriptor = MultiStageDescriptor.builder()
 *     .addStaticDataLoader("parameters", parametersLoader)
 *     .addStage(StageDescriptor.builder()
 *         .kernelName("preprocessing")
 *         .addDataLoader(0, inputDataLoader)
 *         .addResultAllocator(1, preprocessedResultAllocator)
 *         .build())
 *     .addStage(StageDescriptor.builder()
 *         .kernelName("fitness_evaluation")
 *         .reusePreviousResultAsArgument(1, 0)  // Use previous result as input
 *         .addResultAllocator(1, fitnessResultAllocator)
 *         .build())
 *     .build();
 * 
 * // Define fitness extraction from final stage results
 * FitnessExtractor<Double> extractor = (context, kernelCtx, executor, generation, genotypes, results) -> {
 *     float[] fitnessValues = results.extractFloatArray(context, 1);
 *     return Arrays.stream(fitnessValues)
 *         .mapToDouble(f -> (double) f)
 *         .boxed()
 *         .collect(Collectors.toList());
 * };
 * 
 * // Create multi-stage fitness evaluator
 * MultiStageFitness<Double> fitness = MultiStageFitness.of(descriptor, extractor);
 * }</pre>
 * 
 * <p>Stage execution workflow:
 * <ol>
 * <li><strong>Initialization</strong>: Load shared static data once before all evaluations</li>
 * <li><strong>Stage iteration</strong>: For each stage in sequence:</li>
 * <li><strong>Context computation</strong>: Calculate kernel execution parameters for the stage</li>
 * <li><strong>Data preparation</strong>: Load stage-specific data and map previous results</li>
 * <li><strong>Kernel execution</strong>: Execute the stage kernel with configured parameters</li>
 * <li><strong>Result management</strong>: Store results for potential use in subsequent stages</li>
 * <li><strong>Final extraction</strong>: Extract fitness values from the last stage results</li>
 * <li><strong>Cleanup</strong>: Release all intermediate and final result memory</li>
 * </ol>
 * 
 * <p>Inter-stage data flow patterns:
 * <ul>
 * <li><strong>Result reuse</strong>: Use previous stage output buffers as input to subsequent stages</li>
 * <li><strong>Size propagation</strong>: Use previous stage result sizes as parameters for memory allocation</li>
 * <li><strong>Memory optimization</strong>: Automatic cleanup of intermediate results no longer needed</li>
 * <li><strong>Data type preservation</strong>: Maintain OpenCL data types across stage boundaries</li>
 * </ul>
 * 
 * <p>Memory management strategy:
 * <ul>
 * <li><strong>Static data persistence</strong>: Shared parameters allocated once across all stages</li>
 * <li><strong>Intermediate cleanup</strong>: Automatic release of stage results when no longer needed</li>
 * <li><strong>Result chaining</strong>: Efficient memory reuse between consecutive stages</li>
 * <li><strong>Final cleanup</strong>: Complete memory cleanup after fitness extraction</li>
 * </ul>
 * 
 * <p>Performance optimization features:
 * <ul>
 * <li><strong>Pipeline efficiency</strong>: Minimized memory transfers between stages</li>
 * <li><strong>Memory coalescing</strong>: Optimized data layouts for GPU memory access</li>
 * <li><strong>Stage-specific tuning</strong>: Individual work group optimization per stage</li>
 * <li><strong>Asynchronous execution</strong>: Non-blocking fitness computation</li>
 * </ul>
 * 
 * @param <T> the fitness value type, must be Comparable for optimization algorithms
 * @see OpenCLFitness
 * @see MultiStageDescriptor
 * @see StageDescriptor
 * @see FitnessExtractor
 */
public class MultiStageFitness<T extends Comparable<T>> extends OpenCLFitness<T> {
	public static final Logger logger = LogManager.getLogger(MultiStageFitness.class);

	private final MultiStageDescriptor multiStageDescriptor;
	private final FitnessExtractor<T> fitnessExtractor;

	private final Map<Device, Map<String, CLData>> staticData = new ConcurrentHashMap<>();

	protected void clearStaticData(final Device device) {
		if (MapUtils.isEmpty(staticData) || MapUtils.isEmpty(staticData.get(device))) {
			return;
		}

		final Map<String, CLData> mapData = staticData.get(device);
		for (final CLData clData : mapData.values()) {
			CL.clReleaseMemObject(clData.clMem());
		}

		mapData.clear();
		staticData.remove(device);
	}

	protected void clearData(final Map<Integer, CLData> data) {
		if (MapUtils.isEmpty(data)) {
			return;
		}

		for (final CLData clData : data.values()) {
			CL.clReleaseMemObject(clData.clMem());
		}

		data.clear();
	}

	protected void clearResultData(final Map<Integer, CLData> resultData) {
		if (MapUtils.isEmpty(resultData)) {
			return;
		}

		for (final CLData clData : resultData.values()) {
			CL.clReleaseMemObject(clData.clMem());
		}

		resultData.clear();
	}

	protected void prepareStaticData(final OpenCLExecutionContext openCLExecutionContext,
			final StageDescriptor stageDescriptor) {
		Validate.notNull(openCLExecutionContext);
		Validate.notNull(stageDescriptor);

		final var device = openCLExecutionContext.device();

		logger.trace("[{}] Preparing static data", device.name());

		final var kernels = openCLExecutionContext.kernels();
		final var kernelName = stageDescriptor.kernelName();

		final var kernel = kernels.get(kernelName);

		final var mapStaticDataAsArgument = stageDescriptor.mapStaticDataAsArgument();
		for (final var entry : mapStaticDataAsArgument.entrySet()) {
			final var argumentName = entry.getKey();
			final var argumentIndex = entry.getValue();

			final var staticDataMap = staticData.get(device);

			if (staticDataMap.containsKey(argumentName) == false) {
				throw new IllegalArgumentException("Unknown static argument " + argumentName);
			}

			final CLData clStaticData = staticDataMap.get(argumentName);

			logger.trace("[{}] Index {} - Loading static data with name {}", device.name(), argumentIndex, argumentName);

			CL.clSetKernelArg(kernel, argumentIndex, Sizeof.cl_mem, Pointer.to(clStaticData.clMem()));
		}
	}

	private void allocateLocalMemory(OpenCLExecutionContext openCLExecutionContext, StageDescriptor stageDescriptor,
			long generation, List<Genotype> genotypes, final KernelExecutionContext kernelExecutionContext) {
		Validate.notNull(openCLExecutionContext);
		Validate.notNull(stageDescriptor);
		Validate.notNull(kernelExecutionContext);

		final var device = openCLExecutionContext.device();

		logger.trace("[{}] Allocating local memory", device.name());

		final var kernels = openCLExecutionContext.kernels();
		final var kernelName = stageDescriptor.kernelName();

		final var kernel = kernels.get(kernelName);

		final var localMemoryAllocators = stageDescriptor.localMemoryAllocators();
		if (MapUtils.isNotEmpty(localMemoryAllocators)) {
			for (final var entry : localMemoryAllocators.entrySet()) {
				final int argumentIdx = entry.getKey();
				final var localMemoryAllocator = entry.getValue();

				final var size = localMemoryAllocator
						.load(openCLExecutionContext, kernelExecutionContext, generation, genotypes);
				logger.trace("[{}] Index {} - Setting local data with size of {}", device.name(), argumentIdx, size);

				CL.clSetKernelArg(kernel, argumentIdx, size, null);
			}
		}
	}

	protected void loadData(final OpenCLExecutionContext openCLExecutionContext, final StageDescriptor stageDescriptor,
			final Map<Integer, CLData> data, final long generation, final List<Genotype> genotypes) {
		Validate.notNull(openCLExecutionContext);
		Validate.notNull(stageDescriptor);
		Validate.notNull(data);

		final var device = openCLExecutionContext.device();

		logger.trace("[{}] Loading data", device.name());

		final var kernels = openCLExecutionContext.kernels();
		final var kernelName = stageDescriptor.kernelName();

		final var kernel = kernels.get(kernelName);

		final var dataLoaders = stageDescriptor.dataLoaders();
		if (MapUtils.isNotEmpty(dataLoaders)) {
			for (final var entry : dataLoaders.entrySet()) {
				final int argumentIdx = entry.getKey();
				final var dataLoader = entry.getValue();

				final var clDdata = dataLoader.load(openCLExecutionContext, generation, genotypes);

				if (data.put(argumentIdx, clDdata) != null) {
					throw new IllegalArgumentException("Multiple data configured for index " + argumentIdx);
				}
				logger.trace("[{}] Index {} - Loading data of size {}", device.name(), argumentIdx, clDdata.size());

				CL.clSetKernelArg(kernel, argumentIdx, Sizeof.cl_mem, Pointer.to(clDdata.clMem()));
			}
		}
	}

	@Override
	public void beforeAllEvaluations(final OpenCLExecutionContext openCLExecutionContext,
			final ExecutorService executorService) {
		super.beforeAllEvaluations(openCLExecutionContext, executorService);

		final var device = openCLExecutionContext.device();

		logger.trace("[{}] Loading static data", device.name());
		clearStaticData(device);

		final var staticDataLoaders = multiStageDescriptor.staticDataLoaders();
		for (final var entry : staticDataLoaders.entrySet()) {
			final String argumentName = entry.getKey();
			final var dataSupplier = entry.getValue();

			if (logger.isTraceEnabled()) {
				final var deviceName = openCLExecutionContext.device()
						.name();
				logger.trace("[{}] Loading static data for entry name {}", deviceName, argumentName);
			}
			final CLData clData = dataSupplier.load(openCLExecutionContext);

			final var mapData = staticData.computeIfAbsent(device, k -> new HashMap<>());
			if (mapData.put(argumentName, clData) != null) {
				throw new IllegalArgumentException("Multiple data configured with name " + argumentName);
			}
		}
	}

	/**
	 * Constructs a MultiStageFitness with the specified stage descriptor and fitness extractor.
	 * 
	 * @param _multiStageDescriptor configuration for multi-stage kernel execution and data management
	 * @param _fitnessExtractor function to extract fitness values from final stage results
	 * @throws IllegalArgumentException if any parameter is null
	 */
	public MultiStageFitness(final MultiStageDescriptor _multiStageDescriptor,
			final FitnessExtractor<T> _fitnessExtractor) {
		Validate.notNull(_multiStageDescriptor);
		Validate.notNull(_fitnessExtractor);

		this.multiStageDescriptor = _multiStageDescriptor;
		this.fitnessExtractor = _fitnessExtractor;
	}

	@Override
	public CompletableFuture<List<T>> compute(final OpenCLExecutionContext openCLExecutionContext,
			final ExecutorService executorService, final long generation, final List<Genotype> genotypes) {
		Validate.notNull(openCLExecutionContext);

		return CompletableFuture.supplyAsync(() -> {

			List<T> finalResults = null;

			final var device = openCLExecutionContext.device();

			final Map<Integer, CLData> data = new ConcurrentHashMap<>();
			Map<Integer, CLData> resultData = new ConcurrentHashMap<>();

			final var stageDescriptors = multiStageDescriptor.stageDescriptors();
			for (int i = 0; i < stageDescriptors.size(); i++) {
				final StageDescriptor stageDescriptor = stageDescriptors.get(i);
				final var kernels = openCLExecutionContext.kernels();
				final var kernelName = stageDescriptor.kernelName();

				final var kernel = kernels.get(kernelName);

				logger.debug("[{}] Executing {}-th stage for kernel {}", device.name(), i, kernelName);

				/**
				 * Compute the Kernel Execution Context
				 */
				final var kernelExecutionContextComputer = stageDescriptor.kernelExecutionContextComputer();
				final var kernelExecutionContext = kernelExecutionContextComputer
						.compute(openCLExecutionContext, kernelName, generation, genotypes);

				/**
				 * Map previous results to new arguments
				 */
				final Map<Integer, CLData> oldResultData = new HashMap<>(resultData);
				resultData = new ConcurrentHashMap<>();
				final Map<Integer, Integer> reusePreviousResultSizeAsArguments = stageDescriptor
						.reusePreviousResultSizeAsArguments();
				final Map<Integer, Integer> reusePreviousResultAsArguments = stageDescriptor
						.reusePreviousResultAsArguments();
				final Set<CLData> reusedArguments = new HashSet<>();
				if (MapUtils.isNotEmpty(reusePreviousResultAsArguments)
						|| MapUtils.isNotEmpty(reusePreviousResultSizeAsArguments)) {

					if (MapUtils.isNotEmpty(reusePreviousResultAsArguments)) {
						for (final Entry<Integer, Integer> entry : reusePreviousResultAsArguments.entrySet()) {
							final var oldKeyArgument = entry.getKey();
							final var newKeyArgument = entry.getValue();

							final var previousResultData = oldResultData.get(oldKeyArgument);
							if (previousResultData == null) {
								logger.error(
										"[{}] Could not find previous argument with index {}. Known previous arguments: {}",
										device.name(),
										oldKeyArgument,
										oldResultData);

								throw new IllegalArgumentException(
										"Could not find previous argument with index " + oldKeyArgument);
							}

							logger.trace("[{}] Index {} - Reuse previous result that had index {}",
									device.name(),
									newKeyArgument,
									oldKeyArgument);
							CL.clSetKernelArg(kernel, newKeyArgument, Sizeof.cl_mem, Pointer.to(previousResultData.clMem()));
							reusedArguments.add(previousResultData);
						}
					}

					if (MapUtils.isNotEmpty(reusePreviousResultSizeAsArguments)) {
						for (final Entry<Integer, Integer> entry : reusePreviousResultSizeAsArguments.entrySet()) {
							final var oldKeyArgument = entry.getKey();
							final var newKeyArgument = entry.getValue();

							final var previousResultData = oldResultData.get(oldKeyArgument);
							if (previousResultData == null) {
								logger.error(
										"[{}] Could not find previous argument with index {}. Known previous arguments: {}",
										device.name(),
										oldKeyArgument,
										oldResultData);

								throw new IllegalArgumentException(
										"Could not find previous argument with index " + oldKeyArgument);
							}

							if (logger.isTraceEnabled()) {
								logger.trace("[{}] Index {} - Setting previous result size of {} of previous argument index {}",
										device.name(),
										newKeyArgument,
										previousResultData.size(),
										oldKeyArgument);
							}

							CL.clSetKernelArg(kernel,
									newKeyArgument,
									Sizeof.cl_int,
									Pointer.to(new int[] { previousResultData.size() }));
						}
					}

					// Clean up unused results
					final var previousResultsToKeep = reusePreviousResultAsArguments.keySet();
					for (Entry<Integer, CLData> entry2 : oldResultData.entrySet()) {
						if (previousResultsToKeep.contains(entry2.getKey()) == false) {
							CL.clReleaseMemObject(entry2.getValue()
									.clMem());
						}
					}
				}

				prepareStaticData(openCLExecutionContext, stageDescriptor);
				loadData(openCLExecutionContext, stageDescriptor, data, generation, genotypes);
				allocateLocalMemory(openCLExecutionContext, stageDescriptor, generation, genotypes, kernelExecutionContext);

				/**
				 * Allocate memory for results
				 */
				final var resultAllocators = stageDescriptor.resultAllocators();
				if (MapUtils.isNotEmpty(resultAllocators)) {
					logger.trace("[{}] Result allocators: {}", device.name(), resultAllocators);

					for (final var entry : resultAllocators.entrySet()) {
						final int argumentIdx = entry.getKey();
						final var resultAllocator = entry.getValue();

						final var clDdata = resultAllocator
								.load(openCLExecutionContext, kernelExecutionContext, generation, genotypes);

						if (resultData.put(argumentIdx, clDdata) != null) {
							throw new IllegalArgumentException(
									"Multiple result allocators configured for index " + argumentIdx);
						}
						if (logger.isTraceEnabled()) {
							logger.trace("[{}] Index {} - Allocate result data memory of type {} and size {}",
									device.name(),
									argumentIdx,
									clDdata.clType(),
									clDdata.size());
						}

						CL.clSetKernelArg(kernel, argumentIdx, Sizeof.cl_mem, Pointer.to(clDdata.clMem()));
					}
				} else {
					logger.trace("[{}] No result allocator found", device.name());
				}

				final var clCommandQueue = openCLExecutionContext.clCommandQueue();
				final var globalWorkDimensions = kernelExecutionContext.globalWorkDimensions();
				final var globalWorkSize = kernelExecutionContext.globalWorkSize();
				final long[] workGroupSize = kernelExecutionContext.workGroupSize()
						.orElse(null);

				logger.trace(
						"[{}] Starting computation on kernel {} for {} genotypes and global work size {} and local work size {}",
						device.name(),
						kernelName,
						genotypes.size(),
						globalWorkSize,
						workGroupSize);
				try {
					final long startTime = System.nanoTime();
					CL.clEnqueueNDRangeKernel(clCommandQueue,
							kernel,
							globalWorkDimensions,
							null,
							globalWorkSize,
							workGroupSize,
							0,
							null,
							null);

					// CL.clFinish(openCLExecutionContext.clCommandQueue());

					final long endTime = System.nanoTime();
					final long duration = endTime - startTime;
					if (logger.isDebugEnabled()) {
						final var deviceName = openCLExecutionContext.device()
								.name();
						logger.debug("[{}] -  Stage {} - Took {} microsec for {} genotypes",
								deviceName,
								i,
								duration / 1000.,
								genotypes.size());
					}
				} catch (Exception e) {
					logger.error("[{}] Failure to compute", device.name(), e);
					throw e;
				}

				if (i == stageDescriptors.size() - 1) {
					finalResults = fitnessExtractor.compute(openCLExecutionContext,
							kernelExecutionContext,
							executorService,
							generation,
							genotypes,
							new ResultExtractor(Map.of(device, resultData)));

					clearResultData(resultData);
				}

				for (final CLData clData : reusedArguments) {
					CL.clReleaseMemObject(clData.clMem());
				}
				clearData(data);
			}

			if (finalResults == null) {
				throw new IllegalStateException("final results cannot be null");
			}
			return finalResults;
		}, executorService);
	}

	@Override
	public void afterEvaluation(OpenCLExecutionContext openCLExecutionContext, ExecutorService executorService,
			long generation, List<Genotype> genotypes) {
		super.afterEvaluation(openCLExecutionContext, executorService, generation, genotypes);

		final var device = openCLExecutionContext.device();
		logger.trace("[{}] Releasing data", device.name());
	}

	@Override
	public void afterAllEvaluations(final OpenCLExecutionContext openCLExecutionContext,
			final ExecutorService executorService) {
		super.afterAllEvaluations(openCLExecutionContext, executorService);

		final var device = openCLExecutionContext.device();
		logger.trace("[{}] Releasing static data", device.name());
		clearStaticData(device);
	}

	/**
	 * Creates a new MultiStageFitness instance with the specified configuration.
	 * 
	 * @param <U> the fitness value type
	 * @param multiStageDescriptor configuration for multi-stage kernel execution and data management
	 * @param fitnessExtractor function to extract fitness values from final stage results
	 * @return a new MultiStageFitness instance
	 * @throws IllegalArgumentException if any parameter is null
	 */
	public static <U extends Comparable<U>> MultiStageFitness<U> of(final MultiStageDescriptor multiStageDescriptor,
			final FitnessExtractor<U> fitnessExtractor) {
		Validate.notNull(multiStageDescriptor);
		Validate.notNull(fitnessExtractor);

		return new MultiStageFitness<>(multiStageDescriptor, fitnessExtractor);
	}
}