/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.ml.feature;

import java.io.Serializable;
import org.apache.spark.ml.Model;
import org.apache.spark.ml.feature.LSH;
import org.apache.spark.ml.feature.LSHModel;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.linalg.VectorUDT;
import org.apache.spark.ml.util.MLTestingUtils$;
import org.apache.spark.ml.util.SchemaUtils$;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.expressions.UserDefinedFunction;
import org.apache.spark.sql.functions$;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import scala.Function1;
import scala.Function2;
import scala.Predef$;
import scala.Tuple2;
import scala.collection.Seq;
import scala.runtime.BoxesRunTime;

public final class LSHTest$ {
    public static LSHTest$ MODULE$;

    static {
        new LSHTest$();
    }

    public <T extends LSHModel<T>> Tuple2<Object, Object> calculateLSHProperty(Dataset<?> dataset, LSH<T> lsh, double distFP, double distFN) {
        LSHModel model = lsh.fit(dataset);
        String inputCol = model.getInputCol();
        String outputCol = model.getOutputCol();
        Dataset transformedData = model.transform(dataset);
        MLTestingUtils$.MODULE$.checkCopyAndUids(lsh, (Model<?>)model);
        SchemaUtils$.MODULE$.checkColumnType(transformedData.schema(), model.getOutputCol(), (DataType)DataTypes.createArrayType((DataType)new VectorUDT()), SchemaUtils$.MODULE$.checkColumnType$default$4());
        Seq headHashValue = (Seq)((Row)transformedData.select(outputCol, (Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[0])).head()).get(0);
        Predef$.MODULE$.assert(headHashValue.length() == model.getNumHashTables());
        Dataset pairs = transformedData.as("a").crossJoin(transformedData.as("b"));
        UserDefinedFunction distUDF = functions$.MODULE$.udf((Function2 & Serializable & scala.Serializable)(x, y) -> BoxesRunTime.boxToDouble((double)model.keyDistance(x, y)), DataTypes.DoubleType);
        UserDefinedFunction sameBucket = functions$.MODULE$.udf((Function2 & Serializable & scala.Serializable)(x, y) -> BoxesRunTime.boxToBoolean((boolean)LSHTest$.$anonfun$calculateLSHProperty$2(model, x, y)), DataTypes.BooleanType);
        Dataset result = pairs.withColumn("same_bucket", sameBucket.apply((Seq)Predef$.MODULE$.wrapRefArray((Object[])new Column[]{functions$.MODULE$.col(new StringBuilder(2).append("a.").append(outputCol).toString()), functions$.MODULE$.col(new StringBuilder(2).append("b.").append(outputCol).toString())}))).withColumn("distance", distUDF.apply((Seq)Predef$.MODULE$.wrapRefArray((Object[])new Column[]{functions$.MODULE$.col(new StringBuilder(2).append("a.").append(inputCol).toString()), functions$.MODULE$.col(new StringBuilder(2).append("b.").append(inputCol).toString())})));
        Dataset positive = result.filter(functions$.MODULE$.col("same_bucket"));
        Dataset negative = result.filter(functions$.MODULE$.col("same_bucket").unary_$bang());
        double falsePositiveCount = positive.filter(functions$.MODULE$.col("distance").$greater((Object)BoxesRunTime.boxToDouble((double)distFP))).count();
        double falseNegativeCount = negative.filter(functions$.MODULE$.col("distance").$less((Object)BoxesRunTime.boxToDouble((double)distFN))).count();
        return new Tuple2.mcDD.sp(falsePositiveCount / (double)positive.count(), falseNegativeCount / (double)negative.count());
    }

    public <T extends LSHModel<T>> Tuple2<Object, Object> calculateApproxNearestNeighbors(LSH<T> lsh, Dataset<?> dataset, Vector key, int k, boolean singleProbe) {
        LSHModel model = lsh.fit(dataset);
        UserDefinedFunction distUDF = functions$.MODULE$.udf((Function1 & Serializable & scala.Serializable)x -> BoxesRunTime.boxToDouble((double)model.keyDistance(x, key)), DataTypes.DoubleType);
        Dataset expected = dataset.sort((Seq)Predef$.MODULE$.wrapRefArray((Object[])new Column[]{distUDF.apply((Seq)Predef$.MODULE$.wrapRefArray((Object[])new Column[]{functions$.MODULE$.col(model.getInputCol())}))})).limit(k);
        Dataset actual = model.approxNearestNeighbors(dataset, key, k, singleProbe, "distCol");
        Predef$.MODULE$.assert(actual.schema().sameType((DataType)model.transformSchema(dataset.schema()).add("distCol", DataTypes.DoubleType)));
        if (!singleProbe) {
            Predef$.MODULE$.assert(actual.count() == (long)k);
        }
        double correctCount = expected.join(actual, model.getInputCol()).count();
        return new Tuple2.mcDD.sp(correctCount / (double)actual.count(), correctCount / (double)expected.count());
    }

    public <T extends LSHModel<T>> Tuple2<Object, Object> calculateApproxSimilarityJoin(LSH<T> lsh, Dataset<?> datasetA, Dataset<?> datasetB, double threshold) {
        LSHModel model = lsh.fit(datasetA);
        String inputCol = model.getInputCol();
        UserDefinedFunction distUDF = functions$.MODULE$.udf((Function2 & Serializable & scala.Serializable)(x, y) -> BoxesRunTime.boxToDouble((double)model.keyDistance(x, y)), DataTypes.DoubleType);
        Dataset expected = datasetA.as("a").crossJoin(datasetB.as("b")).filter(distUDF.apply((Seq)Predef$.MODULE$.wrapRefArray((Object[])new Column[]{functions$.MODULE$.col(new StringBuilder(2).append("a.").append(inputCol).toString()), functions$.MODULE$.col(new StringBuilder(2).append("b.").append(inputCol).toString())})).$less((Object)BoxesRunTime.boxToDouble((double)threshold)));
        Dataset actual = model.approxSimilarityJoin(datasetA, datasetB, threshold);
        SchemaUtils$.MODULE$.checkColumnType(actual.schema(), "distCol", DataTypes.DoubleType, SchemaUtils$.MODULE$.checkColumnType$default$4());
        Predef$.MODULE$.assert(actual.schema().apply("datasetA").dataType().sameType((DataType)model.transformSchema(datasetA.schema())));
        Predef$.MODULE$.assert(actual.schema().apply("datasetB").dataType().sameType((DataType)model.transformSchema(datasetB.schema())));
        double correctCount = actual.filter(functions$.MODULE$.col("distCol").$less((Object)BoxesRunTime.boxToDouble((double)threshold))).count();
        return new Tuple2.mcDD.sp(correctCount / (double)actual.count(), correctCount / (double)expected.count());
    }

    public static final /* synthetic */ boolean $anonfun$calculateLSHProperty$2(LSHModel model$1, Seq x, Seq y) {
        return model$1.hashDistance(x, y) == 0.0;
    }

    private LSHTest$() {
        MODULE$ = this;
    }
}

