/*
 * Decompiled with CFR 0.152.
 */
package ai.tripl.arc.transform;

import ai.tripl.arc.api.API;
import ai.tripl.arc.transform.MLTransform;
import ai.tripl.arc.transform.MLTransformStage;
import ai.tripl.arc.transform.MLTransformStage$;
import ai.tripl.arc.util.DetailException;
import ai.tripl.arc.util.log.logger.Logger;
import java.io.Serializable;
import java.net.URI;
import java.util.HashMap;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.Transformer;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.tuning.CrossValidatorModel;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.expressions.UserDefinedFunction;
import org.apache.spark.sql.functions$;
import scala.Array$;
import scala.Function0;
import scala.Function1;
import scala.MatchError;
import scala.None$;
import scala.Option;
import scala.Option$;
import scala.Predef$;
import scala.Some;
import scala.Tuple11;
import scala.collection.GenTraversableOnce;
import scala.collection.Seq;
import scala.collection.immutable.List;
import scala.collection.immutable.List$;
import scala.collection.immutable.Nil$;
import scala.collection.mutable.ArrayOps;
import scala.collection.mutable.Map;
import scala.math.Ordering;
import scala.reflect.ClassTag$;
import scala.reflect.api.JavaUniverse;
import scala.reflect.api.Mirror;
import scala.reflect.api.TypeCreator;
import scala.reflect.api.TypeTags;
import scala.reflect.api.Types;
import scala.reflect.api.Universe;
import scala.reflect.runtime.package$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.ObjectRef;
import scala.util.Either;
import scala.util.Left;
import scala.util.Right;

public final class MLTransformStage$
implements scala.Serializable {
    public static MLTransformStage$ MODULE$;

    static {
        new MLTransformStage$();
    }

    public Option<Dataset<Row>> execute(MLTransformStage stage2, SparkSession spark, Logger logger, API.ARCContext arcContext) {
        Object object;
        Dataset dataset;
        Dataset dataset2;
        Transformer[] transformerArray;
        CrossValidatorModel crossValidatorModel;
        Dataset df = spark.table(stage2.inputView());
        Either<PipelineModel, CrossValidatorModel> either = stage2.model();
        if (either instanceof Right) {
            CrossValidatorModel crossValidatorModel2;
            Right right = (Right)either;
            crossValidatorModel = crossValidatorModel2 = (CrossValidatorModel)right.value();
        } else if (either instanceof Left) {
            Left left = (Left)either;
            PipelineModel pipelineModel = (PipelineModel)left.value();
            crossValidatorModel = pipelineModel;
        } else {
            throw new MatchError(either);
        }
        CrossValidatorModel model = crossValidatorModel;
        try {
            Transformer[] transformerArray2;
            Either<PipelineModel, CrossValidatorModel> either2 = stage2.model();
            if (either2 instanceof Right) {
                Right right = (Right)either2;
                CrossValidatorModel crossValidatorModel3 = (CrossValidatorModel)right.value();
                transformerArray2 = ((PipelineModel)crossValidatorModel3.bestModel()).stages();
            } else if (either2 instanceof Left) {
                Left left = (Left)either2;
                PipelineModel pipelineModel = (PipelineModel)left.value();
                transformerArray2 = pipelineModel.stages();
            } else {
                throw new MatchError(either2);
            }
            transformerArray = transformerArray2;
        }
        catch (Exception e) {
            throw new DetailException(e, stage2){
                private final Map<String, Object> detail;

                public Map<String, Object> detail() {
                    return this.detail;
                }
                {
                    this.detail = stage$1.stageDetail();
                }
            };
        }
        Transformer[] stages = transformerArray;
        try {
            dataset2 = model.transform(df);
        }
        catch (Exception e) {
            throw new DetailException(e, stage2){
                private final Map<String, Object> detail;

                public Map<String, Object> detail() {
                    return this.detail;
                }
                {
                    this.detail = stage$1.stageDetail();
                }
            };
        }
        Dataset fullTransformedDF = dataset2;
        Column[] inputCols = (Column[])new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])df.schema().fields())).map((Function1 & Serializable & scala.Serializable)f -> functions$.MODULE$.col(f.name()), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Column.class)));
        Column[] predictionCols = (Column[])new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])Predef$.MODULE$.genericArrayOps(new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])stages)).filter((Function1 & Serializable & scala.Serializable)stage -> BoxesRunTime.boxToBoolean((boolean)stage.hasParam("predictionCol"))))).map((Function1 & Serializable & scala.Serializable)stage -> stage.get(stage.getParam("predictionCol")), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Option.class))))).map((Function1 & Serializable & scala.Serializable)predictionCol -> predictionCol.getOrElse((Function0 & Serializable & scala.Serializable)() -> "prediction"), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.Any()))).map((Function1 & Serializable & scala.Serializable)x$15 -> x$15.toString(), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(String.class))))).map((Function1 & Serializable & scala.Serializable)x$16 -> functions$.MODULE$.col(x$16), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Column.class)));
        Column[] probabilityCols = (Column[])new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])Predef$.MODULE$.genericArrayOps(new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])stages)).filter((Function1 & Serializable & scala.Serializable)stage -> BoxesRunTime.boxToBoolean((boolean)stage.hasParam("probabilityCol"))))).map((Function1 & Serializable & scala.Serializable)stage -> stage.get(stage.getParam("probabilityCol")), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Option.class))))).map((Function1 & Serializable & scala.Serializable)predictionCol -> predictionCol.getOrElse((Function0 & Serializable & scala.Serializable)() -> "probability"), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.Any()))).map((Function1 & Serializable & scala.Serializable)x$17 -> x$17.toString(), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(String.class))))).map((Function1 & Serializable & scala.Serializable)x$18 -> functions$.MODULE$.col(x$18), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Column.class)));
        ObjectRef transformedDF = ObjectRef.create((Object)fullTransformedDF.select((Seq)Predef$.MODULE$.wrapRefArray((Object[])new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])inputCols)).$plus$plus((GenTraversableOnce)new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])predictionCols)), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Column.class))))).$plus$plus((GenTraversableOnce)new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])probabilityCols)), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Column.class))))));
        JavaUniverse $u = package$.MODULE$.universe();
        JavaUniverse.JavaMirror $m = package$.MODULE$.universe().runtimeMirror(this.getClass().getClassLoader());
        public final class Ai_tripl_arc_transform_MLTransformStage$$typecreator1$1
        extends TypeCreator {
            public <U extends Universe> Types.TypeApi apply(Mirror<U> $m$untyped) {
                Universe $u = $m$untyped.universe();
                Mirror<U> $m = $m$untyped;
                return $m.staticClass("org.apache.spark.ml.linalg.Vector").asType().toTypeConstructor();
            }

            public Ai_tripl_arc_transform_MLTransformStage$$typecreator1$1() {
            }
        }
        UserDefinedFunction maxProbability = functions$.MODULE$.udf((Function1 & Serializable & scala.Serializable)v -> BoxesRunTime.boxToDouble((double)MLTransformStage$.$anonfun$execute$14(v)), ((TypeTags)package$.MODULE$.universe()).TypeTag().Double(), ((TypeTags)$u).TypeTag().apply((Mirror)$m, (TypeCreator)new Ai_tripl_arc_transform_MLTransformStage$$typecreator1$1()));
        new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])probabilityCols)).foreach((Function1 & Serializable & scala.Serializable)col -> {
            transformedDF.elem = ((Dataset)transformedDF.elem).withColumn(String.valueOf(col), maxProbability.apply((Seq)Predef$.MODULE$.wrapRefArray((Object[])new Column[]{col})));
            return BoxedUnit.UNIT;
        });
        List<String> list = stage2.partitionBy();
        if (Nil$.MODULE$.equals(list)) {
            Dataset dataset3;
            Option<Object> option = stage2.numPartitions();
            if (option instanceof Some) {
                Some some = (Some)option;
                int numPartitions = BoxesRunTime.unboxToInt((Object)some.value());
                dataset3 = ((Dataset)transformedDF.elem).repartition(numPartitions);
            } else if (None$.MODULE$.equals(option)) {
                dataset3 = (Dataset)transformedDF.elem;
            } else {
                throw new MatchError(option);
            }
            dataset = dataset3;
        } else {
            Dataset dataset4;
            List partitionCols = (List)list.map((Function1 & Serializable & scala.Serializable)col -> ((Dataset)transformedDF$1.elem).apply(col), List$.MODULE$.canBuildFrom());
            Option<Object> option = stage2.numPartitions();
            if (option instanceof Some) {
                Some some = (Some)option;
                int numPartitions = BoxesRunTime.unboxToInt((Object)some.value());
                dataset4 = ((Dataset)transformedDF.elem).repartition(numPartitions, (Seq)partitionCols);
            } else if (None$.MODULE$.equals(option)) {
                dataset4 = ((Dataset)transformedDF.elem).repartition((Seq)partitionCols);
            } else {
                throw new MatchError(option);
            }
            dataset = dataset4;
        }
        Dataset repartitionedDF = dataset;
        if (arcContext.immutableViews()) {
            repartitionedDF.createTempView(stage2.outputView());
        } else {
            repartitionedDF.createOrReplaceTempView(stage2.outputView());
        }
        if (!repartitionedDF.isStreaming()) {
            stage2.stageDetail().put((Object)"outputColumns", (Object)repartitionedDF.schema().length());
            stage2.stageDetail().put((Object)"numPartitions", (Object)repartitionedDF.rdd().partitions().length);
            if (stage2.persist()) {
                repartitionedDF.persist(arcContext.storageLevel());
                stage2.stageDetail().put((Object)"records", (Object)repartitionedDF.count());
                HashMap approxQuantileMap = new HashMap();
                new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])probabilityCols)).foreach((Function1 & Serializable & scala.Serializable)col2 -> (Double[])approxQuantileMap.put(col2.toString(), new ArrayOps.ofDouble(Predef$.MODULE$.doubleArrayOps(repartitionedDF.stat().approxQuantile(col2.toString(), new double[]{0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0}, 0.1))).map((Function1 & Serializable & scala.Serializable)col -> BoxesRunTime.unboxToDouble((Object)col), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Double.class)))));
                object = approxQuantileMap.size() > 0 ? stage2.stageDetail().put((Object)"percentiles", approxQuantileMap) : BoxedUnit.UNIT;
            } else {
                object = BoxedUnit.UNIT;
            }
        } else {
            object = BoxedUnit.UNIT;
        }
        return Option$.MODULE$.apply((Object)repartitionedDF);
    }

    public MLTransformStage apply(MLTransform plugin, String name, Option<String> description, URI inputURI, Either<PipelineModel, CrossValidatorModel> model, String inputView, String outputView, scala.collection.immutable.Map<String, String> params, boolean persist, Option<Object> numPartitions, List<String> partitionBy) {
        return new MLTransformStage(plugin, name, description, inputURI, model, inputView, outputView, params, persist, numPartitions, partitionBy);
    }

    public Option<Tuple11<MLTransform, String, Option<String>, URI, Either<PipelineModel, CrossValidatorModel>, String, String, scala.collection.immutable.Map<String, String>, Object, Option<Object>, List<String>>> unapply(MLTransformStage x$0) {
        return x$0 == null ? None$.MODULE$ : new Some((Object)new Tuple11((Object)x$0.plugin(), (Object)x$0.name(), x$0.description(), (Object)x$0.inputURI(), x$0.model(), (Object)x$0.inputView(), (Object)x$0.outputView(), x$0.params(), (Object)BoxesRunTime.boxToBoolean((boolean)x$0.persist()), x$0.numPartitions(), x$0.partitionBy()));
    }

    private Object readResolve() {
        return MODULE$;
    }

    public static final /* synthetic */ double $anonfun$execute$14(Vector v) {
        return BoxesRunTime.unboxToDouble((Object)new ArrayOps.ofDouble(Predef$.MODULE$.doubleArrayOps(v.toArray())).max((Ordering)Ordering.Double$.MODULE$));
    }

    private MLTransformStage$() {
        MODULE$ = this;
    }
}

