MultiStageFitness.java

package net.bmahe.genetics4j.gpu.spec.fitness;

import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;

import org.apache.commons.collections4.MapUtils;
import org.apache.commons.lang3.Validate;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.jocl.CL;
import org.jocl.Pointer;
import org.jocl.Sizeof;

import net.bmahe.genetics4j.core.Genotype;
import net.bmahe.genetics4j.gpu.opencl.OpenCLExecutionContext;
import net.bmahe.genetics4j.gpu.opencl.model.Device;
import net.bmahe.genetics4j.gpu.spec.fitness.cldata.CLData;
import net.bmahe.genetics4j.gpu.spec.fitness.kernelcontext.KernelExecutionContext;
import net.bmahe.genetics4j.gpu.spec.fitness.multistage.MultiStageDescriptor;
import net.bmahe.genetics4j.gpu.spec.fitness.multistage.StageDescriptor;

public class MultiStageFitness<T extends Comparable<T>> extends OpenCLFitness<T> {
	public static final Logger logger = LogManager.getLogger(MultiStageFitness.class);

	private final MultiStageDescriptor multiStageDescriptor;
	private final FitnessExtractor<T> fitnessExtractor;

	private final Map<Device, Map<String, CLData>> staticData = new ConcurrentHashMap<>();

	protected void clearStaticData(final Device device) {
		if (MapUtils.isEmpty(staticData) || MapUtils.isEmpty(staticData.get(device))) {
			return;
		}

		final Map<String, CLData> mapData = staticData.get(device);
		for (final CLData clData : mapData.values()) {
			CL.clReleaseMemObject(clData.clMem());
		}

		mapData.clear();
		staticData.remove(device);
	}

	protected void clearData(final Map<Integer, CLData> data) {
		if (MapUtils.isEmpty(data)) {
			return;
		}

		for (final CLData clData : data.values()) {
			CL.clReleaseMemObject(clData.clMem());
		}

		data.clear();
	}

	protected void clearResultData(final Map<Integer, CLData> resultData) {
		if (MapUtils.isEmpty(resultData)) {
			return;
		}

		for (final CLData clData : resultData.values()) {
			CL.clReleaseMemObject(clData.clMem());
		}

		resultData.clear();
	}

	protected void prepareStaticData(final OpenCLExecutionContext openCLExecutionContext,
			final StageDescriptor stageDescriptor) {
		Validate.notNull(openCLExecutionContext);
		Validate.notNull(stageDescriptor);

		final var device = openCLExecutionContext.device();

		logger.trace("[{}] Preparing static data", device.name());

		final var kernels = openCLExecutionContext.kernels();
		final var kernelName = stageDescriptor.kernelName();

		final var kernel = kernels.get(kernelName);

		final var mapStaticDataAsArgument = stageDescriptor.mapStaticDataAsArgument();
		for (final var entry : mapStaticDataAsArgument.entrySet()) {
			final var argumentName = entry.getKey();
			final var argumentIndex = entry.getValue();

			final var staticDataMap = staticData.get(device);

			if (staticDataMap.containsKey(argumentName) == false) {
				throw new IllegalArgumentException("Unknown static argument " + argumentName);
			}

			final CLData clStaticData = staticDataMap.get(argumentName);

			logger.trace("[{}] Index {} - Loading static data with name {}", device.name(), argumentIndex, argumentName);

			CL.clSetKernelArg(kernel, argumentIndex, Sizeof.cl_mem, Pointer.to(clStaticData.clMem()));
		}
	}

	private void allocateLocalMemory(OpenCLExecutionContext openCLExecutionContext, StageDescriptor stageDescriptor,
			long generation, List<Genotype> genotypes, final KernelExecutionContext kernelExecutionContext) {
		Validate.notNull(openCLExecutionContext);
		Validate.notNull(stageDescriptor);
		Validate.notNull(kernelExecutionContext);

		final var device = openCLExecutionContext.device();

		logger.trace("[{}] Allocating local memory", device.name());

		final var kernels = openCLExecutionContext.kernels();
		final var kernelName = stageDescriptor.kernelName();

		final var kernel = kernels.get(kernelName);

		final var localMemoryAllocators = stageDescriptor.localMemoryAllocators();
		if (MapUtils.isNotEmpty(localMemoryAllocators)) {
			for (final var entry : localMemoryAllocators.entrySet()) {
				final int argumentIdx = entry.getKey();
				final var localMemoryAllocator = entry.getValue();

				final var size = localMemoryAllocator
						.load(openCLExecutionContext, kernelExecutionContext, generation, genotypes);
				logger.trace("[{}] Index {} - Setting local data with size of {}", device.name(), argumentIdx, size);

				CL.clSetKernelArg(kernel, argumentIdx, size, null);
			}
		}
	}

	protected void loadData(final OpenCLExecutionContext openCLExecutionContext, final StageDescriptor stageDescriptor,
			final Map<Integer, CLData> data, final long generation, final List<Genotype> genotypes) {
		Validate.notNull(openCLExecutionContext);
		Validate.notNull(stageDescriptor);
		Validate.notNull(data);

		final var device = openCLExecutionContext.device();

		logger.trace("[{}] Loading data", device.name());

		final var kernels = openCLExecutionContext.kernels();
		final var kernelName = stageDescriptor.kernelName();

		final var kernel = kernels.get(kernelName);

		final var dataLoaders = stageDescriptor.dataLoaders();
		if (MapUtils.isNotEmpty(dataLoaders)) {
			for (final var entry : dataLoaders.entrySet()) {
				final int argumentIdx = entry.getKey();
				final var dataLoader = entry.getValue();

				final var clDdata = dataLoader.load(openCLExecutionContext, generation, genotypes);

				if (data.put(argumentIdx, clDdata) != null) {
					throw new IllegalArgumentException("Multiple data configured for index " + argumentIdx);
				}
				logger.trace("[{}] Index {} - Loading data of size {}", device.name(), argumentIdx, clDdata.size());

				CL.clSetKernelArg(kernel, argumentIdx, Sizeof.cl_mem, Pointer.to(clDdata.clMem()));
			}
		}
	}

	@Override
	public void beforeAllEvaluations(final OpenCLExecutionContext openCLExecutionContext,
			final ExecutorService executorService) {
		super.beforeAllEvaluations(openCLExecutionContext, executorService);

		final var device = openCLExecutionContext.device();

		logger.trace("[{}] Loading static data", device.name());
		clearStaticData(device);

		final var staticDataLoaders = multiStageDescriptor.staticDataLoaders();
		for (final var entry : staticDataLoaders.entrySet()) {
			final String argumentName = entry.getKey();
			final var dataSupplier = entry.getValue();

			if (logger.isTraceEnabled()) {
				final var deviceName = openCLExecutionContext.device()
						.name();
				logger.trace("[{}] Loading static data for entry name {}", deviceName, argumentName);
			}
			final CLData clData = dataSupplier.load(openCLExecutionContext);

			final var mapData = staticData.computeIfAbsent(device, k -> new HashMap<>());
			if (mapData.put(argumentName, clData) != null) {
				throw new IllegalArgumentException("Multiple data configured with name " + argumentName);
			}
		}
	}

	public MultiStageFitness(final MultiStageDescriptor _multiStageDescriptor,
			final FitnessExtractor<T> _fitnessExtractor) {
		Validate.notNull(_multiStageDescriptor);
		Validate.notNull(_fitnessExtractor);

		this.multiStageDescriptor = _multiStageDescriptor;
		this.fitnessExtractor = _fitnessExtractor;
	}

	@Override
	public CompletableFuture<List<T>> compute(final OpenCLExecutionContext openCLExecutionContext,
			final ExecutorService executorService, final long generation, final List<Genotype> genotypes) {
		Validate.notNull(openCLExecutionContext);

		return CompletableFuture.supplyAsync(() -> {

			List<T> finalResults = null;

			final var device = openCLExecutionContext.device();

			final Map<Integer, CLData> data = new ConcurrentHashMap<>();
			Map<Integer, CLData> resultData = new ConcurrentHashMap<>();

			final var stageDescriptors = multiStageDescriptor.stageDescriptors();
			for (int i = 0; i < stageDescriptors.size(); i++) {
				final StageDescriptor stageDescriptor = stageDescriptors.get(i);
				final var kernels = openCLExecutionContext.kernels();
				final var kernelName = stageDescriptor.kernelName();

				final var kernel = kernels.get(kernelName);

				logger.debug("[{}] Executing {}-th stage for kernel {}", device.name(), i, kernelName);

				/**
				 * Compute the Kernel Execution Context
				 */
				final var kernelExecutionContextComputer = stageDescriptor.kernelExecutionContextComputer();
				final var kernelExecutionContext = kernelExecutionContextComputer
						.compute(openCLExecutionContext, kernelName, generation, genotypes);

				/**
				 * Map previous results to new arguments
				 */
				final Map<Integer, CLData> oldResultData = new HashMap<>(resultData);
				resultData = new ConcurrentHashMap<>();
				final Map<Integer, Integer> reusePreviousResultSizeAsArguments = stageDescriptor
						.reusePreviousResultSizeAsArguments();
				final Map<Integer, Integer> reusePreviousResultAsArguments = stageDescriptor
						.reusePreviousResultAsArguments();
				final Set<CLData> reusedArguments = new HashSet<>();
				if (MapUtils.isNotEmpty(reusePreviousResultAsArguments)
						|| MapUtils.isNotEmpty(reusePreviousResultSizeAsArguments)) {

					if (MapUtils.isNotEmpty(reusePreviousResultAsArguments)) {
						for (final Entry<Integer, Integer> entry : reusePreviousResultAsArguments.entrySet()) {
							final var oldKeyArgument = entry.getKey();
							final var newKeyArgument = entry.getValue();

							final var previousResultData = oldResultData.get(oldKeyArgument);
							if (previousResultData == null) {
								logger.error(
										"[{}] Could not find previous argument with index {}. Known previous arguments: {}",
										device.name(),
										oldKeyArgument,
										oldResultData);

								throw new IllegalArgumentException(
										"Could not find previous argument with index " + oldKeyArgument);
							}

							logger.trace("[{}] Index {} - Reuse previous result that had index {}",
									device.name(),
									newKeyArgument,
									oldKeyArgument);
							CL.clSetKernelArg(kernel, newKeyArgument, Sizeof.cl_mem, Pointer.to(previousResultData.clMem()));
							reusedArguments.add(previousResultData);
						}
					}

					if (MapUtils.isNotEmpty(reusePreviousResultSizeAsArguments)) {
						for (final Entry<Integer, Integer> entry : reusePreviousResultSizeAsArguments.entrySet()) {
							final var oldKeyArgument = entry.getKey();
							final var newKeyArgument = entry.getValue();

							final var previousResultData = oldResultData.get(oldKeyArgument);
							if (previousResultData == null) {
								logger.error(
										"[{}] Could not find previous argument with index {}. Known previous arguments: {}",
										device.name(),
										oldKeyArgument,
										oldResultData);

								throw new IllegalArgumentException(
										"Could not find previous argument with index " + oldKeyArgument);
							}

							if (logger.isTraceEnabled()) {
								logger.trace("[{}] Index {} - Setting previous result size of {} of previous argument index {}",
										device.name(),
										newKeyArgument,
										previousResultData.size(),
										oldKeyArgument);
							}

							CL.clSetKernelArg(kernel,
									newKeyArgument,
									Sizeof.cl_int,
									Pointer.to(new int[] { previousResultData.size() }));
						}
					}

					// Clean up unused results
					final var previousResultsToKeep = reusePreviousResultAsArguments.keySet();
					for (Entry<Integer, CLData> entry2 : oldResultData.entrySet()) {
						if (previousResultsToKeep.contains(entry2.getKey()) == false) {
							CL.clReleaseMemObject(entry2.getValue()
									.clMem());
						}
					}
				}

				prepareStaticData(openCLExecutionContext, stageDescriptor);
				loadData(openCLExecutionContext, stageDescriptor, data, generation, genotypes);
				allocateLocalMemory(openCLExecutionContext, stageDescriptor, generation, genotypes, kernelExecutionContext);

				/**
				 * Allocate memory for results
				 */
				final var resultAllocators = stageDescriptor.resultAllocators();
				if (MapUtils.isNotEmpty(resultAllocators)) {
					logger.trace("[{}] Result allocators: {}", device.name(), resultAllocators);

					for (final var entry : resultAllocators.entrySet()) {
						final int argumentIdx = entry.getKey();
						final var resultAllocator = entry.getValue();

						final var clDdata = resultAllocator
								.load(openCLExecutionContext, kernelExecutionContext, generation, genotypes);

						if (resultData.put(argumentIdx, clDdata) != null) {
							throw new IllegalArgumentException(
									"Multiple result allocators configured for index " + argumentIdx);
						}
						if (logger.isTraceEnabled()) {
							logger.trace("[{}] Index {} - Allocate result data memory of type {} and size {}",
									device.name(),
									argumentIdx,
									clDdata.clType(),
									clDdata.size());
						}

						CL.clSetKernelArg(kernel, argumentIdx, Sizeof.cl_mem, Pointer.to(clDdata.clMem()));
					}
				} else {
					logger.trace("[{}] No result allocator found", device.name());
				}

				final var clCommandQueue = openCLExecutionContext.clCommandQueue();
				final var globalWorkDimensions = kernelExecutionContext.globalWorkDimensions();
				final var globalWorkSize = kernelExecutionContext.globalWorkSize();
				final long[] workGroupSize = kernelExecutionContext.workGroupSize()
						.orElse(null);

				logger.trace(
						"[{}] Starting computation on kernel {} for {} genotypes and global work size {} and local work size {}",
						device.name(),
						kernelName,
						genotypes.size(),
						globalWorkSize,
						workGroupSize);
				try {
					final long startTime = System.nanoTime();
					CL.clEnqueueNDRangeKernel(clCommandQueue,
							kernel,
							globalWorkDimensions,
							null,
							globalWorkSize,
							workGroupSize,
							0,
							null,
							null);

					// CL.clFinish(openCLExecutionContext.clCommandQueue());

					final long endTime = System.nanoTime();
					final long duration = endTime - startTime;
					if (logger.isDebugEnabled()) {
						final var deviceName = openCLExecutionContext.device()
								.name();
						logger.debug("[{}] -  Stage {} - Took {} microsec for {} genotypes",
								deviceName,
								i,
								duration / 1000.,
								genotypes.size());
					}
				} catch (Exception e) {
					logger.error("[{}] Failure to compute", device.name(), e);
					throw e;
				}

				if (i == stageDescriptors.size() - 1) {
					finalResults = fitnessExtractor.compute(openCLExecutionContext,
							kernelExecutionContext,
							executorService,
							generation,
							genotypes,
							new ResultExtractor(Map.of(device, resultData)));

					clearResultData(resultData);
				}

				for (final CLData clData : reusedArguments) {
					CL.clReleaseMemObject(clData.clMem());
				}
				clearData(data);
			}

			if (finalResults == null) {
				throw new IllegalStateException("final results cannot be null");
			}
			return finalResults;
		}, executorService);
	}

	@Override
	public void afterEvaluation(OpenCLExecutionContext openCLExecutionContext, ExecutorService executorService,
			long generation, List<Genotype> genotypes) {
		super.afterEvaluation(openCLExecutionContext, executorService, generation, genotypes);

		final var device = openCLExecutionContext.device();
		logger.trace("[{}] Releasing data", device.name());
	}

	@Override
	public void afterAllEvaluations(final OpenCLExecutionContext openCLExecutionContext,
			final ExecutorService executorService) {
		super.afterAllEvaluations(openCLExecutionContext, executorService);

		final var device = openCLExecutionContext.device();
		logger.trace("[{}] Releasing static data", device.name());
		clearStaticData(device);
	}

	public static <U extends Comparable<U>> MultiStageFitness<U> of(final MultiStageDescriptor multiStageDescriptor,
			final FitnessExtractor<U> fitnessExtractor) {
		Validate.notNull(multiStageDescriptor);
		Validate.notNull(fitnessExtractor);

		return new MultiStageFitness<>(multiStageDescriptor, fitnessExtractor);
	}
}