ResultExtractor.java
package net.bmahe.genetics4j.gpu.spec.fitness;
import java.util.Map;
import org.apache.commons.lang3.Validate;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.jocl.CL;
import org.jocl.Pointer;
import org.jocl.Sizeof;
import net.bmahe.genetics4j.gpu.opencl.OpenCLExecutionContext;
import net.bmahe.genetics4j.gpu.opencl.model.Device;
import net.bmahe.genetics4j.gpu.spec.fitness.cldata.CLData;
public class ResultExtractor {
public static final Logger logger = LogManager.getLogger(ResultExtractor.class);
private final Map<Device, Map<Integer, CLData>> resultData;
protected CLData extractClData(final Device device, final int argumentIndex) {
Validate.notNull(device);
Validate.isTrue(argumentIndex >= 0);
if (resultData.containsKey(device) == false) {
throw new IllegalArgumentException("Could not find entry for device [" + device.name() + "]");
}
final var deviceResults = resultData.get(device);
if (deviceResults.containsKey(argumentIndex) == false) {
throw new IllegalArgumentException("No data defined for argument " + argumentIndex);
}
final var clData = deviceResults.get(argumentIndex);
return clData;
}
public ResultExtractor(final Map<Device, Map<Integer, CLData>> _resultData) {
this.resultData = _resultData;
}
public byte[] extractImageAsByteArray(final OpenCLExecutionContext openCLExecutionContext, final int argumentIndex,
final int width, final int height, final int numChannels, final int channelSize) {
Validate.isTrue(argumentIndex >= 0);
Validate.isTrue(width > 0);
Validate.isTrue(height > 0);
Validate.isTrue(numChannels > 0);
Validate.isTrue(channelSize > 0);
final var device = openCLExecutionContext.device();
final var clData = extractClData(device, argumentIndex);
final var clCommandQueue = openCLExecutionContext.clCommandQueue();
final byte[] data = new byte[width * height * numChannels * channelSize];
CL.clEnqueueReadImage(clCommandQueue,
clData.clMem(),
CL.CL_TRUE,
new long[] { 0, 0, 0 },
new long[] { width, height, 1 },
0,
0,
Pointer.to(data),
0,
null,
null);
return data;
}
public float[] extractFloatArray(final OpenCLExecutionContext openCLExecutionContext, final int argumentIndex) {
final var device = openCLExecutionContext.device();
final var clData = extractClData(device, argumentIndex);
if (clData.clType() != Sizeof.cl_float) {
throw new IllegalArgumentException("Data is not of type of float[]");
}
final var clCommandQueue = openCLExecutionContext.clCommandQueue();
final float[] data = new float[clData.size()];
CL.clEnqueueReadBuffer(clCommandQueue,
clData.clMem(),
CL.CL_TRUE,
0,
clData.size() * Sizeof.cl_float,
Pointer.to(data),
0,
null,
null);
return data;
}
public int[] extractIntArray(final OpenCLExecutionContext openCLExecutionContext, final int argumentIndex) {
final var device = openCLExecutionContext.device();
final var clData = extractClData(device, argumentIndex);
if (clData.clType() != Sizeof.cl_int) {
throw new IllegalArgumentException("Data is not of type of int[]");
}
final var clCommandQueue = openCLExecutionContext.clCommandQueue();
final int[] data = new int[clData.size()];
CL.clEnqueueReadBuffer(clCommandQueue,
clData.clMem(),
CL.CL_TRUE,
0,
clData.size() * Sizeof.cl_int,
Pointer.to(data),
0,
null,
null);
return data;
}
public long[] extractLongArray(final OpenCLExecutionContext openCLExecutionContext, final int argumentIndex) {
final var device = openCLExecutionContext.device();
final var clData = extractClData(device, argumentIndex);
if (clData.clType() != Sizeof.cl_long) {
throw new IllegalArgumentException("Data is not of type of long[]");
}
final var clCommandQueue = openCLExecutionContext.clCommandQueue();
final long[] data = new long[clData.size()];
CL.clEnqueueReadBuffer(clCommandQueue,
clData.clMem(),
CL.CL_TRUE,
0,
clData.size() * Sizeof.cl_long,
Pointer.to(data),
0,
null,
null);
return data;
}
}