CSVEvolutionListener.java

package net.bmahe.genetics4j.extras.evolutionlisteners;

import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.nio.file.Path;
import java.util.Comparator;
import java.util.List;
import java.util.Optional;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;

import org.apache.commons.csv.CSVFormat;
import org.apache.commons.csv.CSVPrinter;
import org.apache.commons.lang3.Validate;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.immutables.value.Value;

import net.bmahe.genetics4j.core.Genotype;
import net.bmahe.genetics4j.core.evolutionlisteners.EvolutionListener;

/**
 * Evolution Listener which writes the output of each generation to a CSV file
 *
 * @author bruno
 *
 * @param <T> Fitness type
 * @param <U> Data type written to the CSV
 */
@Value.Immutable
public abstract class CSVEvolutionListener<T extends Comparable<T>, U> implements EvolutionListener<T> {
	final static public Logger logger = LogManager.getLogger(CSVEvolutionListener.class);

	public static final boolean DEFAULT_AUTO_FLUSH = true;

	private CSVPrinter csvPrinter;

	protected CSVPrinter openPrinter() {

		final List<String> headers = columnExtractors().stream()
				.map(ce -> ce.header())
				.collect(Collectors.toUnmodifiableList());

		try {
			return CSVFormat.DEFAULT.withAutoFlush(autoFlush())
					.withHeader(headers.toArray(new String[headers.size()]))
					.print(Path.of(filename()), StandardCharsets.UTF_8);
		} catch (IOException e) {
			logger.error("Could not open {}", filename(), e);
			throw new RuntimeException("Could not open file " + filename(), e);
		}

	}

	/**
	 * Whether or not the CSV writer has auto flush enabled. Defaults to
	 * {@value #DEFAULT_AUTO_FLUSH}
	 *
	 * @return
	 */
	@Value.Default
	public boolean autoFlush() {
		return DEFAULT_AUTO_FLUSH;
	}

	/**
	 * User defined function to provide some additional information when computing
	 * the value to write. Defaults to null
	 *
	 * @return
	 */
	@Value.Default
	public GenerationFunction<T, U> evolutionContextSupplier() {
		return (generation, population, fitness, isDone) -> null;
	}

	/**
	 * How many generations to skip between each writes. Defaults to writing every
	 * generations
	 *
	 * @return
	 */
	@Value.Default
	public int skipN() {
		return 0;
	}

	/**
	 * Users can supply an optional set of filters to control which individuals get
	 * written and in which order. Default to have no impact.
	 *
	 * @return
	 */
	@Value.Default
	public Function<Stream<EvolutionStep<T, U>>, Stream<EvolutionStep<T, U>>> filter() {
		return (stream) -> stream;
	}

	/**
	 * Destination file name for the CSV file
	 *
	 * @return
	 */
	@Value.Parameter
	public abstract String filename();

	/**
	 * List of Column Extractors. They specify how and what to write from each
	 * individual at a given generation
	 *
	 * @return
	 */
	@Value.Parameter
	public abstract List<ColumnExtractor<T, U>> columnExtractors();

	@Override
	public void onEvolution(final long generation, final List<Genotype> population, final List<T> fitness,
			final boolean isDone) {
		Validate.isTrue(generation >= 0);
		Validate.notNull(population);
		Validate.notNull(fitness);
		Validate.isTrue(population.size() > 0);
		Validate.isTrue(population.size() == fitness.size());

		if (isDone == false && skipN() > 0 && generation % skipN() != 0) {
			return;
		}

		if (csvPrinter == null) {
			csvPrinter = openPrinter();
		}

		final Optional<U> context = Optional
				.ofNullable(evolutionContextSupplier().apply(generation, population, fitness, isDone));

		final var rawIndividualStream = IntStream.range(0, population.size())
				.boxed()
				.map(individualIndex -> EvolutionStep.of(context,
						generation,
						individualIndex,
						population.get(individualIndex),
						fitness.get(individualIndex),
						isDone));

		final var filteredStream = filter().apply(rawIndividualStream);

		filteredStream.forEach(evolutionStep -> {
			final List<Object> columnValues = columnExtractors().stream()
					.map(ce -> ce.columnExtractorFunction())
					.map(cef -> cef.apply(evolutionStep))
					.collect(Collectors.toUnmodifiableList());

			try {
				csvPrinter.printRecord(columnValues);
			} catch (IOException e1) {
				logger.error("Could not write values: {}", columnValues, e1);
				throw new RuntimeException("Could not write values: " + columnValues, e1);
			}
		});

		if (isDone && csvPrinter != null) {
			try {
				csvPrinter.close(true);
			} catch (IOException e) {
				logger.error("Could not close CSV printer for filename {}", filename(), e);
				throw new RuntimeException("Could not close CSV printer for filename " + filename(), e);
			}
		}
	}

	public static class Builder<T extends Comparable<T>, U> extends ImmutableCSVEvolutionListener.Builder<T, U> {
	}

	public static <T extends Comparable<T>, U> CSVEvolutionListener<T, U> of(String filename,
			List<ColumnExtractor<T, U>> columnExtractors) {
		return ImmutableCSVEvolutionListener.of(filename, (Iterable<? extends ColumnExtractor<T, U>>) columnExtractors);
	}

	public static <T extends Comparable<T>, U> CSVEvolutionListener<T, U> of(String filename,
			Iterable<? extends ColumnExtractor<T, U>> columnExtractors) {
		return ImmutableCSVEvolutionListener.of(filename, columnExtractors);
	}

	public static <T extends Comparable<T>, U> CSVEvolutionListener<T, U> of(final String filename,
			final GenerationFunction<T, U> evolutionContextSupplier,
			final Iterable<? extends ColumnExtractor<T, U>> columnExtractors) {
		var csvEvolutionListenerBuilder = new CSVEvolutionListener.Builder<T, U>();

		csvEvolutionListenerBuilder.filename(filename)
				.evolutionContextSupplier(evolutionContextSupplier)
				.addAllColumnExtractors(columnExtractors);

		return csvEvolutionListenerBuilder.build();
	}

	public static <T extends Comparable<T>, U> CSVEvolutionListener<T, U> of(final String filename,
			final GenerationFunction<T, U> evolutionContextSupplier,
			final Iterable<? extends ColumnExtractor<T, U>> columnExtractors, final int skipN) {
		var csvEvolutionListenerBuilder = new CSVEvolutionListener.Builder<T, U>();

		csvEvolutionListenerBuilder.filename(filename)
				.evolutionContextSupplier(evolutionContextSupplier)
				.addAllColumnExtractors(columnExtractors)
				.skipN(skipN);

		return csvEvolutionListenerBuilder.build();
	}

	public static <T extends Comparable<T>, U> CSVEvolutionListener<T, U> ofTopN(final String filename,
			final GenerationFunction<T, U> evolutionContextSupplier,
			final Iterable<? extends ColumnExtractor<T, U>> columnExtractors, final Comparator<T> comparator,
			final int topN) {
		var csvEvolutionListenerBuilder = new CSVEvolutionListener.Builder<T, U>();

		csvEvolutionListenerBuilder.filename(filename)
				.evolutionContextSupplier(evolutionContextSupplier)
				.addAllColumnExtractors(columnExtractors)
				.filter(stream -> stream.sorted((a, b) -> comparator.reversed().compare(a.fitness(), b.fitness()))
						.limit(topN));

		return csvEvolutionListenerBuilder.build();
	}

	public static <T extends Comparable<T>, U> CSVEvolutionListener<T, U> ofTopN(final String filename,
			final GenerationFunction<T, U> evolutionContextSupplier,
			final Iterable<? extends ColumnExtractor<T, U>> columnExtractors, final int topN) {
		var csvEvolutionListenerBuilder = new CSVEvolutionListener.Builder<T, U>();

		csvEvolutionListenerBuilder.filename(filename)
				.evolutionContextSupplier(evolutionContextSupplier)
				.addAllColumnExtractors(columnExtractors)
				.filter(stream -> stream.sorted(Comparator.comparing(EvolutionStep::fitness)).limit(topN));

		return csvEvolutionListenerBuilder.build();
	}
}