/*
 * Decompiled with CFR 0.152.
 */
package ai.tripl.arc.transform;

import ai.tripl.arc.api.API;
import ai.tripl.arc.transform.SimilarityJoinTransform;
import ai.tripl.arc.transform.SimilarityJoinTransformStage;
import ai.tripl.arc.transform.SimilarityJoinTransformStage$;
import ai.tripl.arc.util.DetailException;
import ai.tripl.arc.util.log.logger.Logger;
import java.io.Serializable;
import java.util.UUID;
import org.apache.spark.ml.feature.CountVectorizer;
import org.apache.spark.ml.feature.CountVectorizerModel;
import org.apache.spark.ml.feature.MinHashLSH;
import org.apache.spark.ml.feature.MinHashLSHModel;
import org.apache.spark.ml.feature.NGram;
import org.apache.spark.ml.feature.RegexTokenizer;
import org.apache.spark.ml.linalg.SparseVector;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.expressions.UserDefinedFunction;
import org.apache.spark.sql.functions$;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.StringType$;
import scala.Array$;
import scala.Function1;
import scala.MatchError;
import scala.None$;
import scala.Option;
import scala.Option$;
import scala.Predef$;
import scala.Some;
import scala.Tuple17;
import scala.collection.GenTraversableOnce;
import scala.collection.Seq;
import scala.collection.immutable.;
import scala.collection.immutable.List;
import scala.collection.immutable.List$;
import scala.collection.immutable.Nil$;
import scala.collection.mutable.ArrayOps;
import scala.collection.mutable.Map;
import scala.reflect.ClassTag$;
import scala.reflect.api.JavaUniverse;
import scala.reflect.api.Mirror;
import scala.reflect.api.TypeCreator;
import scala.reflect.api.TypeTags;
import scala.reflect.api.Types;
import scala.reflect.api.Universe;
import scala.reflect.runtime.package$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;

public final class SimilarityJoinTransformStage$
implements scala.Serializable {
    public static SimilarityJoinTransformStage$ MODULE$;

    static {
        new SimilarityJoinTransformStage$();
    }

    public Option<Dataset<Row>> execute(SimilarityJoinTransformStage stage, SparkSession spark, Logger logger, API.ARCContext arcContext) {
        BoxedUnit boxedUnit;
        Dataset dataset;
        Dataset dataset2;
        String uuid = UUID.randomUUID().toString();
        RegexTokenizer regexTokenizer = ((RegexTokenizer)new RegexTokenizer().setInputCol(uuid)).setPattern("").setMinTokenLength(1).setToLowercase(!stage.caseSensitive());
        NGram nGram = ((NGram)new NGram().setInputCol(regexTokenizer.getOutputCol())).setN(stage.shingleLength());
        CountVectorizer countVectorizer = new CountVectorizer().setInputCol(nGram.getOutputCol());
        MinHashLSH minHashLSH = new MinHashLSH().setInputCol(countVectorizer.getOutputCol()).setNumHashTables(stage.numHashTables());
        JavaUniverse $u = package$.MODULE$.universe();
        JavaUniverse.JavaMirror $m = package$.MODULE$.universe().runtimeMirror(this.getClass().getClassLoader());
        public final class Ai_tripl_arc_transform_SimilarityJoinTransformStage$$typecreator1$1
        extends TypeCreator {
            public <U extends Universe> Types.TypeApi apply(Mirror<U> $m$untyped) {
                Universe $u = $m$untyped.universe();
                Mirror<U> $m = $m$untyped;
                return $m.staticClass("org.apache.spark.ml.linalg.SparseVector").asType().toTypeConstructor();
            }

            public Ai_tripl_arc_transform_SimilarityJoinTransformStage$$typecreator1$1() {
            }
        }
        UserDefinedFunction notEmptyVector = functions$.MODULE$.udf((Function1 & Serializable & scala.Serializable)v -> BoxesRunTime.boxToBoolean((boolean)SimilarityJoinTransformStage$.$anonfun$execute$1(v)), ((TypeTags)package$.MODULE$.universe()).TypeTag().Boolean(), ((TypeTags)$u).TypeTag().apply((Mirror)$m, (TypeCreator)new Ai_tripl_arc_transform_SimilarityJoinTransformStage$$typecreator1$1()));
        Dataset leftView = spark.table(stage.leftView());
        Dataset rightView = spark.table(stage.rightView());
        Column[] leftOutputColumns = (Column[])new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])leftView.columns())).map((Function1 & Serializable & scala.Serializable)columnName -> functions$.MODULE$.col(new StringBuilder(9).append("datasetA.").append((String)columnName).toString()), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Column.class)));
        Column[] rightOutputColumns = (Column[])new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])rightView.columns())).map((Function1 & Serializable & scala.Serializable)columnName -> functions$.MODULE$.col(new StringBuilder(9).append("datasetB.").append((String)columnName).toString()), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Column.class)));
        Dataset leftViewFeatures = nGram.transform(regexTokenizer.transform(leftView.select((Seq)Predef$.MODULE$.wrapRefArray((Object[])new Column[]{functions$.MODULE$.col("*"), functions$.MODULE$.trim(functions$.MODULE$.concat((Seq)stage.leftFields().map((Function1 & Serializable & scala.Serializable)field -> functions$.MODULE$.when(functions$.MODULE$.col(field).isNotNull(), (Object)functions$.MODULE$.concat((Seq)Predef$.MODULE$.wrapRefArray((Object[])new Column[]{functions$.MODULE$.col(field).cast((DataType)StringType$.MODULE$), functions$.MODULE$.lit((Object)" ")}))).otherwise((Object)""), List$.MODULE$.canBuildFrom()))).alias(uuid)}))));
        leftViewFeatures.persist(arcContext.storageLevel());
        CountVectorizerModel countVectorizerModel = countVectorizer.fit(leftViewFeatures);
        Dataset inputLeftView = countVectorizerModel.transform(leftViewFeatures).filter(notEmptyVector.apply((Seq)Predef$.MODULE$.wrapRefArray((Object[])new Column[]{functions$.MODULE$.col(countVectorizer.getOutputCol())})));
        inputLeftView.persist(arcContext.storageLevel());
        Dataset inputRightView = countVectorizerModel.transform(nGram.transform(regexTokenizer.transform(rightView.select((Seq)Predef$.MODULE$.wrapRefArray((Object[])new Column[]{functions$.MODULE$.col("*"), functions$.MODULE$.trim(functions$.MODULE$.concat((Seq)stage.rightFields().map((Function1 & Serializable & scala.Serializable)field -> functions$.MODULE$.when(functions$.MODULE$.col(field).isNotNull(), (Object)functions$.MODULE$.concat((Seq)Predef$.MODULE$.wrapRefArray((Object[])new Column[]{functions$.MODULE$.col(field).cast((DataType)StringType$.MODULE$), functions$.MODULE$.lit((Object)" ")}))).otherwise((Object)""), List$.MODULE$.canBuildFrom()))).alias(uuid)}))))).filter(notEmptyVector.apply((Seq)Predef$.MODULE$.wrapRefArray((Object[])new Column[]{functions$.MODULE$.col(countVectorizer.getOutputCol())})));
        try {
            MinHashLSHModel minHashLSHModel = (MinHashLSHModel)minHashLSH.fit(inputLeftView);
            Dataset datasetA = minHashLSHModel.transform(inputLeftView);
            Dataset datasetB = minHashLSHModel.transform(inputRightView);
            dataset2 = minHashLSHModel.approxSimilarityJoin(datasetA, datasetB, 1.0 - stage.threshold()).select((Seq)Predef$.MODULE$.wrapRefArray((Object[])new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])leftOutputColumns)).$plus$plus((GenTraversableOnce)new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])rightOutputColumns)), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Column.class))))).$plus$plus((GenTraversableOnce)new .colon.colon((Object)functions$.MODULE$.lit((Object)BoxesRunTime.boxToDouble((double)1.0)).$minus((Object)functions$.MODULE$.col("distCol")).alias("similarity"), (List)Nil$.MODULE$), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Column.class)))));
        }
        catch (Exception e) {
            throw new DetailException(e, stage){
                private final Map<String, Object> detail;

                public Map<String, Object> detail() {
                    return this.detail;
                }
                {
                    this.detail = stage$2.stageDetail();
                }
            };
        }
        Dataset transformedDF = dataset2;
        List<String> list = stage.partitionBy();
        if (Nil$.MODULE$.equals(list)) {
            Dataset dataset3;
            Option<Object> option = stage.numPartitions();
            if (option instanceof Some) {
                Some some = (Some)option;
                int numPartitions = BoxesRunTime.unboxToInt((Object)some.value());
                dataset3 = transformedDF.repartition(numPartitions);
            } else if (None$.MODULE$.equals(option)) {
                dataset3 = transformedDF;
            } else {
                throw new MatchError(option);
            }
            dataset = dataset3;
        } else {
            Dataset dataset4;
            List partitionCols = (List)list.map((Function1 & Serializable & scala.Serializable)col -> transformedDF.apply(col), List$.MODULE$.canBuildFrom());
            Option<Object> option = stage.numPartitions();
            if (option instanceof Some) {
                Some some = (Some)option;
                int numPartitions = BoxesRunTime.unboxToInt((Object)some.value());
                dataset4 = transformedDF.repartition(numPartitions, (Seq)partitionCols);
            } else if (None$.MODULE$.equals(option)) {
                dataset4 = transformedDF.repartition((Seq)partitionCols);
            } else {
                throw new MatchError(option);
            }
            dataset = dataset4;
        }
        Dataset repartitionedDF = dataset;
        if (arcContext.immutableViews()) {
            repartitionedDF.createTempView(stage.outputView());
        } else {
            repartitionedDF.createOrReplaceTempView(stage.outputView());
        }
        if (!repartitionedDF.isStreaming()) {
            stage.stageDetail().put((Object)"outputColumns", (Object)repartitionedDF.schema().length());
            stage.stageDetail().put((Object)"numPartitions", (Object)repartitionedDF.rdd().partitions().length);
            if (stage.persist()) {
                repartitionedDF.persist(arcContext.storageLevel());
                boxedUnit = stage.stageDetail().put((Object)"records", (Object)repartitionedDF.count());
            } else {
                boxedUnit = BoxedUnit.UNIT;
            }
        } else {
            boxedUnit = BoxedUnit.UNIT;
        }
        leftViewFeatures.unpersist();
        inputLeftView.unpersist();
        return Option$.MODULE$.apply((Object)repartitionedDF);
    }

    public SimilarityJoinTransformStage apply(SimilarityJoinTransform plugin, Option<String> id, String name, Option<String> description, String leftView, List<String> leftFields, String rightView, List<String> rightFields, String outputView, boolean persist, int shingleLength, int numHashTables, double threshold, boolean caseSensitive, List<String> partitionBy, Option<Object> numPartitions, scala.collection.immutable.Map<String, String> params) {
        return new SimilarityJoinTransformStage(plugin, id, name, description, leftView, leftFields, rightView, rightFields, outputView, persist, shingleLength, numHashTables, threshold, caseSensitive, partitionBy, numPartitions, params);
    }

    public Option<Tuple17<SimilarityJoinTransform, Option<String>, String, Option<String>, String, List<String>, String, List<String>, String, Object, Object, Object, Object, Object, List<String>, Option<Object>, scala.collection.immutable.Map<String, String>>> unapply(SimilarityJoinTransformStage x$0) {
        return x$0 == null ? None$.MODULE$ : new Some((Object)new Tuple17((Object)x$0.plugin(), x$0.id(), (Object)x$0.name(), x$0.description(), (Object)x$0.leftView(), x$0.leftFields(), (Object)x$0.rightView(), x$0.rightFields(), (Object)x$0.outputView(), (Object)BoxesRunTime.boxToBoolean((boolean)x$0.persist()), (Object)BoxesRunTime.boxToInteger((int)x$0.shingleLength()), (Object)BoxesRunTime.boxToInteger((int)x$0.numHashTables()), (Object)BoxesRunTime.boxToDouble((double)x$0.threshold()), (Object)BoxesRunTime.boxToBoolean((boolean)x$0.caseSensitive()), x$0.partitionBy(), x$0.numPartitions(), x$0.params()));
    }

    private Object readResolve() {
        return MODULE$;
    }

    public static final /* synthetic */ boolean $anonfun$execute$1(SparseVector v) {
        return v.numNonzeros() > 0;
    }

    private SimilarityJoinTransformStage$() {
        MODULE$ = this;
    }
}

