/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.examples.training.util;

import ai.djl.Device;
import ai.djl.engine.Engine;
import ai.djl.examples.training.util.Arguments;
import ai.djl.examples.training.util.TrainingUtils;
import ai.djl.examples.util.MemoryUtils;
import ai.djl.metric.Metric;
import ai.djl.metric.Metrics;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.training.TrainingListener;
import ai.djl.training.loss.Loss;
import ai.djl.training.util.ProgressBar;
import java.io.IOException;
import java.util.List;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.DefaultParser;
import org.apache.commons.cli.HelpFormatter;
import org.apache.commons.cli.Options;
import org.apache.commons.cli.ParseException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class AbstractTraining
implements TrainingListener {
    private static final Logger logger = LoggerFactory.getLogger(AbstractTraining.class);
    protected float trainingAccuracy;
    protected float trainingLoss;
    protected float validationAccuracy;
    protected float validationLoss;
    protected int batchSize;
    protected int trainDataSize;
    protected int validateDataSize;
    protected int trainingProgress;
    protected int validateProgress;
    private long epochTime;
    private int numEpochs;
    private ProgressBar trainingProgressBar;
    private ProgressBar validateProgressBar;
    protected Metrics metrics = new Metrics();
    protected Loss loss;

    public boolean runExample(String[] args) {
        Options options = Arguments.getOptions();
        try {
            DefaultParser parser = new DefaultParser();
            CommandLine cmd = parser.parse(options, args, null, false);
            Arguments arguments = new Arguments(cmd);
            int maxGpus = arguments.getMaxGpus();
            this.batchSize = arguments.getBatchSize();
            String devices = maxGpus > 0 ? maxGpus + " GPUs" : Device.cpu().toString();
            logger.info("Running {} on: {}, epoch: {}.", new Object[]{this.getClass().getSimpleName(), devices, arguments.getEpoch()});
            long init = System.nanoTime();
            String version = Engine.getInstance().getVersion();
            long loaded = System.nanoTime();
            logger.info(String.format("Load library %s in %.3f ms.", version, Float.valueOf((float)(loaded - init) / 1000000.0f)));
            this.epochTime = System.nanoTime();
            this.train(arguments);
            logger.info("Training: {} batches", (Object)this.trainDataSize);
            logger.info("Validation: {} batches", (Object)this.validateDataSize);
            float p50 = (float)this.metrics.percentile("train", 50).getValue().longValue() / 1000000.0f;
            float p90 = (float)this.metrics.percentile("train", 90).getValue().longValue() / 1000000.0f;
            logger.info(String.format("train P50: %.3f ms, P90: %.3f ms", Float.valueOf(p50), Float.valueOf(p90)));
            p50 = (float)this.metrics.percentile("forward", 50).getValue().longValue() / 1000000.0f;
            p90 = (float)this.metrics.percentile("forward", 90).getValue().longValue() / 1000000.0f;
            logger.info(String.format("forward P50: %.3f ms, P90: %.3f ms", Float.valueOf(p50), Float.valueOf(p90)));
            p50 = (float)this.metrics.percentile("training-metrics", 50).getValue().longValue() / 1000000.0f;
            p90 = (float)this.metrics.percentile("training-metrics", 90).getValue().longValue() / 1000000.0f;
            logger.info(String.format("training-metrics P50: %.3f ms, P90: %.3f ms", Float.valueOf(p50), Float.valueOf(p90)));
            p50 = (float)this.metrics.percentile("backward", 50).getValue().longValue() / 1000000.0f;
            p90 = (float)this.metrics.percentile("backward", 90).getValue().longValue() / 1000000.0f;
            logger.info(String.format("backward P50: %.3f ms, P90: %.3f ms", Float.valueOf(p50), Float.valueOf(p90)));
            p50 = (float)this.metrics.percentile("step", 50).getValue().longValue() / 1000000.0f;
            p90 = (float)this.metrics.percentile("step", 90).getValue().longValue() / 1000000.0f;
            logger.info(String.format("step P50: %.3f ms, P90: %.3f ms", Float.valueOf(p50), Float.valueOf(p90)));
            p50 = (float)this.metrics.percentile("epoch", 50).getValue().longValue() / 1.0E9f;
            p90 = (float)this.metrics.percentile("epoch", 90).getValue().longValue() / 1.0E9f;
            logger.info(String.format("epoch P50: %.3f s, P90: %.3f s", Float.valueOf(p50), Float.valueOf(p90)));
            if (arguments.getOutputDir() != null) {
                MemoryUtils.dumpMemoryInfo(this.metrics, arguments.getOutputDir());
                TrainingUtils.dumpTrainingTimeInfo(this.metrics, arguments.getOutputDir());
            }
            return true;
        }
        catch (ParseException e) {
            HelpFormatter formatter = new HelpFormatter();
            formatter.setLeftPadding(1);
            formatter.setWidth(120);
            formatter.printHelp(e.getMessage(), options);
        }
        catch (Throwable t) {
            logger.error("Unexpected error", t);
        }
        return false;
    }

    protected abstract void train(Arguments var1) throws IOException, ModelNotFoundException;

    public void onTrainingBatch() {
        MemoryUtils.collectMemoryInfo(this.metrics);
        if (this.trainingProgressBar == null) {
            this.trainingProgressBar = new ProgressBar("Training", (long)this.trainDataSize);
        }
        this.trainingProgressBar.update((long)this.trainingProgress++, this.getTrainingStatus(this.metrics));
    }

    public void onValidationBatch() {
        MemoryUtils.collectMemoryInfo(this.metrics);
        if (this.validateProgressBar == null) {
            this.validateProgressBar = new ProgressBar("Validating", (long)this.validateDataSize);
        }
        this.validateProgressBar.update((long)this.validateProgress++);
    }

    public void onEpoch() {
        if (this.epochTime > 0L) {
            this.metrics.addMetric("epoch", (Number)(System.nanoTime() - this.epochTime));
        }
        logger.info("Epoch " + this.numEpochs + " finished.");
        this.printTrainingStatus(this.metrics);
        this.epochTime = System.nanoTime();
        ++this.numEpochs;
        this.trainingProgress = 0;
        this.validateProgress = 0;
    }

    public float getTrainingAccuracy() {
        return this.trainingAccuracy;
    }

    public float getTrainingLoss() {
        return this.trainingLoss;
    }

    public float getValidationAccuracy() {
        return this.validationAccuracy;
    }

    public float getValidationLoss() {
        return this.validationLoss;
    }

    public String getTrainingStatus(Metrics metrics) {
        StringBuilder sb = new StringBuilder();
        List list = metrics.getMetric("train_" + this.loss.getName());
        this.trainingLoss = ((Metric)list.get(list.size() - 1)).getValue().floatValue();
        list = metrics.getMetric("train_Accuracy");
        this.trainingAccuracy = ((Metric)list.get(list.size() - 1)).getValue().floatValue();
        sb.append(String.format("accuracy: %.2f loss: %.2f", Float.valueOf(this.trainingAccuracy), Float.valueOf(this.trainingLoss)));
        list = metrics.getMetric("train");
        if (!list.isEmpty()) {
            float batchTime = (float)((Metric)list.get(list.size() - 1)).getValue().longValue() / 1.0E9f;
            sb.append(String.format(" speed: %.2f images/sec", Float.valueOf((float)this.batchSize / batchTime)));
        }
        return sb.toString();
    }

    public void printTrainingStatus(Metrics metrics) {
        List list = metrics.getMetric("train_" + this.loss.getName());
        this.trainingLoss = ((Metric)list.get(list.size() - 1)).getValue().floatValue();
        list = metrics.getMetric("train_Accuracy");
        this.trainingAccuracy = ((Metric)list.get(list.size() - 1)).getValue().floatValue();
        logger.info("train accuracy: {}, train loss: {}", (Object)Float.valueOf(this.trainingAccuracy), (Object)Float.valueOf(this.trainingLoss));
        list = metrics.getMetric("validate_" + this.loss.getName());
        if (!list.isEmpty()) {
            this.validationLoss = ((Metric)list.get(list.size() - 1)).getValue().floatValue();
            list = metrics.getMetric("validate_Accuracy");
            this.validationAccuracy = ((Metric)list.get(list.size() - 1)).getValue().floatValue();
            logger.info("validate accuracy: {}, validate loss: {}", (Object)Float.valueOf(this.validationAccuracy), (Object)Float.valueOf(this.validationLoss));
        } else {
            logger.info("validation has not been run.");
        }
    }
}

