/*
 * Decompiled with CFR 0.152.
 */
package com.johnsnowlabs.ml.onnx;

import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtSession;
import ai.onnxruntime.providers.OrtCUDAProviderOptions;
import com.johnsnowlabs.ml.onnx.OnnxWrapper;
import com.johnsnowlabs.util.FileHelper$;
import com.johnsnowlabs.util.ZipArchiveUtil$;
import java.io.File;
import java.io.Serializable;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.nio.file.attribute.FileAttribute;
import java.util.EnumSet;
import java.util.UUID;
import org.apache.commons.io.FileUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.Array$;
import scala.Function1;
import scala.MatchError;
import scala.None$;
import scala.Option;
import scala.Predef$;
import scala.Some;
import scala.Tuple2;
import scala.Tuple3;
import scala.collection.immutable.StringOps;
import scala.collection.mutable.ArrayOps;
import scala.reflect.ClassTag$;

public final class OnnxWrapper$
implements scala.Serializable {
    public static OnnxWrapper$ MODULE$;
    private final Logger logger;

    static {
        new OnnxWrapper$();
    }

    public Logger logger() {
        return this.logger;
    }

    public synchronized Tuple2<OrtSession, OrtEnvironment> com$johnsnowlabs$ml$onnx$OnnxWrapper$$withSafeOnnxModelLoader(byte[] onnxModel, Option<OrtSession.SessionOptions> sessionOptions) {
        OrtEnvironment env = OrtEnvironment.getEnvironment();
        OrtSession.SessionOptions opts = sessionOptions.isDefined() ? (OrtSession.SessionOptions)sessionOptions.get() : new OrtSession.SessionOptions();
        EnumSet providers = OrtEnvironment.getAvailableProviders();
        if (new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(providers.toArray())).map((Function1 & Serializable & scala.Serializable)x -> x.toString(), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(String.class))))).contains((Object)"CUDA")) {
            this.logger().info("using CUDA");
            int gpuDeviceId = 0;
            OrtCUDAProviderOptions cudaOpts = new OrtCUDAProviderOptions(gpuDeviceId);
            opts.addCUDA(cudaOpts);
        } else {
            this.logger().info("using CPUs");
            opts.setIntraOpNumThreads(6);
            opts.setOptimizationLevel(OrtSession.SessionOptions.OptLevel.ALL_OPT);
            opts.setExecutionMode(OrtSession.SessionOptions.ExecutionMode.SEQUENTIAL);
        }
        OrtSession session = env.createSession(onnxModel, opts);
        return new Tuple2((Object)session, (Object)env);
    }

    private Option<OrtSession.SessionOptions> withSafeOnnxModelLoader$default$2() {
        return None$.MODULE$;
    }

    public OnnxWrapper read(String modelPath, boolean zipped, boolean useBundle, String modelName, Option<OrtSession.SessionOptions> sessionOptions) {
        Tuple3 tuple3;
        Tuple3 tuple32;
        String folder;
        String tmpFolder = ((Object)Files.createTempDirectory(new StringBuilder(5).append((String)new StringOps(Predef$.MODULE$.augmentString(UUID.randomUUID().toString())).takeRight(12)).append("_onnx").toString(), new FileAttribute[0]).toAbsolutePath()).toString();
        String string = folder = zipped ? ZipArchiveUtil$.MODULE$.unzip(new File(modelPath), (Option<String>)new Some((Object)tmpFolder)) : modelPath;
        if (useBundle) {
            String onnxFile = ((Object)Paths.get(modelPath, new StringBuilder(5).append(modelName).append(".onnx").toString())).toString();
            File modelFile = new File(onnxFile);
            byte[] modelBytes = FileUtils.readFileToByteArray((File)modelFile);
            Tuple2<OrtSession, OrtEnvironment> tuple2 = this.com$johnsnowlabs$ml$onnx$OnnxWrapper$$withSafeOnnxModelLoader(modelBytes, sessionOptions);
            if (tuple2 == null) {
                throw new MatchError(tuple2);
            }
            OrtSession session = (OrtSession)tuple2._1();
            OrtEnvironment env = (OrtEnvironment)tuple2._2();
            Tuple2 tuple22 = new Tuple2((Object)session, (Object)env);
            Tuple2 tuple23 = tuple22;
            OrtSession session2 = (OrtSession)tuple23._1();
            OrtEnvironment env2 = (OrtEnvironment)tuple23._2();
            tuple32 = new Tuple3((Object)session2, (Object)env2, (Object)modelBytes);
        } else {
            String modelFile = (String)new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])new File(folder).list())).head();
            File fullPath = Paths.get(folder, modelFile).toFile();
            byte[] modelBytes = FileUtils.readFileToByteArray((File)fullPath);
            Tuple2<OrtSession, OrtEnvironment> tuple2 = this.com$johnsnowlabs$ml$onnx$OnnxWrapper$$withSafeOnnxModelLoader(modelBytes, sessionOptions);
            if (tuple2 == null) {
                throw new MatchError(tuple2);
            }
            OrtSession session = (OrtSession)tuple2._1();
            OrtEnvironment env = (OrtEnvironment)tuple2._2();
            Tuple2 tuple24 = new Tuple2((Object)session, (Object)env);
            Tuple2 tuple25 = tuple24;
            OrtSession session3 = (OrtSession)tuple25._1();
            OrtEnvironment env3 = (OrtEnvironment)tuple25._2();
            tuple32 = tuple3 = new Tuple3((Object)session3, (Object)env3, (Object)modelBytes);
        }
        if (tuple3 == null) {
            throw new MatchError((Object)tuple3);
        }
        OrtSession session = (OrtSession)tuple3._1();
        OrtEnvironment env = (OrtEnvironment)tuple3._2();
        byte[] modelBytes = (byte[])tuple3._3();
        Tuple3 tuple33 = new Tuple3((Object)session, (Object)env, (Object)modelBytes);
        Tuple3 tuple34 = tuple33;
        OrtSession session4 = (OrtSession)tuple34._1();
        OrtEnvironment env4 = (OrtEnvironment)tuple34._2();
        byte[] modelBytes2 = (byte[])tuple34._3();
        FileHelper$.MODULE$.delete(tmpFolder, FileHelper$.MODULE$.delete$default$2());
        OnnxWrapper onnxWrapper = new OnnxWrapper(modelBytes2);
        onnxWrapper.com$johnsnowlabs$ml$onnx$OnnxWrapper$$m_session_$eq(session4);
        onnxWrapper.com$johnsnowlabs$ml$onnx$OnnxWrapper$$m_env_$eq(env4);
        return onnxWrapper;
    }

    public boolean read$default$2() {
        return true;
    }

    public boolean read$default$3() {
        return false;
    }

    public String read$default$4() {
        return "model";
    }

    public Option<OrtSession.SessionOptions> read$default$5() {
        return None$.MODULE$;
    }

    private Object readResolve() {
        return MODULE$;
    }

    private OnnxWrapper$() {
        MODULE$ = this;
        this.logger = LoggerFactory.getLogger((String)"OnnxWrapper");
    }
}

