/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.examples.training;

import ai.djl.Device;
import ai.djl.Model;
import ai.djl.basicdataset.Mnist;
import ai.djl.examples.training.util.AbstractTraining;
import ai.djl.examples.training.util.Arguments;
import ai.djl.examples.training.util.TrainingUtils;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Block;
import ai.djl.training.DefaultTrainingConfig;
import ai.djl.training.Trainer;
import ai.djl.training.TrainingConfig;
import ai.djl.training.TrainingListener;
import ai.djl.training.dataset.Dataset;
import ai.djl.training.initializer.Initializer;
import ai.djl.training.initializer.XavierInitializer;
import ai.djl.training.loss.Loss;
import ai.djl.training.metrics.Accuracy;
import ai.djl.training.metrics.TrainingMetric;
import ai.djl.training.optimizer.Optimizer;
import ai.djl.training.optimizer.Sgd;
import ai.djl.training.optimizer.learningrate.FactorTracker;
import ai.djl.training.optimizer.learningrate.LearningRateTracker;
import ai.djl.training.util.ProgressBar;
import ai.djl.util.Progress;
import ai.djl.zoo.cv.classification.Mlp;
import java.io.IOException;
import java.nio.file.Paths;

public final class TrainMnist
extends AbstractTraining {
    public static void main(String[] args) {
        new TrainMnist().runExample(args);
    }

    @Override
    protected void train(Arguments arguments) throws IOException {
        Mlp block = new Mlp(28, 28);
        try (Model model = Model.newInstance();){
            model.setBlock((Block)block);
            Dataset trainingSet = this.getDataset(model.getNDManager(), Dataset.Usage.TRAIN, arguments);
            Dataset validateSet = this.getDataset(model.getNDManager(), Dataset.Usage.TEST, arguments);
            TrainingConfig config = this.setupTrainingConfig(arguments);
            try (Trainer trainer = model.newTrainer(config);){
                trainer.setMetrics(this.metrics);
                trainer.setTrainingListener((TrainingListener)this);
                Shape inputShape = new Shape(new long[]{1L, 784L});
                trainer.initialize(new Shape[]{inputShape});
                TrainingUtils.fit(trainer, arguments.getEpoch(), trainingSet, validateSet, arguments.getOutputDir(), "mlp");
            }
            model.setProperty("Epoch", String.valueOf(arguments.getEpoch()));
            model.setProperty("Accuracy", String.format("%.2f", Float.valueOf(this.getValidationAccuracy())));
            model.save(Paths.get(arguments.getOutputDir(), new String[0]), "mlp");
        }
    }

    private TrainingConfig setupTrainingConfig(Arguments arguments) {
        int batchSize = arguments.getBatchSize();
        FactorTracker factorTracker = ((FactorTracker.Builder)((FactorTracker.Builder)((FactorTracker.Builder)LearningRateTracker.factorTracker().optBaseLearningRate(0.1f)).setStep(60000 / batchSize).optFactor(0.1f).optWarmUpBeginLearningRate(0.01f)).optWarmUpSteps(500)).optStopFactorLearningRate(0.001f).build();
        Sgd optimizer = ((Sgd.Builder)((Sgd.Builder)((Sgd.Builder)Optimizer.sgd().setRescaleGrad(1.0f / (float)batchSize)).setLearningRateTracker((LearningRateTracker)factorTracker).optWeightDecays(0.001f)).optMomentum(0.9f).optClipGrad(1.0f)).build();
        this.loss = Loss.softmaxCrossEntropyLoss();
        return new DefaultTrainingConfig((Initializer)new XavierInitializer(), this.loss).setOptimizer((Optimizer)optimizer).addTrainingMetric((TrainingMetric)new Accuracy()).setBatchSize(batchSize).setDevices(Device.getDevices((int)arguments.getMaxGpus()));
    }

    private Dataset getDataset(NDManager manager, Dataset.Usage usage, Arguments arguments) throws IOException {
        int batchSize = arguments.getBatchSize();
        long maxIterations = arguments.getMaxIterations();
        Mnist mnist = ((Mnist.Builder)((Mnist.Builder)Mnist.builder((NDManager)manager).optUsage(usage).setSampling((long)batchSize, true)).optMaxIteration(maxIterations)).build();
        mnist.prepare((Progress)new ProgressBar());
        if (usage == Dataset.Usage.TRAIN) {
            this.trainDataSize = (int)Math.min(mnist.size() / (long)batchSize, maxIterations);
        } else {
            this.validateDataSize = (int)Math.min(mnist.size() / (long)batchSize, maxIterations);
        }
        return mnist;
    }
}

