1 package net.bmahe.genetics4j.gpu.spec.fitness;
2
3 import java.util.Map;
4
5 import org.apache.commons.lang3.Validate;
6 import org.apache.logging.log4j.LogManager;
7 import org.apache.logging.log4j.Logger;
8 import org.jocl.CL;
9 import org.jocl.Pointer;
10 import org.jocl.Sizeof;
11
12 import net.bmahe.genetics4j.gpu.opencl.OpenCLExecutionContext;
13 import net.bmahe.genetics4j.gpu.opencl.model.Device;
14 import net.bmahe.genetics4j.gpu.spec.fitness.cldata.CLData;
15
16 /**
17 * Utility class for extracting computation results from OpenCL device memory after GPU kernel execution.
18 *
19 * <p>ResultExtractor provides type-safe methods for retrieving different data types from OpenCL memory buffers that
20 * contain the results of GPU-accelerated fitness evaluation. This class handles the device-to-host data transfer and
21 * type conversion necessary to make GPU computation results available to the evolutionary algorithm.
22 *
23 * <p>Key functionality includes:
24 * <ul>
25 * <li><strong>Type-safe extraction</strong>: Methods for extracting float, int, long arrays with type validation</li>
26 * <li><strong>Image data support</strong>: Specialized extraction for OpenCL image objects</li>
27 * <li><strong>Device management</strong>: Tracks result data across multiple devices</li>
28 * <li><strong>Argument indexing</strong>: Maps kernel arguments to their corresponding result data</li>
29 * </ul>
30 *
31 * <p>Common usage patterns:
32 *
33 * <pre>{@code
34 * // Extract fitness values as float array
35 * float[] fitnessValues = resultExtractor.extractFloatArray(context, 0);
36 *
37 * // Extract integer results (e.g., classification results)
38 * int[] classifications = resultExtractor.extractIntArray(context, 1);
39 *
40 * // Extract long results (e.g., counters or large indices)
41 * long[] counters = resultExtractor.extractLongArray(context, 2);
42 *
43 * // Extract image data for visualization
44 * byte[] imageData = resultExtractor.extractImageAsByteArray(context, 3, width, height, channels, channelSize);
45 *
46 * // Use extracted results in fitness evaluation
47 * List<Double> fitness = IntStream.range(0, fitnessValues.length)
48 * .mapToDouble(i -> (double) fitnessValues[i])
49 * .boxed()
50 * .collect(Collectors.toList());
51 * }</pre>
52 *
53 * <p>Data extraction workflow:
54 * <ol>
55 * <li><strong>Kernel execution</strong>: GPU kernels compute results and store them in device memory</li>
56 * <li><strong>Result mapping</strong>: Results are mapped by device and kernel argument index</li>
57 * <li><strong>Type validation</strong>: Data types are validated before extraction</li>
58 * <li><strong>Data transfer</strong>: Results are transferred from device to host memory</li>
59 * <li><strong>Type conversion</strong>: Data is converted to appropriate Java types</li>
60 * </ol>
61 *
62 * <p>Error handling and validation:
63 * <ul>
64 * <li><strong>Device validation</strong>: Ensures requested device has result data</li>
65 * <li><strong>Argument validation</strong>: Validates argument indices exist in result mapping</li>
66 * <li><strong>Type checking</strong>: Ensures extracted data matches expected OpenCL types</li>
67 * <li><strong>Transfer validation</strong>: Validates successful device-to-host data transfer</li>
68 * </ul>
69 *
70 * <p>Performance considerations:
71 * <ul>
72 * <li><strong>Synchronous transfers</strong>: Uses blocking transfers to ensure data availability</li>
73 * <li><strong>Memory efficiency</strong>: Allocates host memory based on actual data sizes</li>
74 * <li><strong>Transfer optimization</strong>: Minimizes number of device-to-host transfers</li>
75 * <li><strong>Type safety</strong>: Validates types at runtime to prevent data corruption</li>
76 * </ul>
77 *
78 * @see CLData
79 * @see net.bmahe.genetics4j.gpu.spec.fitness.OpenCLFitness
80 * @see OpenCLExecutionContext
81 */
82 public class ResultExtractor {
83 public static final Logger logger = LogManager.getLogger(ResultExtractor.class);
84
85 private final Map<Device, Map<Integer, CLData>> resultData;
86
87 /**
88 * Extracts CLData for the specified device and kernel argument index.
89 *
90 * @param device the OpenCL device to extract data from
91 * @param argumentIndex the kernel argument index for the data
92 * @return the CLData object containing the result data
93 * @throws IllegalArgumentException if device is null, argumentIndex is negative, device not found, or argument index
94 * not found
95 */
96 protected CLData extractClData(final Device device, final int argumentIndex) {
97 Validate.notNull(device);
98 Validate.isTrue(argumentIndex >= 0);
99
100 if (resultData.containsKey(device) == false) {
101 throw new IllegalArgumentException("Could not find entry for device [" + device.name() + "]");
102 }
103
104 final var deviceResults = resultData.get(device);
105
106 if (deviceResults.containsKey(argumentIndex) == false) {
107 throw new IllegalArgumentException("No data defined for argument " + argumentIndex);
108 }
109
110 final var clData = deviceResults.get(argumentIndex);
111 return clData;
112 }
113
114 /**
115 * Constructs a ResultExtractor with the specified result data mapping.
116 *
117 * @param _resultData mapping from devices to their kernel argument results
118 */
119 public ResultExtractor(final Map<Device, Map<Integer, CLData>> _resultData) {
120
121 this.resultData = _resultData;
122 }
123
124 /**
125 * Extracts image data from OpenCL device memory as a byte array.
126 *
127 * <p>This method reads an OpenCL image object from device memory and converts it to a byte array suitable for host
128 * processing. The image dimensions and channel information must be provided to properly interpret the image data.
129 *
130 * @param openCLExecutionContext the OpenCL execution context
131 * @param argumentIndex the kernel argument index containing the image data
132 * @param width the image width in pixels
133 * @param height the image height in pixels
134 * @param numChannels the number of color channels (e.g., 3 for RGB, 4 for RGBA)
135 * @param channelSize the size of each channel in bytes
136 * @return byte array containing the image data
137 * @throws IllegalArgumentException if any parameter is invalid
138 */
139 public byte[] extractImageAsByteArray(final OpenCLExecutionContext openCLExecutionContext, final int argumentIndex,
140 final int width, final int height, final int numChannels, final int channelSize) {
141 Validate.isTrue(argumentIndex >= 0);
142 Validate.isTrue(width > 0);
143 Validate.isTrue(height > 0);
144 Validate.isTrue(numChannels > 0);
145 Validate.isTrue(channelSize > 0);
146
147 final var device = openCLExecutionContext.device();
148 final var clData = extractClData(device, argumentIndex);
149
150 final var clCommandQueue = openCLExecutionContext.clCommandQueue();
151
152 final byte[] data = new byte[width * height * numChannels * channelSize];
153 CL.clEnqueueReadImage(
154 clCommandQueue,
155 clData.clMem(),
156 CL.CL_TRUE,
157 new long[] { 0, 0, 0 },
158 new long[] { width, height, 1 },
159 0,
160 0,
161 Pointer.to(data),
162 0,
163 null,
164 null);
165
166 return data;
167 }
168
169 /**
170 * Extracts floating-point data from OpenCL device memory as a float array.
171 *
172 * <p>This method reads floating-point data from device memory and transfers it to host memory. The data type is
173 * validated to ensure it contains floating-point values before extraction.
174 *
175 * @param openCLExecutionContext the OpenCL execution context
176 * @param argumentIndex the kernel argument index containing the float data
177 * @return float array containing the extracted data
178 * @throws IllegalArgumentException if the data is not of type float
179 */
180 public float[] extractFloatArray(final OpenCLExecutionContext openCLExecutionContext, final int argumentIndex) {
181 final var device = openCLExecutionContext.device();
182 final var clData = extractClData(device, argumentIndex);
183
184 if (clData.clType() != Sizeof.cl_float) {
185 throw new IllegalArgumentException("Data is not of type of float[]");
186 }
187
188 final var clCommandQueue = openCLExecutionContext.clCommandQueue();
189
190 final float[] data = new float[clData.size()];
191 CL.clEnqueueReadBuffer(
192 clCommandQueue,
193 clData.clMem(),
194 CL.CL_TRUE,
195 0,
196 clData.size() * Sizeof.cl_float,
197 Pointer.to(data),
198 0,
199 null,
200 null);
201
202 return data;
203 }
204
205 /**
206 * Extracts integer data from OpenCL device memory as an int array.
207 *
208 * <p>This method reads integer data from device memory and transfers it to host memory. The data type is validated
209 * to ensure it contains integer values before extraction.
210 *
211 * @param openCLExecutionContext the OpenCL execution context
212 * @param argumentIndex the kernel argument index containing the integer data
213 * @return int array containing the extracted data
214 * @throws IllegalArgumentException if the data is not of type int
215 */
216 public int[] extractIntArray(final OpenCLExecutionContext openCLExecutionContext, final int argumentIndex) {
217 final var device = openCLExecutionContext.device();
218 final var clData = extractClData(device, argumentIndex);
219
220 if (clData.clType() != Sizeof.cl_int) {
221 throw new IllegalArgumentException("Data is not of type of int[]");
222 }
223
224 final var clCommandQueue = openCLExecutionContext.clCommandQueue();
225
226 final int[] data = new int[clData.size()];
227 CL.clEnqueueReadBuffer(
228 clCommandQueue,
229 clData.clMem(),
230 CL.CL_TRUE,
231 0,
232 clData.size() * Sizeof.cl_int,
233 Pointer.to(data),
234 0,
235 null,
236 null);
237
238 return data;
239 }
240
241 /**
242 * Extracts long integer data from OpenCL device memory as a long array.
243 *
244 * <p>This method reads long integer data from device memory and transfers it to host memory. The data type is
245 * validated to ensure it contains long integer values before extraction.
246 *
247 * @param openCLExecutionContext the OpenCL execution context
248 * @param argumentIndex the kernel argument index containing the long integer data
249 * @return long array containing the extracted data
250 * @throws IllegalArgumentException if the data is not of type long
251 */
252 public long[] extractLongArray(final OpenCLExecutionContext openCLExecutionContext, final int argumentIndex) {
253 final var device = openCLExecutionContext.device();
254 final var clData = extractClData(device, argumentIndex);
255
256 if (clData.clType() != Sizeof.cl_long) {
257 throw new IllegalArgumentException("Data is not of type of long[]");
258 }
259
260 final var clCommandQueue = openCLExecutionContext.clCommandQueue();
261
262 final long[] data = new long[clData.size()];
263 CL.clEnqueueReadBuffer(
264 clCommandQueue,
265 clData.clMem(),
266 CL.CL_TRUE,
267 0,
268 clData.size() * Sizeof.cl_long,
269 Pointer.to(data),
270 0,
271 null,
272 null);
273 return data;
274 }
275 }