/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.examples.training;

import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.Model;
import ai.djl.basicdataset.PikachuDetection;
import ai.djl.examples.training.util.AbstractTraining;
import ai.djl.examples.training.util.Arguments;
import ai.djl.examples.training.util.TrainingUtils;
import ai.djl.metric.Metric;
import ai.djl.metric.Metrics;
import ai.djl.modality.cv.MultiBoxDetection;
import ai.djl.modality.cv.transform.ToTensor;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Block;
import ai.djl.nn.LambdaBlock;
import ai.djl.nn.SequentialBlock;
import ai.djl.training.DefaultTrainingConfig;
import ai.djl.training.Trainer;
import ai.djl.training.TrainingConfig;
import ai.djl.training.TrainingListener;
import ai.djl.training.dataset.Dataset;
import ai.djl.training.initializer.Initializer;
import ai.djl.training.initializer.XavierInitializer;
import ai.djl.training.loss.SingleShotDetectionLoss;
import ai.djl.training.metrics.BoundingBoxError;
import ai.djl.training.metrics.SingleShotDetectionAccuracy;
import ai.djl.training.metrics.TrainingMetric;
import ai.djl.training.optimizer.Optimizer;
import ai.djl.training.optimizer.Sgd;
import ai.djl.training.optimizer.learningrate.LearningRateTracker;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.Pipeline;
import ai.djl.translate.Transform;
import ai.djl.translate.TranslateException;
import ai.djl.util.Progress;
import ai.djl.zoo.cv.object_detection.ssd.SingleShotDetection;
import java.io.IOException;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public final class TrainPikachu
extends AbstractTraining {
    private float trainingClassAccuracy;
    private float trainingBoundingBoxError;
    private float validationClassAccuracy;
    private float validationBoundingBoxError;
    private static final Logger logger = LoggerFactory.getLogger(TrainPikachu.class);

    public static void main(String[] args) {
        new TrainPikachu().runExample(args);
    }

    @Override
    protected void train(Arguments arguments) throws IOException {
        this.batchSize = arguments.getBatchSize();
        TrainingConfig config = this.setupTrainingConfig(arguments);
        try (Model model = Model.newInstance();){
            model.setBlock(TrainPikachu.getSsdTrainBlock());
            try (Trainer trainer = model.newTrainer(config);){
                trainer.setMetrics(this.metrics);
                trainer.setTrainingListener((TrainingListener)this);
                Dataset pikachuDetectionTrain = this.getDataset(Dataset.Usage.TRAIN, arguments);
                Dataset pikachuDetectionTest = this.getDataset(Dataset.Usage.TEST, arguments);
                Shape inputShape = new Shape(new long[]{this.batchSize, 3L, 256L, 256L});
                trainer.initialize(new Shape[]{inputShape});
                TrainingUtils.fit(trainer, arguments.getEpoch(), pikachuDetectionTrain, pikachuDetectionTest, arguments.getOutputDir(), "ssd");
            }
            model.setProperty("Epoch", String.valueOf(arguments.getEpoch()));
            model.setProperty("Loss", String.format("%.5f", Float.valueOf(this.validationLoss)));
            model.setProperty("ClassAccuracy", String.format("%.5f", Float.valueOf(this.validationClassAccuracy)));
            model.setProperty("BoundingBoxError", String.format("%.5f", Float.valueOf(this.validationBoundingBoxError)));
            model.save(Paths.get(arguments.getOutputDir(), new String[0]), "ssd");
        }
    }

    /*
     * Exception decompiling
     */
    public int predict(String outputDir, String imageFile) throws IOException, MalformedModelException, TranslateException {
        /*
         * This method has failed to decompile.  When submitting a bug report, please provide this stack trace, and (if you hold appropriate legal rights) the relevant class file.
         * 
         * org.benf.cfr.reader.util.ConfusedCFRException: Started 2 blocks at once
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op04StructuredStatement.getStartingBlocks(Op04StructuredStatement.java:412)
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op04StructuredStatement.buildNestedBlocks(Op04StructuredStatement.java:487)
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op03SimpleStatement.createInitialStructuredBlock(Op03SimpleStatement.java:736)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysisInner(CodeAnalyser.java:850)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysisOrWrapFail(CodeAnalyser.java:278)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysis(CodeAnalyser.java:201)
         *     at org.benf.cfr.reader.entities.attributes.AttributeCode.analyse(AttributeCode.java:94)
         *     at org.benf.cfr.reader.entities.Method.analyse(Method.java:531)
         *     at org.benf.cfr.reader.entities.ClassFile.analyseMid(ClassFile.java:1055)
         *     at org.benf.cfr.reader.entities.ClassFile.analyseTop(ClassFile.java:942)
         *     at org.benf.cfr.reader.Driver.doJarVersionTypes(Driver.java:257)
         *     at org.benf.cfr.reader.Driver.doJar(Driver.java:139)
         *     at org.benf.cfr.reader.CfrDriverImpl.analyse(CfrDriverImpl.java:76)
         *     at org.benf.cfr.reader.Main.main(Main.java:54)
         */
        throw new IllegalStateException("Decompilation failed");
    }

    @Override
    public String getTrainingStatus(Metrics metrics) {
        StringBuilder sb = new StringBuilder();
        List list = metrics.getMetric("train_" + this.loss.getName());
        this.trainingLoss = ((Metric)list.get(list.size() - 1)).getValue().floatValue();
        list = metrics.getMetric("train_classAccuracy");
        this.trainingClassAccuracy = ((Metric)list.get(list.size() - 1)).getValue().floatValue();
        list = metrics.getMetric("train_boundingBoxError");
        this.trainingBoundingBoxError = ((Metric)list.get(list.size() - 1)).getValue().floatValue();
        sb.append(String.format("loss: %2.3ef, classAccuracy: %.4f, bboxError: %2.3e,", Float.valueOf(this.trainingLoss), Float.valueOf(this.trainingClassAccuracy), Float.valueOf(this.trainingBoundingBoxError)));
        list = metrics.getMetric("train");
        if (!list.isEmpty()) {
            float batchTime = (float)((Metric)list.get(list.size() - 1)).getValue().longValue() / 1.0E9f;
            sb.append(String.format(" speed: %.2f images/sec", Float.valueOf((float)this.batchSize / batchTime)));
        }
        return sb.toString();
    }

    @Override
    public void printTrainingStatus(Metrics metrics) {
        List list = metrics.getMetric("train_" + this.loss.getName());
        this.trainingLoss = ((Metric)list.get(list.size() - 1)).getValue().floatValue();
        list = metrics.getMetric("train_classAccuracy");
        this.trainingClassAccuracy = ((Metric)list.get(list.size() - 1)).getValue().floatValue();
        list = metrics.getMetric("train_boundingBoxError");
        this.trainingBoundingBoxError = ((Metric)list.get(list.size() - 1)).getValue().floatValue();
        logger.info("train loss: {}, train class accuracy: {}, train bounding box error: {}", new Object[]{Float.valueOf(this.trainingLoss), Float.valueOf(this.trainingClassAccuracy), Float.valueOf(this.trainingBoundingBoxError)});
        list = metrics.getMetric("validate_" + this.loss.getName());
        if (!list.isEmpty()) {
            this.validationLoss = ((Metric)list.get(list.size() - 1)).getValue().floatValue();
            list = metrics.getMetric("validate_classAccuracy");
            this.validationClassAccuracy = ((Metric)list.get(list.size() - 1)).getValue().floatValue();
            list = metrics.getMetric("validate_boundingBoxError");
            this.validationBoundingBoxError = ((Metric)list.get(list.size() - 1)).getValue().floatValue();
            logger.info("validate loss: {}, validate class accuracy: {}, validate bounding box error: {}", new Object[]{Float.valueOf(this.validationLoss), Float.valueOf(this.validationClassAccuracy), Float.valueOf(this.validationBoundingBoxError)});
        } else {
            logger.info("validation has not been run.");
        }
    }

    private Dataset getDataset(Dataset.Usage usage, Arguments arguments) throws IOException {
        Pipeline pipeline = new Pipeline(new Transform[]{new ToTensor()});
        PikachuDetection pikachuDetection = ((PikachuDetection.Builder)((PikachuDetection.Builder)new PikachuDetection.Builder().optUsage(usage).optPipeline(pipeline)).setSampling((long)this.batchSize, true)).build();
        pikachuDetection.prepare((Progress)new ProgressBar());
        long maxIterations = arguments.getMaxIterations();
        int dataSize = (int)Math.min(pikachuDetection.size() / (long)this.batchSize, maxIterations);
        if (usage == Dataset.Usage.TRAIN) {
            this.trainDataSize = dataSize;
        } else if (usage == Dataset.Usage.TEST) {
            this.validateDataSize = dataSize;
        }
        return pikachuDetection;
    }

    private TrainingConfig setupTrainingConfig(Arguments arguments) {
        XavierInitializer initializer = new XavierInitializer(XavierInitializer.RandomType.UNIFORM, XavierInitializer.FactorType.AVG, 2.0);
        Sgd optimizer = ((Sgd.Builder)((Sgd.Builder)Optimizer.sgd().setRescaleGrad(1.0f / (float)this.batchSize)).setLearningRateTracker((LearningRateTracker)LearningRateTracker.fixedLearningRate((float)0.2f)).optWeightDecays(5.0E-4f)).build();
        this.loss = new SingleShotDetectionLoss("ssd_loss");
        return new DefaultTrainingConfig((Initializer)initializer, this.loss).setOptimizer((Optimizer)optimizer).setBatchSize(this.batchSize).addTrainingMetric((TrainingMetric)new SingleShotDetectionAccuracy("classAccuracy")).addTrainingMetric((TrainingMetric)new BoundingBoxError("boundingBoxError")).setDevices(Device.getDevices((int)arguments.getMaxGpus()));
    }

    public static Block getSsdTrainBlock() {
        int[] numFilters = new int[]{16, 32, 64};
        SequentialBlock baseBlock = new SequentialBlock();
        for (int numFilter : numFilters) {
            baseBlock.add((Block)SingleShotDetection.getDownSamplingBlock((int)numFilter));
        }
        ArrayList<List<Float>> sizes = new ArrayList<List<Float>>();
        ArrayList<List<Float>> ratios = new ArrayList<List<Float>>();
        for (int i = 0; i < 5; ++i) {
            ratios.add(Arrays.asList(Float.valueOf(1.0f), Float.valueOf(2.0f), Float.valueOf(0.5f)));
        }
        sizes.add(Arrays.asList(Float.valueOf(0.2f), Float.valueOf(0.272f)));
        sizes.add(Arrays.asList(Float.valueOf(0.37f), Float.valueOf(0.447f)));
        sizes.add(Arrays.asList(Float.valueOf(0.54f), Float.valueOf(0.619f)));
        sizes.add(Arrays.asList(Float.valueOf(0.71f), Float.valueOf(0.79f)));
        sizes.add(Arrays.asList(Float.valueOf(0.88f), Float.valueOf(0.961f)));
        return new SingleShotDetection.Builder().setNumClasses(1).setNumFeatures(3).optGlobalPool(true).setRatios(ratios).setSizes(sizes).setBaseNetwork((Block)baseBlock).build();
    }

    public static Block getSsdPredictBlock(Block ssdTrain) {
        SequentialBlock ssdPredict = new SequentialBlock();
        ssdPredict.add(ssdTrain);
        ssdPredict.add((Block)new LambdaBlock(output -> {
            NDArray anchors = (NDArray)output.get(0);
            NDArray classPredictions = ((NDArray)output.get(1)).softmax(-1).transpose(new int[]{0, 2, 1});
            NDArray boundingBoxPredictions = (NDArray)output.get(2);
            MultiBoxDetection multiBoxDetection = new MultiBoxDetection.Builder().build();
            NDList detections = multiBoxDetection.detection(new NDList(new NDArray[]{classPredictions, boundingBoxPredictions, anchors}));
            return detections.singletonOrThrow().split(new int[]{1, 2}, 2);
        }));
        return ssdPredict;
    }
}

