/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.examples.training.util;

import ai.djl.Model;
import ai.djl.metric.Metric;
import ai.djl.metric.Metrics;
import ai.djl.training.Trainer;
import ai.djl.training.dataset.Batch;
import ai.djl.training.dataset.Dataset;
import java.io.BufferedWriter;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.nio.file.StandardOpenOption;
import java.nio.file.attribute.FileAttribute;
import java.util.List;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public final class TrainingUtils {
    private static final Logger logger = LoggerFactory.getLogger(TrainingUtils.class);

    private TrainingUtils() {
    }

    public static void fit(Trainer trainer, int numEpoch, Dataset trainingDataset, Dataset validateDataset, String outputDir, String modelName) throws IOException {
        for (int epoch = 0; epoch < numEpoch; ++epoch) {
            for (Batch batch : trainer.iterateDataset(trainingDataset)) {
                trainer.trainBatch(batch);
                trainer.step();
                batch.close();
            }
            if (validateDataset != null) {
                for (Batch batch : trainer.iterateDataset(validateDataset)) {
                    trainer.validateBatch(batch);
                    batch.close();
                }
            }
            trainer.resetTrainingMetrics();
            if (outputDir == null) continue;
            Model model = trainer.getModel();
            model.setProperty("Epoch", String.valueOf(epoch));
            model.save(Paths.get(outputDir, new String[0]), modelName);
        }
    }

    public static void dumpTrainingTimeInfo(Metrics metrics, String logDir) {
        if (logDir == null) {
            return;
        }
        try {
            Path dir = Paths.get(logDir, new String[0]);
            Files.createDirectories(dir, new FileAttribute[0]);
            Path file = dir.resolve("training.log");
            try (BufferedWriter writer = Files.newBufferedWriter(file, StandardOpenOption.CREATE, StandardOpenOption.APPEND);){
                List list = metrics.getMetric("train");
                for (Metric metric : list) {
                    writer.append(metric.toString());
                    writer.newLine();
                }
            }
        }
        catch (IOException e) {
            logger.error("Failed dump training log", (Throwable)e);
        }
    }
}

