/*
 * Decompiled with CFR 0.152.
 */
package com.johnsnowlabs.ml.onnx;

import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtSession;
import ai.onnxruntime.providers.OrtCUDAProviderOptions;
import com.johnsnowlabs.ml.onnx.OnnxSession;
import com.johnsnowlabs.ml.onnx.OnnxWrapper;
import com.johnsnowlabs.util.ConfigHelper$;
import com.johnsnowlabs.util.FileHelper$;
import com.johnsnowlabs.util.ZipArchiveUtil$;
import java.io.File;
import java.io.Serializable;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.nio.file.attribute.FileAttribute;
import java.util.EnumSet;
import java.util.UUID;
import org.apache.commons.io.FileUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.Array$;
import scala.Function0;
import scala.Function1;
import scala.MatchError;
import scala.Option;
import scala.Predef$;
import scala.Some;
import scala.Tuple2;
import scala.collection.immutable.Map;
import scala.collection.immutable.StringOps;
import scala.collection.mutable.ArrayOps;
import scala.reflect.ClassTag$;
import scala.util.Failure;
import scala.util.Success;
import scala.util.Try;
import scala.util.Try$;

public final class OnnxWrapper$
implements scala.Serializable {
    public static OnnxWrapper$ MODULE$;
    private final Logger logger;

    static {
        new OnnxWrapper$();
    }

    public Logger logger() {
        return this.logger;
    }

    public synchronized Tuple2<OrtSession, OrtEnvironment> com$johnsnowlabs$ml$onnx$OnnxWrapper$$withSafeOnnxModelLoader(byte[] onnxModel, Map<String, String> sessionOptions) {
        OrtEnvironment env = OrtEnvironment.getEnvironment();
        OrtSession.SessionOptions sessionOptionsObject = sessionOptions.isEmpty() ? new OrtSession.SessionOptions() : this.mapToSessionOptionsObject(sessionOptions);
        OrtSession session = env.createSession(onnxModel, sessionOptionsObject);
        return new Tuple2((Object)session, (Object)env);
    }

    public OnnxWrapper read(String modelPath, boolean zipped, boolean useBundle, String modelName) {
        String tmpFolder = ((Object)Files.createTempDirectory(new StringBuilder(5).append((String)new StringOps(Predef$.MODULE$.augmentString(UUID.randomUUID().toString())).takeRight(12)).append("_onnx").toString(), new FileAttribute[0]).toAbsolutePath()).toString();
        String folder = zipped ? ZipArchiveUtil$.MODULE$.unzip(new File(modelPath), (Option<String>)new Some((Object)tmpFolder)) : modelPath;
        Map<String, String> sessionOptions = new OnnxSession().getSessionOptions();
        String onnxFile = useBundle ? ((Object)Paths.get(modelPath, new StringBuilder(5).append(modelName).append(".onnx").toString())).toString() : ((Object)Paths.get(folder, (String)new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])new File(folder).list())).head())).toString();
        File modelFile = new File(onnxFile);
        byte[] modelBytes = FileUtils.readFileToByteArray((File)modelFile);
        Tuple2<OrtSession, OrtEnvironment> tuple2 = this.com$johnsnowlabs$ml$onnx$OnnxWrapper$$withSafeOnnxModelLoader(modelBytes, sessionOptions);
        if (tuple2 == null) {
            throw new MatchError(tuple2);
        }
        OrtSession session = (OrtSession)tuple2._1();
        OrtEnvironment env = (OrtEnvironment)tuple2._2();
        Tuple2 tuple22 = new Tuple2((Object)session, (Object)env);
        Tuple2 tuple23 = tuple22;
        OrtSession session2 = (OrtSession)tuple23._1();
        OrtEnvironment env2 = (OrtEnvironment)tuple23._2();
        FileHelper$.MODULE$.delete(tmpFolder, FileHelper$.MODULE$.delete$default$2());
        OnnxWrapper onnxWrapper = new OnnxWrapper(modelBytes);
        onnxWrapper.com$johnsnowlabs$ml$onnx$OnnxWrapper$$ortSession_$eq(session2);
        onnxWrapper.com$johnsnowlabs$ml$onnx$OnnxWrapper$$ortEnv_$eq(env2);
        return onnxWrapper;
    }

    public boolean read$default$2() {
        return true;
    }

    public boolean read$default$3() {
        return false;
    }

    public String read$default$4() {
        return "model";
    }

    private OrtSession.SessionOptions mapToSessionOptionsObject(Map<String, String> sessionOptions) {
        EnumSet providers = OrtEnvironment.getAvailableProviders();
        return new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(providers.toArray())).map((Function1 & Serializable & scala.Serializable)x -> x.toString(), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(String.class))))).contains((Object)"CUDA") ? this.mapToCUDASessionConfig(sessionOptions) : this.mapToCPUSessionConfig(sessionOptions);
    }

    /*
     * WARNING - void declaration
     */
    private OrtSession.SessionOptions mapToCUDASessionConfig(Map<String, String> sessionOptionsMap) {
        void var3_3;
        this.logger().info("Using CUDA");
        Predef$.MODULE$.println((Object)"Using CUDA");
        int gpuDeviceId = new StringOps(Predef$.MODULE$.augmentString((String)sessionOptionsMap.apply((Object)ConfigHelper$.MODULE$.onnxGpuDeviceId()))).toInt();
        OrtSession.SessionOptions sessionOptions = new OrtSession.SessionOptions();
        this.logger().info(new StringBuilder(32).append("ONNX session option gpuDeviceId=").append(gpuDeviceId).toString());
        OrtCUDAProviderOptions cudaOpts = new OrtCUDAProviderOptions(gpuDeviceId);
        sessionOptions.addCUDA(cudaOpts);
        return var3_3;
    }

    private OrtSession.SessionOptions mapToCPUSessionConfig(Map<String, String> sessionOptionsMap) {
        OrtSession.SessionOptions.ExecutionMode defaultExecutionMode = OrtSession.SessionOptions.ExecutionMode.SEQUENTIAL;
        OrtSession.SessionOptions.OptLevel defaultOptLevel = OrtSession.SessionOptions.OptLevel.ALL_OPT;
        this.logger().info("Using CPUs");
        Predef$.MODULE$.println((Object)"Using CPUs");
        int intraOpNumThreads = new StringOps(Predef$.MODULE$.augmentString((String)sessionOptionsMap.apply((Object)ConfigHelper$.MODULE$.onnxIntraOpNumThreads()))).toInt();
        OrtSession.SessionOptions.OptLevel optimizationLevel = this.getOptLevel$1((String)sessionOptionsMap.apply((Object)ConfigHelper$.MODULE$.onnxOptimizationLevel()), defaultOptLevel);
        OrtSession.SessionOptions.ExecutionMode executionMode = this.getExecutionMode$1((String)sessionOptionsMap.apply((Object)ConfigHelper$.MODULE$.onnxExecutionMode()), defaultExecutionMode);
        OrtSession.SessionOptions sessionOptions = new OrtSession.SessionOptions();
        this.logger().info(new StringBuilder(38).append("ONNX session option intraOpNumThreads=").append(intraOpNumThreads).toString());
        sessionOptions.setIntraOpNumThreads(intraOpNumThreads);
        this.logger().info(new StringBuilder(38).append("ONNX session option optimizationLevel=").append(optimizationLevel).toString());
        sessionOptions.setOptimizationLevel(optimizationLevel);
        this.logger().info(new StringBuilder(34).append("ONNX session option executionMode=").append(executionMode).toString());
        sessionOptions.setExecutionMode(executionMode);
        return sessionOptions;
    }

    private Object readResolve() {
        return MODULE$;
    }

    private final OrtSession.SessionOptions.OptLevel getOptLevel$1(String optLevel, OrtSession.SessionOptions.OptLevel defaultOptLevel$1) {
        OrtSession.SessionOptions.OptLevel optLevel2;
        Try try_ = Try$.MODULE$.apply((Function0 & Serializable & scala.Serializable)() -> OrtSession.SessionOptions.OptLevel.valueOf((String)optLevel));
        if (try_ instanceof Success) {
            OrtSession.SessionOptions.OptLevel value;
            Success success = (Success)try_;
            optLevel2 = value = (OrtSession.SessionOptions.OptLevel)success.value();
        } else if (try_ instanceof Failure) {
            this.logger().warn(new StringBuilder(51).append("Error while getting OptLevel, using default value: ").append(defaultOptLevel$1.name()).toString());
            optLevel2 = defaultOptLevel$1;
        } else {
            throw new MatchError((Object)try_);
        }
        return optLevel2;
    }

    private final OrtSession.SessionOptions.ExecutionMode getExecutionMode$1(String executionMode, OrtSession.SessionOptions.ExecutionMode defaultExecutionMode$1) {
        OrtSession.SessionOptions.ExecutionMode executionMode2;
        Try try_ = Try$.MODULE$.apply((Function0 & Serializable & scala.Serializable)() -> OrtSession.SessionOptions.ExecutionMode.valueOf((String)executionMode));
        if (try_ instanceof Success) {
            OrtSession.SessionOptions.ExecutionMode value;
            Success success = (Success)try_;
            executionMode2 = value = (OrtSession.SessionOptions.ExecutionMode)success.value();
        } else if (try_ instanceof Failure) {
            this.logger().warn(new StringBuilder(57).append("Error while getting Execution Mode, using default value: ").append(defaultExecutionMode$1.name()).toString());
            executionMode2 = defaultExecutionMode$1;
        } else {
            throw new MatchError((Object)try_);
        }
        return executionMode2;
    }

    private OnnxWrapper$() {
        MODULE$ = this;
        this.logger = LoggerFactory.getLogger((String)"OnnxWrapper");
    }
}

