GPUFitnessEvaluator.java

package net.bmahe.genetics4j.gpu;

import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;

import org.apache.commons.collections4.ListUtils;
import org.apache.commons.io.IOUtils;
import org.apache.commons.lang3.Validate;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.jocl.CL;
import org.jocl.cl_command_queue;
import org.jocl.cl_context;
import org.jocl.cl_context_properties;
import org.jocl.cl_device_id;
import org.jocl.cl_kernel;
import org.jocl.cl_platform_id;
import org.jocl.cl_program;
import org.jocl.cl_queue_properties;

import net.bmahe.genetics4j.core.Genotype;
import net.bmahe.genetics4j.core.evaluation.FitnessEvaluator;
import net.bmahe.genetics4j.gpu.opencl.DeviceReader;
import net.bmahe.genetics4j.gpu.opencl.DeviceUtils;
import net.bmahe.genetics4j.gpu.opencl.KernelInfoReader;
import net.bmahe.genetics4j.gpu.opencl.OpenCLExecutionContext;
import net.bmahe.genetics4j.gpu.opencl.PlatformReader;
import net.bmahe.genetics4j.gpu.opencl.PlatformUtils;
import net.bmahe.genetics4j.gpu.opencl.model.Device;
import net.bmahe.genetics4j.gpu.opencl.model.KernelInfo;
import net.bmahe.genetics4j.gpu.opencl.model.Platform;
import net.bmahe.genetics4j.gpu.spec.GPUEAConfiguration;
import net.bmahe.genetics4j.gpu.spec.GPUEAExecutionContext;
import net.bmahe.genetics4j.gpu.spec.Program;

public class GPUFitnessEvaluator<T extends Comparable<T>> implements FitnessEvaluator<T> {
	public static final Logger logger = LogManager.getLogger(GPUFitnessEvaluator.class);

	private final GPUEAExecutionContext<T> gpuEAExecutionContext;
	private final GPUEAConfiguration<T> gpuEAConfiguration;
	private final ExecutorService executorService;

	private List<Pair<Platform, Device>> selectedPlatformToDevice;

	final List<cl_context> clContexts = new ArrayList<>();
	final List<cl_command_queue> clCommandQueues = new ArrayList<>();
	final List<cl_program> clPrograms = new ArrayList<>();
	final List<Map<String, cl_kernel>> clKernels = new ArrayList<>();
	final List<OpenCLExecutionContext> clExecutionContexts = new ArrayList<>();

	public GPUFitnessEvaluator(final GPUEAExecutionContext<T> _gpuEAExecutionContext,
			final GPUEAConfiguration<T> _gpuEAConfiguration, final ExecutorService _executorService) {
		Validate.notNull(_gpuEAExecutionContext);
		Validate.notNull(_gpuEAConfiguration);
		Validate.notNull(_executorService);

		this.gpuEAExecutionContext = _gpuEAExecutionContext;
		this.gpuEAConfiguration = _gpuEAConfiguration;
		this.executorService = _executorService;

		CL.setExceptionsEnabled(true);
	}

	private String loadResource(final String filename) {
		Validate.notBlank(filename);

		try {
			return IOUtils.resourceToString(filename, StandardCharsets.UTF_8);
		} catch (IOException e) {
			throw new IllegalStateException("Unable to load resource " + filename, e);
		}
	}

	private List<String> grabProgramSources() {
		final Program programSpec = gpuEAConfiguration.program();

		logger.info("Load program source: {}", programSpec);

		final List<String> sources = new ArrayList<>();

		sources.addAll(programSpec.content());

		programSpec.resources()
				.stream()
				.map(resource -> loadResource(resource))
				.forEach(program -> {
					sources.add(program);
				});

		return sources;
	}

	@Override
	public void preEvaluation() {
		logger.trace("Init...");
		FitnessEvaluator.super.preEvaluation();

		final var platformReader = new PlatformReader();
		final var deviceReader = new DeviceReader();
		final var kernelInfoReader = new KernelInfoReader();

		final int numPlatforms = PlatformUtils.numPlatforms();
		logger.info("Found {} platforms", numPlatforms);

		final List<cl_platform_id> platformIds = PlatformUtils.platformIds(numPlatforms);

		logger.info("Selecting platform and devices");
		final var platformFilters = gpuEAExecutionContext.platformFilters();
		final var deviceFilters = gpuEAExecutionContext.deviceFilters();

		selectedPlatformToDevice = platformIds.stream()
				.map(platformReader::read)
				.filter(platformFilters)
				.flatMap(platform -> {
					final var platformId = platform.platformId();
					final int numDevices = DeviceUtils.numDevices(platformId);
					logger.trace("\tPlatform {}: {} devices", platform.name(), numDevices);

					final var deviceIds = DeviceUtils.getDeviceIds(platformId, numDevices);
					return deviceIds.stream()
							.map(deviceId -> Pair.of(platform, deviceId));
				})
				.map(platformToDeviceId -> {
					final var platform = platformToDeviceId.getLeft();
					final var platformId = platform.platformId();
					final var deviceID = platformToDeviceId.getRight();

					return Pair.of(platform, deviceReader.read(platformId, deviceID));
				})
				.filter(platformToDevice -> deviceFilters.test(platformToDevice.getRight()))
				.toList();

		if (logger.isTraceEnabled()) {
			logger.trace("============================");
			logger.trace("Selected devices:");
			selectedPlatformToDevice.forEach(pd -> {
				logger.trace("{}", pd.getLeft());
				logger.trace("\t{}", pd.getRight());
			});
			logger.trace("============================");
		}

		Validate.isTrue(selectedPlatformToDevice.size() > 0);

		final List<String> programs = grabProgramSources();
		final String[] programsArr = programs.toArray(new String[programs.size()]);

		for (final var platformAndDevice : selectedPlatformToDevice) {
			final var platform = platformAndDevice.getLeft();
			final var device = platformAndDevice.getRight();

			logger.info("Processing platform [{}] / device [{}]", platform.name(), device.name());

			logger.info("\tCreating context");
			cl_context_properties contextProperties = new cl_context_properties();
			contextProperties.addProperty(CL.CL_CONTEXT_PLATFORM, platform.platformId());

			final cl_context context = CL
					.clCreateContext(contextProperties, 1, new cl_device_id[] { device.deviceId() }, null, null, null);

			logger.info("\tCreating command queue");
			final cl_queue_properties queueProperties = new cl_queue_properties();
			queueProperties.addProperty(CL.CL_QUEUE_PROPERTIES,
					CL.CL_QUEUE_PROFILING_ENABLE | CL.CL_QUEUE_OUT_OF_ORDER_EXEC_MODE_ENABLE);
			final cl_command_queue commandQueue = CL
					.clCreateCommandQueueWithProperties(context, device.deviceId(), queueProperties, null);

			logger.info("\tCreate program");
			final cl_program program = CL.clCreateProgramWithSource(context, programsArr.length, programsArr, null, null);

			final var programSpec = gpuEAConfiguration.program();
			final var buildOptions = programSpec.buildOptions()
					.orElse(null);
			logger.info("\tBuilding program with options: {}", buildOptions);
			CL.clBuildProgram(program, 0, null, buildOptions, null, null);

			final Set<String> kernelNames = gpuEAConfiguration.program()
					.kernelNames();

			final Map<String, cl_kernel> kernels = new HashMap<>();
			final Map<String, KernelInfo> kernelInfos = new HashMap<>();
			for (final String kernelName : kernelNames) {

				logger.info("\tCreate kernel {}", kernelName);
				final cl_kernel kernel = CL.clCreateKernel(program, kernelName, null);
				Validate.notNull(kernel);

				kernels.put(kernelName, kernel);

				final var kernelInfo = kernelInfoReader.read(device.deviceId(), kernel, kernelName);
				logger.trace("\t{}", kernelInfo);
				kernelInfos.put(kernelName, kernelInfo);
			}

			clContexts.add(context);
			clCommandQueues.add(commandQueue);
			clKernels.add(kernels);
			clPrograms.add(program);

			final var openCLExecutionContext = OpenCLExecutionContext.builder()
					.platform(platform)
					.device(device)
					.clContext(context)
					.clCommandQueue(commandQueue)
					.kernels(kernels)
					.kernelInfos(kernelInfos)
					.clProgram(program)
					.build();

			clExecutionContexts.add(openCLExecutionContext);
		}

		final var fitness = gpuEAConfiguration.fitness();
		fitness.beforeAllEvaluations();
		for (final OpenCLExecutionContext clExecutionContext : clExecutionContexts) {
			fitness.beforeAllEvaluations(clExecutionContext, executorService);
		}
	}

	@Override
	public List<T> evaluate(final long generation, final List<Genotype> genotypes) {

		final var fitness = gpuEAConfiguration.fitness();

		/**
		 * TODO make it configurable from execution context
		 */
		final int partitionSize = (int) (Math.ceil((double) genotypes.size() / clExecutionContexts.size()));
		final var subGenotypes = ListUtils.partition(genotypes, partitionSize);
		logger.debug("Genotype decomposed in {} partition(s)", subGenotypes.size());
		if (logger.isTraceEnabled()) {
			for (int i = 0; i < subGenotypes.size(); i++) {
				final List<Genotype> subGenotype = subGenotypes.get(i);
				logger.trace("\tPartition {} with {} elements", i, subGenotype.size());
			}
		}

		final List<CompletableFuture<List<T>>> subResultsCF = new ArrayList<>();
		for (int i = 0; i < subGenotypes.size(); i++) {
			final var openCLExecutionContext = clExecutionContexts.get(i % clExecutionContexts.size());
			final var subGenotype = subGenotypes.get(i);

			fitness.beforeEvaluation(generation, subGenotype);
			fitness.beforeEvaluation(openCLExecutionContext, executorService, generation, subGenotype);

			final var resultsCF = fitness.compute(openCLExecutionContext, executorService, generation, subGenotype)
					.thenApply((results) -> {

						fitness.afterEvaluation(openCLExecutionContext, executorService, generation, subGenotype);
						fitness.afterEvaluation(generation, subGenotype);

						return results;
					});

			subResultsCF.add(resultsCF);
		}

		final List<T> resultsEvaluation = new ArrayList<>(genotypes.size());
		for (final CompletableFuture<List<T>> subResultCF : subResultsCF) {
			final var fitnessResults = subResultCF.join();
			resultsEvaluation.addAll(fitnessResults);
		}
		return resultsEvaluation;
	}

	@Override
	public void postEvaluation() {

		final var fitness = gpuEAConfiguration.fitness();

		for (final OpenCLExecutionContext clExecutionContext : clExecutionContexts) {
			fitness.afterAllEvaluations(clExecutionContext, executorService);
		}
		fitness.afterAllEvaluations();

		logger.debug("Releasing kernels");

		for (final Map<String, cl_kernel> kernels : clKernels) {
			for (final cl_kernel clKernel : kernels.values()) {
				CL.clReleaseKernel(clKernel);
			}
		}
		clKernels.clear();

		logger.debug("Releasing programs");
		for (final cl_program clProgram : clPrograms) {
			CL.clReleaseProgram(clProgram);
		}
		clPrograms.clear();

		logger.debug("Releasing command queues");
		for (final cl_command_queue clCommandQueue : clCommandQueues) {
			CL.clReleaseCommandQueue(clCommandQueue);
		}
		clCommandQueues.clear();

		logger.debug("Releasing contexts");
		for (final cl_context clContext : clContexts) {
			CL.clReleaseContext(clContext);
		}
		clContexts.clear();

		clExecutionContexts.clear();
		selectedPlatformToDevice = null;

		FitnessEvaluator.super.postEvaluation();
	}
}