/*
 * Decompiled with CFR 0.152.
 */
package ai.tripl.arc.transform;

import ai.tripl.arc.api.API;
import ai.tripl.arc.api.API$ResponseType$DoubleResponse$;
import ai.tripl.arc.api.API$ResponseType$IntegerResponse$;
import ai.tripl.arc.transform.TensorFlowServingTransform;
import ai.tripl.arc.transform.TensorFlowServingTransformStage;
import ai.tripl.arc.util.DetailException;
import ai.tripl.arc.util.log.logger.Logger;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ArrayNode;
import com.fasterxml.jackson.databind.node.JsonNodeFactory;
import com.fasterxml.jackson.databind.node.ObjectNode;
import java.io.InputStream;
import java.io.Serializable;
import java.net.URI;
import org.apache.http.HttpEntity;
import org.apache.http.client.RedirectStrategy;
import org.apache.http.client.methods.CloseableHttpResponse;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.client.methods.HttpUriRequest;
import org.apache.http.conn.HttpClientConnectionManager;
import org.apache.http.entity.StringEntity;
import org.apache.http.impl.client.CloseableHttpClient;
import org.apache.http.impl.client.HttpClients;
import org.apache.http.impl.client.LaxRedirectStrategy;
import org.apache.http.impl.conn.PoolingHttpClientConnectionManager;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoder;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.Row$;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder;
import org.apache.spark.sql.catalyst.encoders.RowEncoder$;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DecimalType;
import org.apache.spark.sql.types.DoubleType;
import org.apache.spark.sql.types.DoubleType$;
import org.apache.spark.sql.types.FloatType;
import org.apache.spark.sql.types.IntegerType;
import org.apache.spark.sql.types.IntegerType$;
import org.apache.spark.sql.types.LongType;
import org.apache.spark.sql.types.NullType$;
import org.apache.spark.sql.types.StringType;
import org.apache.spark.sql.types.StringType$;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructField$;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.sql.types.StructType$;
import scala.Function1;
import scala.MatchError;
import scala.None$;
import scala.Option;
import scala.Option$;
import scala.Predef$;
import scala.Some;
import scala.Tuple15;
import scala.Tuple2;
import scala.collection.BufferedIterator;
import scala.collection.GenTraversableOnce;
import scala.collection.Iterator;
import scala.collection.JavaConverters$;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.TraversableLike;
import scala.collection.TraversableOnce;
import scala.collection.immutable.;
import scala.collection.immutable.List;
import scala.collection.immutable.List$;
import scala.collection.immutable.Nil$;
import scala.collection.mutable.ArrayOps;
import scala.collection.mutable.Map;
import scala.io.Codec$;
import scala.io.Source$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;

public final class TensorFlowServingTransformStage$
implements scala.Serializable {
    public static TensorFlowServingTransformStage$ MODULE$;

    static {
        new TensorFlowServingTransformStage$();
    }

    /*
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    public Option<Dataset<Row>> execute(TensorFlowServingTransformStage stage, SparkSession spark, Logger logger, API.ARCContext arcContext) {
        BoxedUnit boxedUnit;
        Dataset dataset;
        Dataset dataset2;
        StructType structType;
        Dataset df = spark.table(stage.inputView());
        URI stageUri = stage.uri();
        String stageInputField = stage.inputField();
        int stageBatchSize = stage.batchSize();
        Option<String> stageSignatureName = stage.signatureName();
        API.ResponseType stageResponseType = stage.responseType();
        if (!new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])df.columns())).contains((Object)stage.inputField())) {
            throw new DetailException(stage, df){
                private final Map<String, Object> detail;

                public Map<String, Object> detail() {
                    return this.detail;
                }
                {
                    this.detail = stage$2.stageDetail();
                }
            };
        }
        API.ResponseType responseType = stage.responseType();
        if (API$ResponseType$IntegerResponse$.MODULE$.equals(responseType)) {
            List list = new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])df.schema().fields())).toList();
            structType = StructType$.MODULE$.apply((Seq)new .colon.colon((Object)new StructField("result", (DataType)IntegerType$.MODULE$, true, StructField$.MODULE$.$lessinit$greater$default$4()), (List)Nil$.MODULE$).$colon$colon$colon(list));
        } else if (API$ResponseType$DoubleResponse$.MODULE$.equals(responseType)) {
            List list = new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])df.schema().fields())).toList();
            structType = StructType$.MODULE$.apply((Seq)new .colon.colon((Object)new StructField("result", (DataType)DoubleType$.MODULE$, true, StructField$.MODULE$.$lessinit$greater$default$4()), (List)Nil$.MODULE$).$colon$colon$colon(list));
        } else {
            List list = new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])df.schema().fields())).toList();
            structType = StructType$.MODULE$.apply((Seq)new .colon.colon((Object)new StructField("result", (DataType)StringType$.MODULE$, true, StructField$.MODULE$.$lessinit$greater$default$4()), (List)Nil$.MODULE$).$colon$colon$colon(list));
        }
        StructType tensorFlowResponseSchema = structType;
        ExpressionEncoder typedEncoder = RowEncoder$.MODULE$.apply(tensorFlowResponseSchema);
        try {
            dataset2 = df.mapPartitions((Function1 & Serializable & scala.Serializable)partition -> {
                NullType$ nullType$;
                int n;
                PoolingHttpClientConnectionManager poolingHttpClientConnectionManager = new PoolingHttpClientConnectionManager();
                poolingHttpClientConnectionManager.setMaxTotal(50);
                CloseableHttpClient httpClient = HttpClients.custom().setConnectionManager((HttpClientConnectionManager)poolingHttpClientConnectionManager).setRedirectStrategy((RedirectStrategy)new LaxRedirectStrategy()).build();
                URI uri = stageUri;
                int n2 = 200;
                int n3 = 201;
                int n4 = 202;
                List validStatusCodes = Nil$.MODULE$.$colon$colon((Object)BoxesRunTime.boxToInteger((int)n4)).$colon$colon((Object)BoxesRunTime.boxToInteger((int)n3)).$colon$colon((Object)BoxesRunTime.boxToInteger((int)n2));
                ObjectMapper objectMapper = new ObjectMapper();
                BufferedIterator bufferedPartition = partition.buffered();
                boolean bl = bufferedPartition.hasNext();
                if (bl) {
                    n = ((Row)bufferedPartition.head()).fieldIndex(stageInputField);
                } else if (!bl) {
                    n = 0;
                } else {
                    throw new MatchError((Object)BoxesRunTime.boxToBoolean((boolean)bl));
                }
                int fieldIndex = n;
                boolean bl2 = bufferedPartition.hasNext();
                if (bl2) {
                    nullType$ = ((Row)bufferedPartition.head()).schema().apply(fieldIndex).dataType();
                } else if (!bl2) {
                    nullType$ = NullType$.MODULE$;
                } else {
                    throw new MatchError((Object)BoxesRunTime.boxToBoolean((boolean)bl2));
                }
                NullType$ dataType = nullType$;
                Iterator.GroupedIterator groupedPartition = bufferedPartition.grouped(stageBatchSize);
                return groupedPartition.flatMap(arg_0 -> TensorFlowServingTransformStage$.$anonfun$execute$2(stageSignatureName, (DataType)dataType, objectMapper, fieldIndex, uri, httpClient, validStatusCodes, stageResponseType, arg_0));
            }, (Encoder)typedEncoder);
        }
        catch (Exception e) {
            throw new DetailException(e, stage){
                private final Map<String, Object> detail;

                public Map<String, Object> detail() {
                    return this.detail;
                }
                {
                    this.detail = stage$2.stageDetail();
                }
            };
        }
        Dataset transformedDF = dataset2;
        List<String> list = stage.partitionBy();
        if (Nil$.MODULE$.equals(list)) {
            Option<Object> option = stage.numPartitions();
            if (option instanceof Some) {
                Some some = (Some)option;
                int numPartitions = BoxesRunTime.unboxToInt((Object)some.value());
                dataset = transformedDF.repartition(numPartitions);
            } else {
                if (!None$.MODULE$.equals(option)) throw new MatchError(option);
                dataset = transformedDF;
            }
        } else {
            List partitionCols = (List)list.map((Function1 & Serializable & scala.Serializable)col -> transformedDF.apply(col), List$.MODULE$.canBuildFrom());
            Option<Object> option = stage.numPartitions();
            if (option instanceof Some) {
                Some some = (Some)option;
                int numPartitions = BoxesRunTime.unboxToInt((Object)some.value());
                dataset = transformedDF.repartition(numPartitions, (Seq)partitionCols);
            } else {
                if (!None$.MODULE$.equals(option)) throw new MatchError(option);
                dataset = transformedDF.repartition((Seq)partitionCols);
            }
        }
        Dataset repartitionedDF = dataset;
        if (arcContext.immutableViews()) {
            repartitionedDF.createTempView(stage.outputView());
        } else {
            repartitionedDF.createOrReplaceTempView(stage.outputView());
        }
        if (!repartitionedDF.isStreaming()) {
            stage.stageDetail().put((Object)"outputColumns", (Object)repartitionedDF.schema().length());
            stage.stageDetail().put((Object)"numPartitions", (Object)repartitionedDF.rdd().partitions().length);
            if (stage.persist()) {
                spark.catalog().cacheTable(stage.outputView(), arcContext.storageLevel());
                boxedUnit = stage.stageDetail().put((Object)"records", (Object)repartitionedDF.count());
                return Option$.MODULE$.apply((Object)repartitionedDF);
            } else {
                boxedUnit = BoxedUnit.UNIT;
            }
            return Option$.MODULE$.apply((Object)repartitionedDF);
        } else {
            boxedUnit = BoxedUnit.UNIT;
        }
        return Option$.MODULE$.apply((Object)repartitionedDF);
    }

    public TensorFlowServingTransformStage apply(TensorFlowServingTransform plugin, Option<String> id, String name, Option<String> description, String inputView, String outputView, URI uri, Option<String> signatureName, API.ResponseType responseType, int batchSize, String inputField, scala.collection.immutable.Map<String, String> params, boolean persist, Option<Object> numPartitions, List<String> partitionBy) {
        return new TensorFlowServingTransformStage(plugin, id, name, description, inputView, outputView, uri, signatureName, responseType, batchSize, inputField, params, persist, numPartitions, partitionBy);
    }

    public Option<Tuple15<TensorFlowServingTransform, Option<String>, String, Option<String>, String, String, URI, Option<String>, API.ResponseType, Object, String, scala.collection.immutable.Map<String, String>, Object, Option<Object>, List<String>>> unapply(TensorFlowServingTransformStage x$0) {
        if (x$0 == null) {
            return None$.MODULE$;
        }
        return new Some((Object)new Tuple15((Object)x$0.plugin(), x$0.id(), (Object)x$0.name(), x$0.description(), (Object)x$0.inputView(), (Object)x$0.outputView(), (Object)x$0.uri(), x$0.signatureName(), (Object)x$0.responseType(), (Object)BoxesRunTime.boxToInteger((int)x$0.batchSize()), (Object)x$0.inputField(), x$0.params(), (Object)BoxesRunTime.boxToBoolean((boolean)x$0.persist()), x$0.numPartitions(), x$0.partitionBy()));
    }

    private Object readResolve() {
        return MODULE$;
    }

    public static final /* synthetic */ GenTraversableOnce $anonfun$execute$2(Option stageSignatureName$1, DataType dataType$1, ObjectMapper objectMapper$1, int fieldIndex$1, URI uri$1, CloseableHttpClient httpClient$1, List validStatusCodes$1, API.ResponseType stageResponseType$1, Seq groupedRow) {
        List list;
        JsonNodeFactory jsonNodeFactory = new JsonNodeFactory(true);
        ObjectNode node = jsonNodeFactory.objectNode();
        stageSignatureName$1.foreach((Function1 & Serializable & scala.Serializable)x$29 -> node.put("signature_name", x$29));
        ArrayNode instancesArray = node.putArray("instances");
        groupedRow.foreach((Function1 & Serializable & scala.Serializable)row -> {
            DataType dataType = dataType$1;
            if (dataType instanceof StringType) {
                return instancesArray.add(objectMapper$1.readTree(row.getString(fieldIndex$1)));
            }
            if (dataType instanceof IntegerType) {
                return instancesArray.add(row.getInt(fieldIndex$1));
            }
            if (dataType instanceof LongType) {
                return instancesArray.add(row.getLong(fieldIndex$1));
            }
            if (dataType instanceof FloatType) {
                return instancesArray.add(row.getFloat(fieldIndex$1));
            }
            if (dataType instanceof DoubleType) {
                return instancesArray.add(row.getDouble(fieldIndex$1));
            }
            if (dataType instanceof DecimalType) {
                return instancesArray.add(row.getDecimal(fieldIndex$1));
            }
            throw new MatchError((Object)dataType);
        });
        HttpPost post = new HttpPost(uri$1);
        try {
            post.setEntity((HttpEntity)new StringEntity(objectMapper$1.writeValueAsString((Object)node)));
            CloseableHttpResponse response = httpClient$1.execute((HttpUriRequest)post);
            InputStream responseEntity = response.getEntity().getContent();
            String body = Source$.MODULE$.fromInputStream(responseEntity, Codec$.MODULE$.fallbackSystemCodec()).mkString();
            response.close();
            if (!validStatusCodes$1.contains((Object)BoxesRunTime.boxToInteger((int)response.getStatusLine().getStatusCode()))) {
                throw new Exception(body);
            }
            JsonNode rootNode = objectMapper$1.readTree(body);
            list = ((TraversableOnce)JavaConverters$.MODULE$.iterableAsScalaIterableConverter((Iterable)rootNode.get("predictions")).asScala()).toList();
        }
        finally {
            post.releaseConnection();
        }
        List response = list;
        return (GenTraversableOnce)((TraversableLike)groupedRow.zipWithIndex(Seq$.MODULE$.canBuildFrom())).map((Function1 & Serializable & scala.Serializable)x0$1 -> {
            Tuple2 tuple2 = x0$1;
            if (tuple2 != null) {
                Row row = (Row)tuple2._1();
                int index = tuple2._2$mcI$sp();
                API.ResponseType responseType = stageResponseType$1;
                Seq result = API$ResponseType$IntegerResponse$.MODULE$.equals(responseType) ? (Seq)Seq$.MODULE$.apply((Seq)Predef$.MODULE$.wrapIntArray(new int[]{((JsonNode)response.apply(index)).asInt()})) : (API$ResponseType$DoubleResponse$.MODULE$.equals(responseType) ? (Seq)Seq$.MODULE$.apply((Seq)Predef$.MODULE$.wrapDoubleArray(new double[]{((JsonNode)response.apply(index)).asDouble()})) : (Seq)new .colon.colon((Object)((JsonNode)response.apply(index)).asText(), (List)Nil$.MODULE$));
                return Row$.MODULE$.fromSeq((Seq)row.toSeq().$plus$plus((GenTraversableOnce)result, Seq$.MODULE$.canBuildFrom()));
            }
            throw new MatchError((Object)tuple2);
        }, Seq$.MODULE$.canBuildFrom());
    }

    private TensorFlowServingTransformStage$() {
        MODULE$ = this;
    }
}

