/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.examples.inference.util;

import ai.djl.Device;
import ai.djl.ModelException;
import ai.djl.engine.Engine;
import ai.djl.examples.inference.util.Arguments;
import ai.djl.examples.util.MemoryUtils;
import ai.djl.metric.Metric;
import ai.djl.metric.Metrics;
import ai.djl.modality.Classifications;
import ai.djl.mxnet.zoo.MxModelZoo;
import ai.djl.repository.zoo.ModelLoader;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.TranslateException;
import ai.djl.util.Progress;
import ai.djl.zoo.ModelZoo;
import java.awt.image.BufferedImage;
import java.io.IOException;
import java.time.Duration;
import java.util.Map;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.DefaultParser;
import org.apache.commons.cli.HelpFormatter;
import org.apache.commons.cli.Options;
import org.apache.commons.cli.ParseException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class AbstractBenchmark<T> {
    private static final Logger logger = LoggerFactory.getLogger(AbstractBenchmark.class);
    private T lastResult;
    protected ProgressBar progressBar;

    protected abstract T predict(Arguments var1, Metrics var2, int var3) throws IOException, ModelException, TranslateException;

    protected Options getOptions() {
        return Arguments.getOptions();
    }

    protected Arguments parseArguments(CommandLine cmd) {
        return new Arguments(cmd);
    }

    public final boolean runBenchmark(String[] args) {
        Options options = this.getOptions();
        try {
            DefaultParser parser = new DefaultParser();
            CommandLine cmd = parser.parse(options, args, null, false);
            Arguments arguments = this.parseArguments(cmd);
            long init = System.nanoTime();
            String version = Engine.getInstance().getVersion();
            long loaded = System.nanoTime();
            logger.info(String.format("Load library %s in %.3f ms.", version, Float.valueOf((float)(loaded - init) / 1000000.0f)));
            Duration duration = Duration.ofMinutes(arguments.getDuration());
            if (arguments.getDuration() != 0) {
                logger.info("Running {} on: {}, duration: {} minutes.", new Object[]{this.getClass().getSimpleName(), Device.defaultDevice(), duration.toMinutes()});
            }
            int iteration = arguments.getIteration();
            while (!duration.isNegative()) {
                Metrics metrics = new Metrics();
                logger.info("Running {} on: {}, iteration: {}.", new Object[]{this.getClass().getSimpleName(), Device.defaultDevice(), iteration});
                this.progressBar = new ProgressBar("Iteration", (long)iteration);
                long begin = System.currentTimeMillis();
                this.lastResult = this.predict(arguments, metrics, iteration);
                long totalTime = System.currentTimeMillis() - begin;
                logger.info("Inference result: {}", this.lastResult);
                int totalRuns = iteration;
                if (metrics.hasMetric("thread")) {
                    totalRuns *= ((Metric)metrics.getMetric("thread").get(0)).getValue().intValue();
                }
                logger.info(String.format("total time: %d ms, total runs: %d iterations", totalTime, totalRuns));
                if (metrics.hasMetric("LoadModel")) {
                    long loadModelTime = ((Metric)metrics.getMetric("LoadModel").get(0)).getValue().longValue();
                    logger.info("Model loading time: {} ms.", (Object)String.format("%.3f", Float.valueOf((float)loadModelTime / 1000000.0f)));
                }
                if (metrics.hasMetric("Inference") && iteration > 1) {
                    float p50 = (float)metrics.percentile("Inference", 50).getValue().longValue() / 1000000.0f;
                    float p90 = (float)metrics.percentile("Inference", 90).getValue().longValue() / 1000000.0f;
                    float p99 = (float)metrics.percentile("Inference", 99).getValue().longValue() / 1000000.0f;
                    float preP50 = (float)metrics.percentile("Preprocess", 50).getValue().longValue() / 1000000.0f;
                    float preP90 = (float)metrics.percentile("Preprocess", 90).getValue().longValue() / 1000000.0f;
                    float preP99 = (float)metrics.percentile("Preprocess", 99).getValue().longValue() / 1000000.0f;
                    float postP50 = (float)metrics.percentile("Postprocess", 50).getValue().longValue() / 1000000.0f;
                    float postP90 = (float)metrics.percentile("Postprocess", 90).getValue().longValue() / 1000000.0f;
                    float postP99 = (float)metrics.percentile("Postprocess", 99).getValue().longValue() / 1000000.0f;
                    logger.info(String.format("inference P50: %.3f ms, P90: %.3f ms, P99: %.3f ms", Float.valueOf(p50), Float.valueOf(p90), Float.valueOf(p99)));
                    logger.info(String.format("preprocess P50: %.3f ms, P90: %.3f ms, P99: %.3f ms", Float.valueOf(preP50), Float.valueOf(preP90), Float.valueOf(preP99)));
                    logger.info(String.format("postprocess P50: %.3f ms, P90: %.3f ms, P99: %.3f ms", Float.valueOf(postP50), Float.valueOf(postP90), Float.valueOf(postP99)));
                    if (Boolean.getBoolean("collect-memory")) {
                        float heap = metrics.percentile("Heap", 90).getValue().longValue();
                        float nonHeap = metrics.percentile("NonHeap", 90).getValue().longValue();
                        float cpu = metrics.percentile("cpu", 90).getValue().longValue();
                        float rss = metrics.percentile("rss", 90).getValue().longValue();
                        logger.info(String.format("heap P90: %.3f", Float.valueOf(heap)));
                        logger.info(String.format("nonHeap P90: %.3f", Float.valueOf(nonHeap)));
                        logger.info(String.format("cpu P90: %.3f", Float.valueOf(cpu)));
                        logger.info(String.format("rss P90: %.3f", Float.valueOf(rss)));
                    }
                }
                MemoryUtils.dumpMemoryInfo(metrics, arguments.getOutputDir());
                long delta = System.currentTimeMillis() - begin;
                if ((duration = duration.minus(Duration.ofMillis(delta))).isNegative()) continue;
                logger.info(duration.toMinutes() + " minutes left");
            }
            return true;
        }
        catch (ParseException e) {
            HelpFormatter formatter = new HelpFormatter();
            formatter.setLeftPadding(1);
            formatter.setWidth(120);
            formatter.printHelp(e.getMessage(), options);
        }
        catch (Throwable t) {
            logger.error("Unexpected error", t);
        }
        return false;
    }

    public T getPredictResult() {
        return this.lastResult;
    }

    protected ZooModel<BufferedImage, Classifications> loadModel(Arguments arguments, Metrics metrics) throws ModelException, IOException {
        long begin = System.nanoTime();
        String modelName = arguments.getModelName();
        if (modelName == null) {
            modelName = "RESNET";
        }
        Map<String, String> criteria = arguments.getCriteria();
        ModelLoader loader = arguments.isImperative() ? ModelZoo.getModelLoader((String)modelName) : MxModelZoo.getModelLoader((String)modelName);
        ProgressBar progress = new ProgressBar();
        ZooModel model = loader.loadModel(criteria, (Progress)progress);
        long delta = System.nanoTime() - begin;
        logger.info("Model {} loaded in: {} ms.", (Object)model.getName(), (Object)String.format("%.3f", Float.valueOf((float)delta / 1000000.0f)));
        metrics.addMetric("LoadModel", (Number)delta);
        return model;
    }
}

