1 package net.bmahe.genetics4j.samples.clustering;
2
3 import java.io.FileNotFoundException;
4 import java.io.FileReader;
5 import java.io.IOException;
6 import java.io.Reader;
7 import java.nio.charset.StandardCharsets;
8 import java.nio.file.Path;
9 import java.util.ArrayList;
10 import java.util.List;
11
12 import org.apache.commons.csv.CSVFormat;
13 import org.apache.commons.csv.CSVPrinter;
14 import org.apache.commons.csv.CSVRecord;
15 import org.apache.commons.lang3.Validate;
16 import org.apache.logging.log4j.LogManager;
17 import org.apache.logging.log4j.Logger;
18
19 public class IOUtils {
20 final static public Logger logger = LogManager.getLogger(IOUtils.class);
21
22 public static double[][] loadClusters(final String filename) {
23 logger.info("Loading clusters from {}", filename);
24
25 Reader in;
26 try {
27 in = new FileReader(filename);
28 } catch (FileNotFoundException e) {
29 throw new RuntimeException(e);
30 }
31
32 Iterable<CSVRecord> records;
33 try {
34 records = CSVFormat.DEFAULT.withFirstRecordAsHeader().parse(in);
35 } catch (IOException e) {
36 throw new RuntimeException(e);
37 }
38
39 final List<double[]> entries = new ArrayList<>();
40 for (final CSVRecord record : records) {
41 final double x = Double.parseDouble(record.get(1));
42 final double y = Double.parseDouble(record.get(2));
43
44 entries.add(new double[] { x, y });
45 }
46
47 final double[][] clusters = new double[entries.size()][2];
48 for (int i = 0; i < entries.size(); i++) {
49 clusters[i][0] = entries.get(i)[0];
50 clusters[i][1] = entries.get(i)[1];
51 }
52 return clusters;
53 }
54
55 public static double[][] loadDataPoints(final String filename) throws IOException {
56 final Reader in = new FileReader(filename);
57 final Iterable<CSVRecord> records = CSVFormat.DEFAULT.withFirstRecordAsHeader()
58 .withSkipHeaderRecord(true)
59 .parse(in);
60 final List<double[]> entries = new ArrayList<>();
61 for (final CSVRecord record : records) {
62 final double cluster = Double.parseDouble(record.get(0));
63 final double x = Double.parseDouble(record.get(1));
64 final double y = Double.parseDouble(record.get(2));
65
66 entries.add(new double[] { cluster, x, y });
67 }
68
69 final double[][] clusters = new double[entries.size()][3];
70 for (int i = 0; i < entries.size(); i++) {
71 clusters[i][0] = entries.get(i)[1];
72 clusters[i][1] = entries.get(i)[2];
73 clusters[i][2] = entries.get(i)[0];
74 }
75 return clusters;
76 }
77
78 public static void persistClusters(final double[][] clusters, final String clustersFilename) throws IOException {
79 logger.info("Saving clusters to CSV: {}", clustersFilename);
80
81 final CSVPrinter csvPrinter;
82 try {
83 csvPrinter = CSVFormat.DEFAULT.withAutoFlush(true)
84 .withHeader(new String[] { "cluster", "x", "y" })
85 .print(Path.of(clustersFilename), StandardCharsets.UTF_8);
86 } catch (IOException e) {
87 logger.error("Could not open {}", clustersFilename, e);
88 throw new RuntimeException("Could not open file " + clustersFilename, e);
89 }
90
91 for (int i = 0; i < clusters.length; i++) {
92 try {
93 csvPrinter.printRecord(i, clusters[i][0], clusters[i][1]);
94 } catch (IOException e) {
95 throw new RuntimeException("Could not write data", e);
96 }
97 }
98 csvPrinter.close(true);
99 }
100
101 public static void persistDataPoints(final double[][] data, final String filename) throws IOException {
102 Validate.notBlank(filename);
103
104 logger.info("Saving data to CSV: {}", filename);
105
106 final int numDataPoints = data.length;
107
108 final CSVPrinter csvPrinter;
109 try {
110 csvPrinter = CSVFormat.DEFAULT.withAutoFlush(true)
111 .withHeader(new String[] { "cluster", "x", "y" })
112 .print(Path.of(filename), StandardCharsets.UTF_8);
113 } catch (IOException e) {
114 logger.error("Could not open {}", filename, e);
115 throw new RuntimeException("Could not open file " + filename, e);
116 }
117
118 for (int i = 0; i < numDataPoints; i++) {
119 try {
120 csvPrinter.printRecord((int) data[i][2], data[i][0], data[i][1]);
121 } catch (IOException e) {
122 throw new RuntimeException("Could not write data", e);
123 }
124 }
125 csvPrinter.close(true);
126 }
127
128 public static void persistDataPoints(final double[][] data, final int[] closestClusterIndex, final String filename)
129 throws IOException {
130 Validate.notBlank(filename);
131
132 logger.info("Saving data to CSV: {}", filename);
133
134 final int numDataPoints = data.length;
135
136 final CSVPrinter csvPrinter;
137 try {
138 csvPrinter = CSVFormat.DEFAULT.withAutoFlush(true)
139 .withHeader(new String[] { "cluster", "x", "y" })
140 .print(Path.of(filename), StandardCharsets.UTF_8);
141 } catch (IOException e) {
142 logger.error("Could not open {}", filename, e);
143 throw new RuntimeException("Could not open file " + filename, e);
144 }
145
146 for (int i = 0; i < numDataPoints; i++) {
147 try {
148 csvPrinter.printRecord(closestClusterIndex[i], data[i][0], data[i][1]);
149 } catch (IOException e) {
150 throw new RuntimeException("Could not write data", e);
151 }
152 }
153 csvPrinter.close(true);
154 }
155 }