/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.examples.inference;

import ai.djl.ModelException;
import ai.djl.examples.inference.util.AbstractBenchmark;
import ai.djl.examples.inference.util.Arguments;
import ai.djl.examples.util.MemoryUtils;
import ai.djl.inference.Predictor;
import ai.djl.metric.Metrics;
import ai.djl.modality.Classifications;
import ai.djl.modality.cv.util.BufferedImageUtils;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.translate.TranslateException;
import java.awt.image.BufferedImage;
import java.io.IOException;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class MultithreadedBenchmark
extends AbstractBenchmark<Classifications> {
    private static final Logger logger = LoggerFactory.getLogger(MultithreadedBenchmark.class);

    public static void main(String[] args) {
        if (new MultithreadedBenchmark().runBenchmark(args)) {
            System.exit(0);
        }
        System.exit(-1);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public Classifications predict(Arguments arguments, Metrics metrics, int iteration) throws IOException, ModelException {
        Path imageFile = arguments.getImageFile();
        BufferedImage img = BufferedImageUtils.fromFile((Path)imageFile);
        ZooModel<BufferedImage, Classifications> model = this.loadModel(arguments, metrics);
        int numOfThreads = arguments.getThreads();
        logger.info("Multithreaded inference with {} threads.", (Object)numOfThreads);
        metrics.addMetric("thread", (Number)numOfThreads);
        ArrayList<PredictorCallable> callables = new ArrayList<PredictorCallable>(numOfThreads);
        for (int i = 0; i < numOfThreads; ++i) {
            callables.add(new PredictorCallable(model, img, metrics, iteration, i, i == 0));
        }
        Classifications classification = null;
        ExecutorService executorService = Executors.newFixedThreadPool(numOfThreads);
        int successThreads = 0;
        try {
            List futures = executorService.invokeAll(callables);
            for (Future future : futures) {
                try {
                    classification = (Classifications)future.get();
                    ++successThreads;
                }
                catch (InterruptedException | ExecutionException e) {
                    logger.error("", (Throwable)e);
                }
            }
        }
        catch (InterruptedException e) {
            logger.error("", (Throwable)e);
        }
        finally {
            executorService.shutdown();
        }
        if (successThreads != numOfThreads) {
            logger.error("Only {}/{} threads finished.", (Object)successThreads, (Object)numOfThreads);
        }
        return classification;
    }

    private static class PredictorCallable
    implements Callable<Classifications> {
        private Predictor<BufferedImage, Classifications> predictor;
        private BufferedImage img;
        private Metrics metrics;
        private int iteration;
        private String workerId;
        private boolean collectMemory;

        public PredictorCallable(ZooModel<BufferedImage, Classifications> model, BufferedImage img, Metrics metrics, int iteration, int workerId, boolean collectMemory) {
            this.predictor = model.newPredictor();
            this.img = img;
            this.metrics = metrics;
            this.iteration = iteration;
            this.workerId = String.format("%02d", workerId);
            this.collectMemory = collectMemory;
            this.predictor.setMetrics(metrics);
        }

        @Override
        public Classifications call() throws TranslateException {
            Classifications result = null;
            for (int i = 0; i < this.iteration; ++i) {
                result = (Classifications)this.predictor.predict((Object)this.img);
                if (this.collectMemory) {
                    MemoryUtils.collectMemoryInfo(this.metrics);
                }
                logger.trace("Worker-{}: {} iteration finished.", (Object)this.workerId, (Object)(i + 1));
            }
            logger.debug("Worker-{}: finished.", (Object)this.workerId);
            return result;
        }
    }
}

