/*
 * Decompiled with CFR 0.152.
 */
package ai.tripl.arc.transform;

import ai.tripl.arc.api.API;
import ai.tripl.arc.transform.SimilarityJoinTransform;
import ai.tripl.arc.transform.SimilarityJoinTransformStage;
import ai.tripl.arc.util.DetailException;
import ai.tripl.arc.util.log.logger.Logger;
import java.util.UUID;
import org.apache.spark.ml.feature.CountVectorizer;
import org.apache.spark.ml.feature.CountVectorizerModel;
import org.apache.spark.ml.feature.MinHashLSH;
import org.apache.spark.ml.feature.MinHashLSHModel;
import org.apache.spark.ml.feature.NGram;
import org.apache.spark.ml.feature.RegexTokenizer;
import org.apache.spark.ml.linalg.SparseVector;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.expressions.UserDefinedFunction;
import org.apache.spark.sql.functions$;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StringType$;
import scala.Array$;
import scala.Function1;
import scala.MatchError;
import scala.None$;
import scala.Option;
import scala.Option$;
import scala.Predef$;
import scala.Serializable;
import scala.Some;
import scala.StringContext;
import scala.Tuple16;
import scala.collection.GenTraversableOnce;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.immutable.List;
import scala.collection.immutable.List$;
import scala.collection.immutable.Nil$;
import scala.collection.mutable.Map;
import scala.reflect.ClassTag$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;

public final class SimilarityJoinTransformStage$
implements Serializable {
    public static final SimilarityJoinTransformStage$ MODULE$;

    static {
        new SimilarityJoinTransformStage$();
    }

    public Option<Dataset<Row>> execute(SimilarityJoinTransformStage stage, SparkSession spark, Logger logger, API.ARCContext arcContext) {
        Option<Object> option;
        block18: {
            BoxedUnit boxedUnit;
            Dataset dataset;
            Dataset inputLeftView;
            Dataset leftViewFeatures;
            block15: {
                Dataset dataset2;
                block17: {
                    List partitionCols;
                    Dataset transformedDF;
                    block16: {
                        List<String> list;
                        block11: {
                            Option<Object> option2;
                            block14: {
                                Dataset dataset3;
                                block13: {
                                    block12: {
                                        String uuid = UUID.randomUUID().toString();
                                        RegexTokenizer regexTokenizer = ((RegexTokenizer)new RegexTokenizer().setInputCol(uuid)).setPattern("").setMinTokenLength(1).setToLowercase(!stage.caseSensitive());
                                        NGram nGram = ((NGram)new NGram().setInputCol(regexTokenizer.getOutputCol())).setN(stage.shingleLength());
                                        CountVectorizer countVectorizer = new CountVectorizer().setInputCol(nGram.getOutputCol());
                                        MinHashLSH minHashLSH = new MinHashLSH().setInputCol(countVectorizer.getOutputCol()).setNumHashTables(stage.numHashTables());
                                        UserDefinedFunction notEmptyVector = functions$.MODULE$.udf((Object)new Serializable(){
                                            public static final long serialVersionUID = 0L;

                                            public final boolean apply(SparseVector v) {
                                                return v.numNonzeros() > 0;
                                            }
                                        }, DataTypes.BooleanType);
                                        Dataset leftView = spark.table(stage.leftView());
                                        Dataset rightView = spark.table(stage.rightView());
                                        Column[] leftOutputColumns = (Column[])Predef$.MODULE$.refArrayOps((Object[])leftView.columns()).map((Function1)new Serializable(){
                                            public static final long serialVersionUID = 0L;

                                            public final Column apply(String columnName) {
                                                return functions$.MODULE$.col(new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"datasetA.", ""})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{columnName})));
                                            }
                                        }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Column.class)));
                                        Column[] rightOutputColumns = (Column[])Predef$.MODULE$.refArrayOps((Object[])rightView.columns()).map((Function1)new Serializable(){
                                            public static final long serialVersionUID = 0L;

                                            public final Column apply(String columnName) {
                                                return functions$.MODULE$.col(new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"datasetB.", ""})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{columnName})));
                                            }
                                        }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Column.class)));
                                        leftViewFeatures = nGram.transform(regexTokenizer.transform(leftView.select((Seq)Predef$.MODULE$.wrapRefArray((Object[])new Column[]{functions$.MODULE$.col("*"), functions$.MODULE$.trim(functions$.MODULE$.concat((Seq)stage.leftFields().map((Function1)new Serializable(){
                                            public static final long serialVersionUID = 0L;

                                            public final Column apply(String field) {
                                                return functions$.MODULE$.when(functions$.MODULE$.col(field).isNotNull(), (Object)functions$.MODULE$.concat((Seq)Predef$.MODULE$.wrapRefArray((Object[])new Column[]{functions$.MODULE$.col(field).cast((DataType)StringType$.MODULE$), functions$.MODULE$.lit((Object)" ")}))).otherwise((Object)"");
                                            }
                                        }, List$.MODULE$.canBuildFrom()))).alias(uuid)}))));
                                        leftViewFeatures.persist(arcContext.storageLevel());
                                        CountVectorizerModel countVectorizerModel = countVectorizer.fit(leftViewFeatures);
                                        inputLeftView = countVectorizerModel.transform(leftViewFeatures).filter(notEmptyVector.apply((Seq)Predef$.MODULE$.wrapRefArray((Object[])new Column[]{functions$.MODULE$.col(countVectorizer.getOutputCol())})));
                                        inputLeftView.persist(arcContext.storageLevel());
                                        Dataset inputRightView = countVectorizerModel.transform(nGram.transform(regexTokenizer.transform(rightView.select((Seq)Predef$.MODULE$.wrapRefArray((Object[])new Column[]{functions$.MODULE$.col("*"), functions$.MODULE$.trim(functions$.MODULE$.concat((Seq)stage.rightFields().map((Function1)new Serializable(){
                                            public static final long serialVersionUID = 0L;

                                            public final Column apply(String field) {
                                                return functions$.MODULE$.when(functions$.MODULE$.col(field).isNotNull(), (Object)functions$.MODULE$.concat((Seq)Predef$.MODULE$.wrapRefArray((Object[])new Column[]{functions$.MODULE$.col(field).cast((DataType)StringType$.MODULE$), functions$.MODULE$.lit((Object)" ")}))).otherwise((Object)"");
                                            }
                                        }, List$.MODULE$.canBuildFrom()))).alias(uuid)}))))).filter(notEmptyVector.apply((Seq)Predef$.MODULE$.wrapRefArray((Object[])new Column[]{functions$.MODULE$.col(countVectorizer.getOutputCol())})));
                                        try {
                                            MinHashLSHModel minHashLSHModel = (MinHashLSHModel)minHashLSH.fit(inputLeftView);
                                            Dataset datasetA = minHashLSHModel.transform(inputLeftView);
                                            Dataset datasetB = minHashLSHModel.transform(inputRightView);
                                            transformedDF = minHashLSHModel.approxSimilarityJoin(datasetA, datasetB, 1.0 - stage.threshold()).select((Seq)Predef$.MODULE$.wrapRefArray((Object[])Predef$.MODULE$.refArrayOps((Object[])Predef$.MODULE$.refArrayOps((Object[])leftOutputColumns).$plus$plus((GenTraversableOnce)Predef$.MODULE$.refArrayOps((Object[])rightOutputColumns), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Column.class)))).$plus$plus((GenTraversableOnce)Seq$.MODULE$.apply((Seq)Predef$.MODULE$.wrapRefArray((Object[])new Column[]{functions$.MODULE$.lit((Object)BoxesRunTime.boxToDouble((double)1.0)).$minus((Object)functions$.MODULE$.col("distCol")).alias("similarity")})), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Column.class)))));
                                            list = stage.partitionBy();
                                            if (!Nil$.MODULE$.equals(list)) break block11;
                                        }
                                        catch (Exception exception) {
                                            throw new DetailException(stage, exception){
                                                private final Map<String, Object> detail;

                                                public Map<String, Object> detail() {
                                                    return this.detail;
                                                }
                                                {
                                                    this.detail = stage$1.stageDetail();
                                                }
                                            };
                                        }
                                        option2 = stage.numPartitions();
                                        if (!(option2 instanceof Some)) break block12;
                                        Some some = (Some)option2;
                                        int numPartitions = BoxesRunTime.unboxToInt((Object)some.x());
                                        dataset3 = transformedDF.repartition(numPartitions);
                                        break block13;
                                    }
                                    if (!None$.MODULE$.equals(option2)) break block14;
                                    dataset3 = transformedDF;
                                }
                                dataset = dataset3;
                                break block15;
                            }
                            throw new MatchError(option2);
                        }
                        partitionCols = (List)list.map((Function1)new Serializable(transformedDF){
                            public static final long serialVersionUID = 0L;
                            private final Dataset transformedDF$1;

                            public final Column apply(String col) {
                                return this.transformedDF$1.apply(col);
                            }
                            {
                                this.transformedDF$1 = transformedDF$1;
                            }
                        }, List$.MODULE$.canBuildFrom());
                        option = stage.numPartitions();
                        if (!(option instanceof Some)) break block16;
                        Some some = (Some)option;
                        int numPartitions = BoxesRunTime.unboxToInt((Object)some.x());
                        dataset2 = transformedDF.repartition(numPartitions, (Seq)partitionCols);
                        break block17;
                    }
                    if (!None$.MODULE$.equals(option)) break block18;
                    dataset2 = transformedDF.repartition((Seq)partitionCols);
                }
                dataset = dataset2;
            }
            Dataset repartitionedDF = dataset;
            if (arcContext.immutableViews()) {
                repartitionedDF.createTempView(stage.outputView());
            } else {
                repartitionedDF.createOrReplaceTempView(stage.outputView());
            }
            if (repartitionedDF.isStreaming()) {
                boxedUnit = BoxedUnit.UNIT;
            } else {
                stage.stageDetail().put((Object)"outputColumns", (Object)repartitionedDF.schema().length());
                stage.stageDetail().put((Object)"numPartitions", (Object)repartitionedDF.rdd().partitions().length);
                if (stage.persist()) {
                    repartitionedDF.persist(arcContext.storageLevel());
                    boxedUnit = stage.stageDetail().put((Object)"records", (Object)repartitionedDF.count());
                } else {
                    boxedUnit = BoxedUnit.UNIT;
                }
            }
            leftViewFeatures.unpersist();
            inputLeftView.unpersist();
            return Option$.MODULE$.apply((Object)repartitionedDF);
        }
        throw new MatchError(option);
    }

    public SimilarityJoinTransformStage apply(SimilarityJoinTransform plugin, String name, Option<String> description, String leftView, List<String> leftFields, String rightView, List<String> rightFields, String outputView, boolean persist, int shingleLength, int numHashTables, double threshold, boolean caseSensitive, List<String> partitionBy, Option<Object> numPartitions, scala.collection.immutable.Map<String, String> params) {
        return new SimilarityJoinTransformStage(plugin, name, description, leftView, leftFields, rightView, rightFields, outputView, persist, shingleLength, numHashTables, threshold, caseSensitive, partitionBy, numPartitions, params);
    }

    public Option<Tuple16<SimilarityJoinTransform, String, Option<String>, String, List<String>, String, List<String>, String, Object, Object, Object, Object, Object, List<String>, Option<Object>, scala.collection.immutable.Map<String, String>>> unapply(SimilarityJoinTransformStage x$0) {
        return x$0 == null ? None$.MODULE$ : new Some((Object)new Tuple16((Object)x$0.plugin(), (Object)x$0.name(), x$0.description(), (Object)x$0.leftView(), x$0.leftFields(), (Object)x$0.rightView(), x$0.rightFields(), (Object)x$0.outputView(), (Object)BoxesRunTime.boxToBoolean((boolean)x$0.persist()), (Object)BoxesRunTime.boxToInteger((int)x$0.shingleLength()), (Object)BoxesRunTime.boxToInteger((int)x$0.numHashTables()), (Object)BoxesRunTime.boxToDouble((double)x$0.threshold()), (Object)BoxesRunTime.boxToBoolean((boolean)x$0.caseSensitive()), x$0.partitionBy(), x$0.numPartitions(), x$0.params()));
    }

    private Object readResolve() {
        return MODULE$;
    }

    private SimilarityJoinTransformStage$() {
        MODULE$ = this;
    }
}

