1 package net.bmahe.genetics4j.samples.clustering; 2 3 import java.io.FileNotFoundException; 4 import java.io.FileReader; 5 import java.io.IOException; 6 import java.io.Reader; 7 import java.nio.charset.StandardCharsets; 8 import java.nio.file.Path; 9 import java.util.ArrayList; 10 import java.util.List; 11 12 import org.apache.commons.csv.CSVFormat; 13 import org.apache.commons.csv.CSVPrinter; 14 import org.apache.commons.csv.CSVRecord; 15 import org.apache.commons.lang3.Validate; 16 import org.apache.logging.log4j.LogManager; 17 import org.apache.logging.log4j.Logger; 18 19 public class IOUtils { 20 final static public Logger logger = LogManager.getLogger(IOUtils.class); 21 22 public static double[][] loadClusters(final String filename) { 23 logger.info("Loading clusters from {}", filename); 24 25 Reader in; 26 try { 27 in = new FileReader(filename); 28 } catch (FileNotFoundException e) { 29 throw new RuntimeException(e); 30 } 31 32 Iterable<CSVRecord> records; 33 try { 34 records = CSVFormat.DEFAULT.withFirstRecordAsHeader() 35 .parse(in); 36 } catch (IOException e) { 37 throw new RuntimeException(e); 38 } 39 40 final List<double[]> entries = new ArrayList<>(); 41 for (final CSVRecord record : records) { 42 final double x = Double.parseDouble(record.get(1)); 43 final double y = Double.parseDouble(record.get(2)); 44 45 entries.add(new double[] { x, y }); 46 } 47 48 final double[][] clusters = new double[entries.size()][2]; 49 for (int i = 0; i < entries.size(); i++) { 50 clusters[i][0] = entries.get(i)[0]; 51 clusters[i][1] = entries.get(i)[1]; 52 } 53 return clusters; 54 } 55 56 public static double[][] loadDataPoints(final String filename) throws IOException { 57 final Reader in = new FileReader(filename); 58 final Iterable<CSVRecord> records = CSVFormat.DEFAULT.withFirstRecordAsHeader() 59 .withSkipHeaderRecord(true) 60 .parse(in); 61 final List<double[]> entries = new ArrayList<>(); 62 for (final CSVRecord record : records) { 63 final double cluster = Double.parseDouble(record.get(0)); 64 final double x = Double.parseDouble(record.get(1)); 65 final double y = Double.parseDouble(record.get(2)); 66 67 entries.add(new double[] { cluster, x, y }); 68 } 69 70 final double[][] clusters = new double[entries.size()][3]; 71 for (int i = 0; i < entries.size(); i++) { 72 clusters[i][0] = entries.get(i)[1]; 73 clusters[i][1] = entries.get(i)[2]; 74 clusters[i][2] = entries.get(i)[0]; 75 } 76 return clusters; 77 } 78 79 public static void persistClusters(final double[][] clusters, final String clustersFilename) throws IOException { 80 logger.info("Saving clusters to CSV: {}", clustersFilename); 81 82 final CSVPrinter csvPrinter; 83 try { 84 csvPrinter = CSVFormat.DEFAULT.withAutoFlush(true) 85 .withHeader(new String[] { "cluster", "x", "y" }) 86 .print(Path.of(clustersFilename), StandardCharsets.UTF_8); 87 } catch (IOException e) { 88 logger.error("Could not open {}", clustersFilename, e); 89 throw new RuntimeException("Could not open file " + clustersFilename, e); 90 } 91 92 for (int i = 0; i < clusters.length; i++) { 93 try { 94 csvPrinter.printRecord(i, clusters[i][0], clusters[i][1]); 95 } catch (IOException e) { 96 throw new RuntimeException("Could not write data", e); 97 } 98 } 99 csvPrinter.close(true); 100 } 101 102 public static void persistDataPoints(final double[][] data, final String filename) throws IOException { 103 Validate.notBlank(filename); 104 105 logger.info("Saving data to CSV: {}", filename); 106 107 final int numDataPoints = data.length; 108 109 final CSVPrinter csvPrinter; 110 try { 111 csvPrinter = CSVFormat.DEFAULT.withAutoFlush(true) 112 .withHeader(new String[] { "cluster", "x", "y" }) 113 .print(Path.of(filename), StandardCharsets.UTF_8); 114 } catch (IOException e) { 115 logger.error("Could not open {}", filename, e); 116 throw new RuntimeException("Could not open file " + filename, e); 117 } 118 119 for (int i = 0; i < numDataPoints; i++) { 120 try { 121 csvPrinter.printRecord((int) data[i][2], data[i][0], data[i][1]); 122 } catch (IOException e) { 123 throw new RuntimeException("Could not write data", e); 124 } 125 } 126 csvPrinter.close(true); 127 } 128 129 public static void persistDataPoints(final double[][] data, final int[] closestClusterIndex, final String filename) 130 throws IOException { 131 Validate.notBlank(filename); 132 133 logger.info("Saving data to CSV: {}", filename); 134 135 final int numDataPoints = data.length; 136 137 final CSVPrinter csvPrinter; 138 try { 139 csvPrinter = CSVFormat.DEFAULT.withAutoFlush(true) 140 .withHeader(new String[] { "cluster", "x", "y" }) 141 .print(Path.of(filename), StandardCharsets.UTF_8); 142 } catch (IOException e) { 143 logger.error("Could not open {}", filename, e); 144 throw new RuntimeException("Could not open file " + filename, e); 145 } 146 147 for (int i = 0; i < numDataPoints; i++) { 148 try { 149 csvPrinter.printRecord(closestClusterIndex[i], data[i][0], data[i][1]); 150 } catch (IOException e) { 151 throw new RuntimeException("Could not write data", e); 152 } 153 } 154 csvPrinter.close(true); 155 } 156 }