/*
 * Decompiled with CFR 0.152.
 */
package com.microsoft.azure.synapse.ml.onnx;

import ai.onnxruntime.NodeInfo;
import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OnnxValue;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import ai.onnxruntime.TensorInfo;
import ai.onnxruntime.ValueInfo;
import com.microsoft.azure.synapse.ml.core.env.StreamUtilities$;
import com.microsoft.azure.synapse.ml.core.utils.CloseableIterator;
import com.microsoft.azure.synapse.ml.onnx.ONNXUtils$;
import java.io.Serializable;
import java.util.Map;
import org.apache.spark.TaskContext$;
import org.apache.spark.internal.Logging;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.Row$;
import org.apache.spark.sql.types.StructType;
import org.slf4j.Logger;
import scala.Array$;
import scala.Function0;
import scala.Function1;
import scala.MatchError;
import scala.None$;
import scala.NotImplementedError;
import scala.Option;
import scala.Predef$;
import scala.Some;
import scala.Tuple2;
import scala.collection.GenTraversableOnce;
import scala.collection.Iterator;
import scala.collection.JavaConverters$;
import scala.collection.MapLike;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.TraversableLike;
import scala.collection.TraversableOnce;
import scala.collection.immutable.Iterable$;
import scala.collection.immutable.StringOps;
import scala.collection.mutable.ArrayOps;
import scala.collection.mutable.Map$;
import scala.jdk.CollectionConverters$;
import scala.reflect.ClassTag$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.java8.JFunction0;
import scala.runtime.java8.JFunction1;

public final class ONNXRuntime$
implements Logging {
    public static ONNXRuntime$ MODULE$;
    private transient Logger org$apache$spark$internal$Logging$$log_;

    static {
        new ONNXRuntime$();
    }

    public String logName() {
        return Logging.logName$((Logging)this);
    }

    public Logger log() {
        return Logging.log$((Logging)this);
    }

    public void logInfo(Function0<String> msg) {
        Logging.logInfo$((Logging)this, msg);
    }

    public void logDebug(Function0<String> msg) {
        Logging.logDebug$((Logging)this, msg);
    }

    public void logTrace(Function0<String> msg) {
        Logging.logTrace$((Logging)this, msg);
    }

    public void logWarning(Function0<String> msg) {
        Logging.logWarning$((Logging)this, msg);
    }

    public void logError(Function0<String> msg) {
        Logging.logError$((Logging)this, msg);
    }

    public void logInfo(Function0<String> msg, Throwable throwable) {
        Logging.logInfo$((Logging)this, msg, (Throwable)throwable);
    }

    public void logDebug(Function0<String> msg, Throwable throwable) {
        Logging.logDebug$((Logging)this, msg, (Throwable)throwable);
    }

    public void logTrace(Function0<String> msg, Throwable throwable) {
        Logging.logTrace$((Logging)this, msg, (Throwable)throwable);
    }

    public void logWarning(Function0<String> msg, Throwable throwable) {
        Logging.logWarning$((Logging)this, msg, (Throwable)throwable);
    }

    public void logError(Function0<String> msg, Throwable throwable) {
        Logging.logError$((Logging)this, msg, (Throwable)throwable);
    }

    public boolean isTraceEnabled() {
        return Logging.isTraceEnabled$((Logging)this);
    }

    public void initializeLogIfNecessary(boolean isInterpreter) {
        Logging.initializeLogIfNecessary$((Logging)this, (boolean)isInterpreter);
    }

    public boolean initializeLogIfNecessary(boolean isInterpreter, boolean silent) {
        return Logging.initializeLogIfNecessary$((Logging)this, (boolean)isInterpreter, (boolean)silent);
    }

    public boolean initializeLogIfNecessary$default$2() {
        return Logging.initializeLogIfNecessary$default$2$((Logging)this);
    }

    public void initializeForcefully(boolean isInterpreter, boolean silent) {
        Logging.initializeForcefully$((Logging)this, (boolean)isInterpreter, (boolean)silent);
    }

    public Logger org$apache$spark$internal$Logging$$log_() {
        return this.org$apache$spark$internal$Logging$$log_;
    }

    public void org$apache$spark$internal$Logging$$log__$eq(Logger x$1) {
        this.org$apache$spark$internal$Logging$$log_ = x$1;
    }

    public OrtSession createOrtSession(byte[] modelContent, OrtEnvironment ortEnv, OrtSession.SessionOptions.OptLevel optLevel, Option<Object> gpuDeviceId) {
        OrtSession.SessionOptions options = new OrtSession.SessionOptions();
        try {
            gpuDeviceId.foreach((Function1)(JFunction1.mcVI.sp & Serializable & scala.Serializable)x$1 -> options.addCUDA(x$1));
        }
        catch (Throwable throwable) {
            Throwable throwable2 = throwable;
            if (throwable2 instanceof OrtException) {
                OrtException ortException = (OrtException)throwable2;
                OrtException.OrtErrorCode ortErrorCode = ortException.getCode();
                OrtException.OrtErrorCode ortErrorCode2 = OrtException.OrtErrorCode.ORT_INVALID_ARGUMENT;
                if (!(ortErrorCode != null ? !ortErrorCode.equals(ortErrorCode2) : ortErrorCode2 != null)) {
                    String err = new StringBuilder(274).append("GPU device is found on executor nodes with id ").append(gpuDeviceId.get()).append(", ").append("but adding CUDA support failed. Most likely the ONNX runtime supplied to the cluster ").append("does not support GPU. Please install com.microsoft.onnxruntime:onnxruntime_gpu:{version} ").append("instead for optimal performance. Exception details: ").append(ortException.toString()).toString();
                    this.logError((Function0<String>)(Function0 & Serializable & scala.Serializable)() -> err);
                    BoxedUnit boxedUnit = BoxedUnit.UNIT;
                }
            }
            throw throwable;
        }
        options.setOptimizationLevel(optLevel);
        return ortEnv.createSession(modelContent, options);
    }

    public OrtSession.SessionOptions.OptLevel createOrtSession$default$3() {
        return OrtSession.SessionOptions.OptLevel.ALL_OPT;
    }

    public Option<Object> createOrtSession$default$4() {
        return None$.MODULE$;
    }

    public Option<Object> selectGpuDevice(Option<String> deviceType) {
        None$ none$;
        Some some;
        String string;
        Option<String> option = deviceType;
        boolean bl = None$.MODULE$.equals(option) ? true : option instanceof Some && "CUDA".equals(string = (String)(some = (Some)option).value());
        if (bl) {
            Option gpuNum = TaskContext$.MODULE$.get().resources().get((Object)"gpu").flatMap((Function1 & Serializable & scala.Serializable)x$1 -> new ArrayOps.ofInt(Predef$.MODULE$.intArrayOps((int[])new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])x$1.addresses())).map((Function1 & Serializable & scala.Serializable)x$2 -> BoxesRunTime.boxToInteger((int)ONNXRuntime$.$anonfun$selectGpuDevice$2(x$2)), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.Int())))).headOption());
            none$ = gpuNum;
        } else {
            Some some2;
            String string2;
            none$ = option instanceof Some && "CPU".equals(string2 = (String)(some2 = (Some)option).value()) ? None$.MODULE$ : None$.MODULE$;
        }
        return none$;
    }

    public Iterator<Row> applyModel(OrtSession session, OrtEnvironment env, scala.collection.immutable.Map<String, String> feedMap, scala.collection.immutable.Map<String, String> fetchMap, StructType inputSchema, Iterator<Row> rows) {
        Iterator results = rows.map((Function1 & Serializable & scala.Serializable)row -> {
            scala.collection.mutable.Map inputTensors = (scala.collection.mutable.Map)((TraversableLike)CollectionConverters$.MODULE$.mapAsScalaMapConverter(session.getInputInfo()).asScala()).map((Function1 & Serializable & scala.Serializable)x0$1 -> {
                Tuple2 tuple2;
                Tuple2 tuple22 = x0$1;
                if (tuple22 != null) {
                    String inputName = (String)tuple22._1();
                    NodeInfo inputNodeInfo = (NodeInfo)tuple22._2();
                    Seq batchedValues = (Seq)row.getAs((String)feedMap.apply((Object)inputName));
                    ValueInfo valueInfo = inputNodeInfo.getInfo();
                    if (!(valueInfo instanceof TensorInfo)) {
                        throw new NotImplementedError(new StringBuilder(54).append("Only tensor input type is supported, but got ").append(valueInfo).append(" instead.").toString());
                    }
                    TensorInfo tensorInfo = (TensorInfo)valueInfo;
                    OnnxTensor tensor = ONNXUtils$.MODULE$.createTensor(env, tensorInfo, batchedValues);
                    tuple2 = new Tuple2((Object)inputName, (Object)tensor);
                } else {
                    throw new MatchError((Object)tuple22);
                }
                Tuple2 tuple23 = tuple2;
                return tuple23;
            }, Map$.MODULE$.canBuildFrom());
            Seq outputBatches = (Seq)StreamUtilities$.MODULE$.using((AutoCloseable)session.run((Map)JavaConverters$.MODULE$.mutableMapAsJavaMapConverter(inputTensors).asJava()), (Function1 & Serializable & scala.Serializable)result -> ((TraversableOnce)fetchMap.map((Function1 & Serializable & scala.Serializable)x0$2 -> {
                Tuple2 tuple2 = x0$2;
                if (tuple2 == null) {
                    throw new MatchError((Object)tuple2);
                }
                String outputName = (String)tuple2._2();
                int i = ((MapLike)CollectionConverters$.MODULE$.mapAsScalaMapConverter(session.getOutputInfo()).asScala()).keysIterator().indexOf((Object)outputName);
                OnnxValue outputValue = result.get(i);
                Seq<Object> seq = ONNXUtils$.MODULE$.mapOnnxValueToArray(outputValue);
                return seq;
            }, Iterable$.MODULE$.canBuildFrom())).toSeq()).get();
            inputTensors.valuesIterator().foreach((Function1 & Serializable & scala.Serializable)x$3 -> {
                x$3.close();
                return BoxedUnit.UNIT;
            });
            Seq data = (Seq)inputSchema.map((Function1 & Serializable & scala.Serializable)f -> row.getAs(f.name()), Seq$.MODULE$.canBuildFrom());
            return Row$.MODULE$.fromSeq((Seq)data.$plus$plus((GenTraversableOnce)outputBatches, Seq$.MODULE$.canBuildFrom()));
        });
        return new CloseableIterator(results, (Function0)(JFunction0.mcV.sp & Serializable & scala.Serializable)() -> {
            session.close();
            env.close();
        });
    }

    public static final /* synthetic */ int $anonfun$selectGpuDevice$2(String x$2) {
        return new StringOps(Predef$.MODULE$.augmentString(x$2)).toInt();
    }

    private ONNXRuntime$() {
        MODULE$ = this;
        Logging.$init$((Logging)this);
    }
}

