/*
 * Decompiled with CFR 0.152.
 */
package com.linkedin.feathr.offline.join;

import com.linkedin.feathr.offline.client.DataFrameColName$;
import com.linkedin.feathr.offline.config.FeatureJoinConfig;
import com.linkedin.feathr.offline.job.DataFrameStatFunctions;
import com.linkedin.feathr.offline.join.DataFrameKeyCombiner;
import com.linkedin.feathr.offline.join.DataFrameKeyCombiner$;
import com.linkedin.feathr.offline.join.PreprocessedObservation;
import com.linkedin.feathr.offline.join.algorithms.SaltedSparkJoin;
import com.linkedin.feathr.offline.join.algorithms.SaltedSparkJoin$;
import com.linkedin.feathr.offline.swa.SlidingWindowFeatureUtils$;
import com.linkedin.feathr.offline.util.SourceUtils$;
import com.linkedin.feathr.offline.util.datetime.DateTimeInterval;
import java.io.Serializable;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.SparkSession$;
import org.apache.spark.sql.functions$;
import org.apache.spark.util.sketch.BloomFilter;
import scala.Function0;
import scala.Function1;
import scala.Function2;
import scala.MatchError;
import scala.None$;
import scala.Option;
import scala.Option$;
import scala.Predef$;
import scala.Some;
import scala.Tuple2;
import scala.Tuple3;
import scala.collection.GenTraversableOnce;
import scala.collection.Iterable;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.SeqLike;
import scala.collection.TraversableLike;
import scala.collection.TraversableOnce;
import scala.collection.immutable.;
import scala.collection.immutable.List;
import scala.collection.immutable.Map;
import scala.collection.immutable.Map$;
import scala.collection.immutable.Nil$;
import scala.collection.mutable.ArrayOps;
import scala.collection.mutable.HashMap;
import scala.collection.mutable.Iterable$;
import scala.math.package$;
import scala.runtime.BoxesRunTime;

public final class OptimizerUtils$ {
    public static OptimizerUtils$ MODULE$;
    private transient Logger log;
    private final long maxExpectedItemForBloomfilter;
    private final double bloomFilterFPP;
    private volatile transient boolean bitmap$trans$0;

    static {
        new OptimizerUtils$();
    }

    public long maxExpectedItemForBloomfilter() {
        return this.maxExpectedItemForBloomfilter;
    }

    public double bloomFilterFPP() {
        return this.bloomFilterFPP;
    }

    private Logger log$lzycompute() {
        OptimizerUtils$ optimizerUtils$ = this;
        synchronized (optimizerUtils$) {
            if (!this.bitmap$trans$0) {
                this.log = LogManager.getLogger((String)this.getClass().getName());
                this.bitmap$trans$0 = true;
            }
        }
        return this.log;
    }

    private Logger log() {
        return !this.bitmap$trans$0 ? this.log$lzycompute() : this.log;
    }

    public PreprocessedObservation preProcessObservation(Dataset<Row> inputDF, FeatureJoinConfig joinConfig, Seq<Tuple2<Seq<Object>, Seq<String>>> joinStages, Seq<String> keyTagList, Option<Object> rowBloomFilterThreshold, Option<SaltedSparkJoin.JoinParameters> saltedJoinParameters, Seq<String> columnsToPreserve) {
        Tuple2<Option<DateTimeInterval>, Option<String>> tuple2 = SlidingWindowFeatureUtils$.MODULE$.getObsSwaDataTimeRange(inputDF, joinConfig.settings());
        if (tuple2 == null) {
            throw new MatchError(tuple2);
        }
        Option swaObsTime = (Option)tuple2._1();
        Option obsSWATimeExpr = (Option)tuple2._2();
        Tuple2 tuple22 = new Tuple2((Object)swaObsTime, (Object)obsSWATimeExpr);
        Tuple2 tuple23 = tuple22;
        Option swaObsTime2 = (Option)tuple23._1();
        Option obsSWATimeExpr2 = (Option)tuple23._2();
        Seq<String> extraColumnsInSlickJoin = this.getSlidingWindowRelatedColumns((Option<String>)obsSWATimeExpr2, joinStages, keyTagList);
        HashMap keyTagsToBloomFilterColumnMap = new HashMap();
        Tuple3 tuple3 = (Tuple3)joinStages.foldLeft((Object)new Tuple3((Object)keyTagsToBloomFilterColumnMap, inputDF, (Object)Nil$.MODULE$), (Function2 & Serializable & scala.Serializable)(accFilterMapDF, joinStage) -> {
            boolean x$3;
            Seq x$2;
            Dataset x$1;
            Seq keyTags = (Seq)joinStage._1();
            Seq stringKeyTagList = (Seq)keyTags.map((Function1)keyTagList, Seq$.MODULE$.canBuildFrom());
            HashMap accFilterMap = (HashMap)accFilterMapDF._1();
            Dataset df = (Dataset)accFilterMapDF._2();
            DataFrameKeyCombiner qual$1 = DataFrameKeyCombiner$.MODULE$.apply();
            Tuple2<String, Dataset<Row>> tuple2 = qual$1.combine((Dataset<Row>)(x$1 = df), (Seq<String>)(x$2 = stringKeyTagList), x$3 = qual$1.combine$default$3());
            if (tuple2 == null) {
                throw new MatchError(tuple2);
            }
            String bfKeyColName = (String)tuple2._1();
            Dataset contextDFWithKeys = (Dataset)tuple2._2();
            Tuple2 tuple22 = new Tuple2((Object)bfKeyColName, (Object)contextDFWithKeys);
            Tuple2 tuple23 = tuple22;
            String bfKeyColName2 = (String)tuple23._1();
            Dataset contextDFWithKeys2 = (Dataset)tuple23._2();
            accFilterMap.put((Object)keyTags, (Object)bfKeyColName2);
            return new Tuple3((Object)accFilterMap, (Object)contextDFWithKeys2, ((TraversableLike)accFilterMapDF._3()).$plus$plus((GenTraversableOnce)new .colon.colon((Object)bfKeyColName2, (List)Nil$.MODULE$), Seq$.MODULE$.canBuildFrom()));
        });
        if (tuple3 == null) {
            throw new MatchError((Object)tuple3);
        }
        Dataset withKeyDF = (Dataset)tuple3._2();
        Seq joinKeyColumnNames = (Seq)tuple3._3();
        Tuple2 tuple24 = new Tuple2((Object)withKeyDF, (Object)joinKeyColumnNames);
        Tuple2 tuple25 = tuple24;
        Dataset withKeyDF2 = (Dataset)tuple25._1();
        Seq joinKeyColumnNames2 = (Seq)tuple25._2();
        Seq origJoinKeyColumns = ((TraversableOnce)keyTagsToBloomFilterColumnMap.map((Function1 & Serializable & scala.Serializable)x$4 -> (String)x$4._2(), Iterable$.MODULE$.canBuildFrom())).toSeq();
        Seq extraColumns = (Seq)((TraversableLike)new .colon.colon((Object)DataFrameColName$.MODULE$.UidColumnName(), (List)Nil$.MODULE$)).$plus$plus(extraColumnsInSlickJoin, Seq$.MODULE$.canBuildFrom());
        Dataset withKeyAndUidDF = withKeyDF2.withColumn(DataFrameColName$.MODULE$.UidColumnName(), functions$.MODULE$.monotonically_increasing_id());
        Seq allSelectedColumns = (Seq)((SeqLike)((TraversableLike)extraColumns.$plus$plus((GenTraversableOnce)origJoinKeyColumns, Seq$.MODULE$.canBuildFrom())).$plus$plus(columnsToPreserve, Seq$.MODULE$.canBuildFrom())).distinct();
        Dataset keyAndUidOnlyDF = withKeyAndUidDF.select((String)allSelectedColumns.head(), (Seq)allSelectedColumns.tail());
        Option<Map<Seq<Object>, BloomFilter>> bloomFilters = this.generateBloomFilters(rowBloomFilterThreshold, (Dataset<Row>)withKeyAndUidDF, joinStages, (Seq<String>)origJoinKeyColumns, (HashMap<Seq<Object>, String>)keyTagsToBloomFilterColumnMap);
        Dataset withUidDF = withKeyAndUidDF.drop(joinKeyColumnNames2);
        Map<Seq<Object>, Dataset<Row>> saltedJoinFrequentItemDFs = this.generateSaltedJoinFrequentItems((Map<Seq<Object>, String>)keyTagsToBloomFilterColumnMap.toMap(Predef$.MODULE$.$conforms()), (Dataset<Row>)withKeyDF2, saltedJoinParameters);
        return new PreprocessedObservation(bloomFilters, (Dataset<Row>)keyAndUidOnlyDF, (Dataset<Row>)withUidDF, (Option<DateTimeInterval>)swaObsTime2, extraColumnsInSlickJoin, saltedJoinFrequentItemDFs);
    }

    private Option<Map<Seq<Object>, BloomFilter>> generateBloomFilters(Option<Object> rowBloomFilterThreshold, Dataset<Row> contextDF, Seq<Tuple2<Seq<Object>, Seq<String>>> joinStages, Seq<String> origJoinKeyColumns, HashMap<Seq<Object>, String> keyTagsToBloomFilterColumnMap) {
        None$ none$;
        boolean generateBloomfilters;
        boolean bl;
        boolean forceDisable;
        long estimatedSize = package$.MODULE$.min(SourceUtils$.MODULE$.estimateRDDRow(contextDF.rdd(), SourceUtils$.MODULE$.estimateRDDRow$default$2()), this.maxExpectedItemForBloomfilter());
        boolean bl2 = forceDisable = rowBloomFilterThreshold.isDefined() && BoxesRunTime.unboxToInt((Object)rowBloomFilterThreshold.get()) == 0;
        if (estimatedSize <= 0L) {
            bl = true;
        } else {
            int threshold;
            boolean thresholdMet = rowBloomFilterThreshold.isEmpty() ? true : (threshold = BoxesRunTime.unboxToInt((Object)rowBloomFilterThreshold.get())) == -1 || threshold > 0 && estimatedSize < (long)threshold;
            bl = generateBloomfilters = thresholdMet && joinStages.nonEmpty();
        }
        if (!forceDisable && generateBloomfilters) {
            long expectItemNum = estimatedSize > 0L ? estimatedSize : this.maxExpectedItemForBloomfilter();
            Seq<BloomFilter> filters = new DataFrameStatFunctions(contextDF).batchCreateBloomFilter(origJoinKeyColumns, expectItemNum, this.bloomFilterFPP());
            Map filterMap = ((TraversableOnce)((TraversableLike)keyTagsToBloomFilterColumnMap.toSeq().zip(filters, Seq$.MODULE$.canBuildFrom())).map((Function1 & Serializable & scala.Serializable)x0$1 -> {
                BloomFilter filter;
                Tuple2 tuple2;
                block3: {
                    Tuple2 tuple22;
                    block2: {
                        tuple22 = x0$1;
                        if (tuple22 == null) break block2;
                        tuple2 = (Tuple2)tuple22._1();
                        filter = (BloomFilter)tuple22._2();
                        if (tuple2 != null) break block3;
                    }
                    throw new MatchError((Object)tuple22);
                }
                Seq tag = (Seq)tuple2._1();
                Tuple2 tuple23 = new Tuple2((Object)tag, (Object)filter);
                return tuple23;
            }, Seq$.MODULE$.canBuildFrom())).toMap(Predef$.MODULE$.$conforms());
            none$ = new Some((Object)filterMap);
        } else {
            none$ = None$.MODULE$;
        }
        None$ bloomFilters = none$;
        return bloomFilters;
    }

    private Seq<String> getSlidingWindowRelatedColumns(Option<String> obsSWATimeExpr, Seq<Tuple2<Seq<Object>, Seq<String>>> joinStages, Seq<String> keyTagList) {
        return (Seq)obsSWATimeExpr.map((Function1 & Serializable & scala.Serializable)timeExpr -> {
            Seq seq;
            Seq seq2;
            SparkSession ss = SparkSession$.MODULE$.builder().getOrCreate();
            try {
                seq2 = OptimizerUtils$.getTopLevelReferencedFields$1(timeExpr, ss);
            }
            catch (Exception exception) {
                seq2 = (Seq)Nil$.MODULE$;
            }
            Seq swaTimeColumns = seq2;
            try {
                seq = (Seq)joinStages.flatMap((Function1 & Serializable & scala.Serializable)joinStage -> {
                    Seq keyTags = (Seq)joinStage._1();
                    return (Seq)keyTags.flatMap((Function1 & Serializable & scala.Serializable)keyTagId -> OptimizerUtils$.getTopLevelReferencedFields$1((String)keyTagList.apply(BoxesRunTime.unboxToInt((Object)keyTagId)), ss), Seq$.MODULE$.canBuildFrom());
                }, Seq$.MODULE$.canBuildFrom());
            }
            catch (Exception exception) {
                seq = (Seq)Nil$.MODULE$;
            }
            Seq swaJoinKeyColumns = seq;
            return (Seq)((SeqLike)swaTimeColumns.$plus$plus((GenTraversableOnce)swaJoinKeyColumns, Seq$.MODULE$.canBuildFrom())).distinct();
        }).getOrElse((Function0 & Serializable & scala.Serializable)() -> (Seq)Nil$.MODULE$);
    }

    private Map<Seq<Object>, Dataset<Row>> generateSaltedJoinFrequentItems(Map<Seq<Object>, String> keyTagsToColumnMap, Dataset<Row> withKeyDF, Option<SaltedSparkJoin.JoinParameters> saltedJoinParameters) {
        if (saltedJoinParameters.isEmpty()) {
            return (Map)Predef$.MODULE$.Map().apply((Seq)Nil$.MODULE$);
        }
        return (Map)keyTagsToColumnMap.flatMap((Function1 & Serializable & scala.Serializable)x0$1 -> {
            Iterable iterable;
            Tuple2 tuple2 = x0$1;
            if (tuple2 != null) {
                Seq keyTags = (Seq)tuple2._1();
                String keyColumnName = (String)tuple2._2();
                Dataset<Row> frequentItemsDf = SaltedSparkJoin$.MODULE$.getFrequentItemsDataFrame(((SaltedSparkJoin.JoinParameters)saltedJoinParameters.get()).estimator(), withKeyDF, keyColumnName, ((SaltedSparkJoin.JoinParameters)saltedJoinParameters.get()).frequentItemThreshold());
                if (new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])frequentItemsDf.head(1))).nonEmpty()) {
                    iterable = Option$.MODULE$.option2Iterable((Option)new Some((Object)new Tuple2((Object)keyTags, frequentItemsDf)));
                } else {
                    MODULE$.log().info(new StringBuilder(40).append("Salted Join: no frequent items for key: ").append(keyColumnName).toString());
                    iterable = Option$.MODULE$.option2Iterable((Option)None$.MODULE$);
                }
            } else {
                throw new MatchError((Object)tuple2);
            }
            Iterable iterable2 = iterable;
            return iterable2;
        }, Map$.MODULE$.canBuildFrom());
    }

    private static final Seq getTopLevelReferencedFields$1(String sqlExpr, SparkSession ss) {
        return ((TraversableOnce)ss.sessionState().sqlParser().parseExpression(sqlExpr).references().map((Function1 & Serializable & scala.Serializable)x$5 -> (String)new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])x$5.name().split("\\."))).head(), scala.collection.Iterable$.MODULE$.canBuildFrom())).toSeq();
    }

    private OptimizerUtils$() {
        MODULE$ = this;
        this.maxExpectedItemForBloomfilter = 50000000L;
        this.bloomFilterFPP = 0.05;
    }
}

