/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.examples.training.transferlearning;

import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.Model;
import ai.djl.basicdataset.Cifar10;
import ai.djl.examples.training.util.AbstractTraining;
import ai.djl.examples.training.util.Arguments;
import ai.djl.examples.training.util.TrainingUtils;
import ai.djl.modality.cv.transform.Normalize;
import ai.djl.modality.cv.transform.ToTensor;
import ai.djl.mxnet.zoo.MxModelZoo;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Block;
import ai.djl.nn.Blocks;
import ai.djl.nn.SequentialBlock;
import ai.djl.nn.SymbolBlock;
import ai.djl.nn.core.Linear;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.DefaultTrainingConfig;
import ai.djl.training.Trainer;
import ai.djl.training.TrainingConfig;
import ai.djl.training.TrainingListener;
import ai.djl.training.dataset.Dataset;
import ai.djl.training.initializer.Initializer;
import ai.djl.training.initializer.XavierInitializer;
import ai.djl.training.loss.Loss;
import ai.djl.training.metrics.Accuracy;
import ai.djl.training.metrics.TrainingMetric;
import ai.djl.training.optimizer.Adam;
import ai.djl.training.optimizer.Optimizer;
import ai.djl.training.optimizer.learningrate.LearningRateTracker;
import ai.djl.training.optimizer.learningrate.MultiFactorTracker;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.Pipeline;
import ai.djl.translate.Transform;
import ai.djl.util.Progress;
import ai.djl.zoo.ModelZoo;
import ai.djl.zoo.cv.classification.ResNetV1;
import java.io.IOException;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

public final class TrainResnetWithCifar10
extends AbstractTraining {
    public static void main(String[] args) {
        new TrainResnetWithCifar10().runExample(args);
    }

    @Override
    protected void train(Arguments arguments) throws IOException, ModelNotFoundException {
        try (Model model = this.getModel(arguments);){
            this.batchSize = arguments.getBatchSize();
            Dataset trainDataset = this.getDataset(model.getNDManager(), Dataset.Usage.TRAIN, arguments);
            Dataset validationDataset = this.getDataset(model.getNDManager(), Dataset.Usage.TEST, arguments);
            TrainingConfig config = this.setupTrainingConfig(arguments);
            try (Trainer trainer = model.newTrainer(config);){
                trainer.setMetrics(this.metrics);
                trainer.setTrainingListener((TrainingListener)this);
                Shape inputShape = new Shape(new long[]{1L, 3L, 32L, 32L});
                trainer.initialize(new Shape[]{inputShape});
                TrainingUtils.fit(trainer, arguments.getEpoch(), trainDataset, validationDataset, arguments.getOutputDir(), "resnetv1");
            }
            model.setProperty("Epoch", String.valueOf(arguments.getEpoch()));
            model.setProperty("Accuracy", String.format("%.2f", Float.valueOf(this.getValidationAccuracy())));
            model.save(Paths.get("build/model", new String[0]), "resnetv1");
        }
        catch (MalformedModelException e) {
            throw new IllegalArgumentException(e);
        }
    }

    private Model getModel(Arguments arguments) throws IOException, ModelNotFoundException, MalformedModelException {
        boolean isSymbolic = arguments.isSymbolic();
        boolean preTrained = arguments.isPreTrained();
        Map<String, String> criteria = arguments.getCriteria();
        if (isSymbolic) {
            if (criteria == null) {
                criteria = new ConcurrentHashMap<String, String>();
                criteria.put("layers", "50");
                criteria.put("flavor", "v1");
            }
            ZooModel model = MxModelZoo.RESNET.loadModel(criteria, (Progress)new ProgressBar());
            SequentialBlock newBlock = new SequentialBlock();
            SymbolBlock block = (SymbolBlock)model.getBlock();
            block.removeLastBlock();
            newBlock.add((Block)block);
            newBlock.add(x -> new NDList(new NDArray[]{x.singletonOrThrow().squeeze()}));
            newBlock.add((Block)new Linear.Builder().setOutChannels(10L).build());
            newBlock.add(Blocks.batchFlattenBlock());
            model.setBlock((Block)newBlock);
            if (!preTrained) {
                model.getBlock().clear();
            }
            return model;
        }
        if (preTrained) {
            if (criteria == null) {
                criteria = new ConcurrentHashMap<String, String>();
                criteria.put("layers", "50");
                criteria.put("flavor", "v1");
                criteria.put("dataset", "cifar10");
            }
            return ModelZoo.RESNET.loadModel(criteria, (Progress)new ProgressBar());
        }
        Model model = Model.newInstance();
        Block resNet50 = new ResNetV1.Builder().setImageShape(new Shape(new long[]{3L, 32L, 32L})).setNumLayers(50).setOutSize(10L).build();
        model.setBlock(resNet50);
        return model;
    }

    private TrainingConfig setupTrainingConfig(Arguments arguments) {
        int[] epochs = arguments.isPreTrained() ? new int[]{2, 5, 8} : new int[]{20, 60, 90, 120, 180};
        int[] steps = Arrays.stream(epochs).map(k -> k * 60000 / this.batchSize).toArray();
        XavierInitializer initializer = new XavierInitializer(XavierInitializer.RandomType.UNIFORM, XavierInitializer.FactorType.AVG, 2.0);
        MultiFactorTracker learningRateTracker = ((MultiFactorTracker.Builder)((MultiFactorTracker.Builder)((MultiFactorTracker.Builder)LearningRateTracker.multiFactorTracker().setSteps(steps).optBaseLearningRate(0.001f)).optFactor((float)Math.sqrt(0.1f)).optWarmUpBeginLearningRate(1.0E-4f)).optWarmUpSteps(200)).build();
        Adam optimizer = ((Adam.Builder)((Adam.Builder)((Adam.Builder)Optimizer.adam().setRescaleGrad(1.0f / (float)this.batchSize)).optLearningRateTracker((LearningRateTracker)learningRateTracker).optWeightDecays(0.001f)).optClipGrad(5.0f)).build();
        this.loss = Loss.softmaxCrossEntropyLoss();
        return new DefaultTrainingConfig((Initializer)initializer, this.loss).setOptimizer((Optimizer)optimizer).addTrainingMetric((TrainingMetric)new Accuracy()).setBatchSize(this.batchSize).setDevices(Device.getDevices((int)arguments.getMaxGpus()));
    }

    private Dataset getDataset(NDManager manager, Dataset.Usage usage, Arguments arguments) throws IOException {
        Pipeline pipeline = new Pipeline(new Transform[]{new ToTensor(), new Normalize(new float[]{0.4914f, 0.4822f, 0.4465f}, new float[]{0.2023f, 0.1994f, 0.201f})});
        long maxIterations = arguments.getMaxIterations();
        Cifar10 cifar10 = ((Cifar10.Builder)((Cifar10.Builder)((Cifar10.Builder)Cifar10.builder((NDManager)manager).optUsage(usage).setSampling((long)this.batchSize, true)).optMaxIteration(maxIterations)).optPipeline(pipeline)).build();
        cifar10.prepare((Progress)new ProgressBar());
        int dataSize = (int)Math.min(cifar10.size() / (long)this.batchSize, maxIterations);
        if (usage == Dataset.Usage.TRAIN) {
            this.trainDataSize = dataSize;
        } else if (usage == Dataset.Usage.TEST) {
            this.validateDataSize = dataSize;
        }
        return cifar10;
    }
}

