/*
 * Decompiled with CFR 0.152.
 */
package com.linkedin.feathr.offline.evaluator.lookup;

import com.linkedin.feathr.common.FeatureValue;
import com.linkedin.feathr.compute.AnyNode;
import com.linkedin.feathr.compute.Lookup;
import com.linkedin.feathr.compute.NodeReference;
import com.linkedin.feathr.offline.PostTransformationUtil$;
import com.linkedin.feathr.offline.derived.strategies.SeqJoinAggregator$;
import com.linkedin.feathr.offline.derived.strategies.SequentialJoinAsDerivation$;
import com.linkedin.feathr.offline.evaluator.NodeEvaluator;
import com.linkedin.feathr.offline.graph.DataframeAndColumnMetadata;
import com.linkedin.feathr.offline.graph.DataframeAndColumnMetadata$;
import com.linkedin.feathr.offline.graph.FCMGraphTraverser;
import com.linkedin.feathr.offline.graph.NodeUtils$;
import com.linkedin.feathr.offline.join.algorithms.JoinType$;
import com.linkedin.feathr.offline.join.algorithms.SequentialJoinConditionBuilder$;
import com.linkedin.feathr.offline.join.algorithms.SparkJoinWithJoinCondition;
import com.linkedin.feathr.offline.join.algorithms.SparkJoinWithJoinCondition$;
import com.linkedin.feathr.offline.mvel.plugins.FeathrExpressionExecutionContext;
import com.linkedin.feathr.offline.source.accessor.DataPathHandler;
import com.linkedin.feathr.offline.transformation.MvelDefinition;
import com.linkedin.feathr.offline.util.DataFrameSplitterMerger$;
import java.io.Serializable;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.functions$;
import org.apache.spark.sql.types.DataType;
import scala.Array$;
import scala.Function1;
import scala.Function2;
import scala.MatchError;
import scala.None$;
import scala.Option;
import scala.Predef$;
import scala.Some;
import scala.Tuple2;
import scala.collection.GenTraversableOnce;
import scala.collection.IterableLike;
import scala.collection.JavaConverters$;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.immutable.;
import scala.collection.immutable.List;
import scala.collection.immutable.Map;
import scala.collection.immutable.Nil$;
import scala.collection.mutable.ArrayOps;
import scala.reflect.ClassTag$;
import scala.runtime.BoxesRunTime;

public final class LookupNodeEvaluator$
implements NodeEvaluator {
    public static LookupNodeEvaluator$ MODULE$;

    static {
        new LookupNodeEvaluator$();
    }

    public DataframeAndColumnMetadata processLookupNode(Lookup lookupNode, DataframeAndColumnMetadata baseNode, Seq<String> baseKeyColumns, DataframeAndColumnMetadata expansionNode, Dataset<Row> contextDf, String seqJoinFeatureName, SparkJoinWithJoinCondition seqJoinJoiner, Map<String, FeatureValue> defaultValueMap, SparkSession ss) {
        String expansionFeatureName = (String)expansionNode.featureColumn().get();
        Seq expansionNodeCols = (Seq)expansionNode.keyExpression().$plus$plus((GenTraversableOnce)new .colon.colon((Object)((String)expansionNode.featureColumn().get()), (List)Nil$.MODULE$), Seq$.MODULE$.canBuildFrom());
        Dataset expansionNodeDF = expansionNode.df().select((Seq)expansionNodeCols.map((Function1 & Serializable & scala.Serializable)colName -> functions$.MODULE$.col(colName), Seq$.MODULE$.canBuildFrom()));
        Seq expansionNodeRenamedCols = new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])expansionNodeDF.columns())).map((Function1 & Serializable & scala.Serializable)c -> new StringBuilder(13).append("__expansion__").append((String)c).toString(), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(String.class))))).toSeq();
        Dataset expansionNodeDfWithRenamedCols = expansionNodeDF.toDF(expansionNodeRenamedCols);
        Dataset<Row> left = PostTransformationUtil$.MODULE$.transformFeatures((Seq<Tuple2<String, String>>)((Seq)new .colon.colon((Object)new Tuple2(baseNode.featureColumn().get(), baseNode.featureColumn().get()), (List)Nil$.MODULE$)), contextDf, (Map<String, MvelDefinition>)Predef$.MODULE$.Map().empty(), (Function2<DataType, String, Column>)(Function2 & Serializable & scala.Serializable)(dataType, columnName) -> SequentialJoinAsDerivation$.MODULE$.getDefaultTransformation((DataType)dataType, (String)columnName), (Option<FeathrExpressionExecutionContext>)None$.MODULE$);
        Tuple2<Dataset<Row>, Dataset<Row>> tuple2 = DataFrameSplitterMerger$.MODULE$.splitOnNull(left, (String)baseNode.featureColumn().get());
        if (tuple2 == null) {
            throw new MatchError(tuple2);
        }
        Dataset coercedBaseDfWithNoNull = (Dataset)tuple2._1();
        Dataset coercedBaseDfWithNull = (Dataset)tuple2._2();
        Tuple2 tuple22 = new Tuple2((Object)coercedBaseDfWithNoNull, (Object)coercedBaseDfWithNull);
        Tuple2 tuple23 = tuple22;
        Dataset coercedBaseDfWithNoNull2 = (Dataset)tuple23._1();
        Dataset coercedBaseDfWithNull2 = (Dataset)tuple23._2();
        String groupByColumn = "__frame_seq_join_group_by_id";
        Dataset leftWithUidDF = coercedBaseDfWithNoNull2.withColumn(groupByColumn, functions$.MODULE$.monotonically_increasing_id());
        Tuple2<Seq<String>, Dataset<Row>> tuple24 = SeqJoinAggregator$.MODULE$.explodeLeftJoinKey(ss, (Dataset<Row>)leftWithUidDF, baseKeyColumns, seqJoinFeatureName);
        if (tuple24 == null) {
            throw new MatchError(tuple24);
        }
        Seq adjustedLeftJoinKey = (Seq)tuple24._1();
        Dataset explodedLeft = (Dataset)tuple24._2();
        Tuple2 tuple25 = new Tuple2((Object)adjustedLeftJoinKey, (Object)explodedLeft);
        Tuple2 tuple26 = tuple25;
        Seq adjustedLeftJoinKey2 = (Seq)tuple26._1();
        Dataset explodedLeft2 = (Dataset)tuple26._2();
        Dataset<Row> intermediateResult = seqJoinJoiner.join((Seq<String>)adjustedLeftJoinKey2, (Dataset<Row>)explodedLeft2, (Seq<String>)((Seq)expansionNode.keyExpression().map((Function1 & Serializable & scala.Serializable)c -> new StringBuilder(13).append("__expansion__").append((String)c).toString(), Seq$.MODULE$.canBuildFrom())), (Dataset<Row>)expansionNodeDfWithRenamedCols, JoinType$.MODULE$.left_outer());
        String producedFeatureName = new StringBuilder(13).append("__expansion__").append(expansionFeatureName).toString();
        Option expansionFeatureDefaultValue = defaultValueMap.get((Object)expansionFeatureName);
        Dataset<Row> intermediateResultWithDefault = SeqJoinAggregator$.MODULE$.substituteDefaultValuesForSeqJoinFeature(intermediateResult, producedFeatureName, (Option<FeatureValue>)expansionFeatureDefaultValue, ss);
        String aggregationType = lookupNode.getAggregation();
        Dataset<Row> aggDf = SeqJoinAggregator$.MODULE$.applyAggregationFunction(seqJoinFeatureName, producedFeatureName, intermediateResultWithDefault, aggregationType, groupByColumn);
        Dataset<Row> coercedBaseDfWithNullWithDefault = SeqJoinAggregator$.MODULE$.substituteDefaultValuesForSeqJoinFeature((Dataset<Row>)coercedBaseDfWithNull2.withColumn(producedFeatureName, functions$.MODULE$.lit(null).cast(intermediateResult.schema().apply(producedFeatureName).dataType())), producedFeatureName, (Option<FeatureValue>)expansionFeatureDefaultValue, ss);
        Dataset<Row> coercedBaseDfWithNullWithAgg = SeqJoinAggregator$.MODULE$.applyAggregationFunction(seqJoinFeatureName, producedFeatureName, (Dataset<Row>)coercedBaseDfWithNullWithDefault.withColumn(groupByColumn, functions$.MODULE$.monotonically_increasing_id()), aggregationType, groupByColumn);
        Dataset<Row> finalRes = DataFrameSplitterMerger$.MODULE$.merge(aggDf, coercedBaseDfWithNullWithAgg);
        Dataset resWithDroppedCols = finalRes.drop((Seq)expansionNode.keyExpression().map((Function1 & Serializable & scala.Serializable)c -> new StringBuilder(13).append("__expansion__").append((String)c).toString(), Seq$.MODULE$.canBuildFrom())).drop(new StringBuilder(8).append("__base__").append(baseNode.featureColumn().get()).toString());
        Dataset finalResAfterDroppingCols = resWithDroppedCols.withColumnRenamed(producedFeatureName, seqJoinFeatureName);
        return new DataframeAndColumnMetadata((Dataset<Row>)finalResAfterDroppingCols, (Seq<String>)((Seq)baseNode.keyExpression().map((Function1 & Serializable & scala.Serializable)x -> (String)new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])x.split("__"))).last(), Seq$.MODULE$.canBuildFrom())), (Option<String>)new Some((Object)seqJoinFeatureName), DataframeAndColumnMetadata$.MODULE$.apply$default$4(), DataframeAndColumnMetadata$.MODULE$.apply$default$5());
    }

    private Seq<Integer> getLookupNodeKeys(AnyNode node) {
        Seq seq;
        AnyNode anyNode = node;
        if (anyNode.isLookup()) {
            seq = (Seq)JavaConverters$.MODULE$.asScalaBufferConverter((java.util.List)anyNode.getLookup().getConcreteKey().getKey()).asScala();
        } else if (anyNode.isDataSource()) {
            seq = anyNode.getDataSource().hasConcreteKey() ? (Seq)JavaConverters$.MODULE$.asScalaBufferConverter((java.util.List)anyNode.getDataSource().getConcreteKey().getKey()).asScala() : null;
        } else if (anyNode.isTransformation()) {
            seq = (Seq)JavaConverters$.MODULE$.asScalaBufferConverter((java.util.List)anyNode.getTransformation().getConcreteKey().getKey()).asScala();
        } else {
            throw new MatchError((Object)anyNode);
        }
        return seq;
    }

    @Override
    public Dataset<Row> evaluate(AnyNode node, FCMGraphTraverser graphTraverser, Dataset<Row> contextDf, List<DataPathHandler> dataPathHandlers) {
        Lookup lookUpNode = node.getLookup();
        NodeReference baseNodeRef = ((Lookup.LookupKey)((IterableLike)JavaConverters$.MODULE$.asScalaBufferConverter((java.util.List)lookUpNode.getLookupKey()).asScala()).find((Function1 & Serializable & scala.Serializable)x -> BoxesRunTime.boxToBoolean((boolean)x.isNodeReference())).get()).getNodeReference();
        DataframeAndColumnMetadata baseNode = (DataframeAndColumnMetadata)graphTraverser.nodeIdToDataframeAndColumnMetadataMap().apply((Object)BoxesRunTime.boxToInteger((int)Predef$.MODULE$.Integer2int(baseNodeRef.getId())));
        Seq baseKeyColumns = (Seq)this.getLookupNodeKeys((AnyNode)graphTraverser.nodes().apply(Predef$.MODULE$.Integer2int(lookUpNode.getLookupNode()))).flatMap((Function1 & Serializable & scala.Serializable)x -> ((DataframeAndColumnMetadata)graphTraverser.nodeIdToDataframeAndColumnMetadataMap().apply((Object)BoxesRunTime.boxToInteger((int)Predef$.MODULE$.Integer2int(x)))).featureColumn().isDefined() ? (Seq)new .colon.colon((Object)((String)((DataframeAndColumnMetadata)graphTraverser.nodeIdToDataframeAndColumnMetadataMap().apply((Object)BoxesRunTime.boxToInteger((int)Predef$.MODULE$.Integer2int(x)))).featureColumn().get()), (List)Nil$.MODULE$) : ((DataframeAndColumnMetadata)graphTraverser.nodeIdToDataframeAndColumnMetadataMap().apply((Object)BoxesRunTime.boxToInteger((int)Predef$.MODULE$.Integer2int(x)))).keyExpression(), Seq$.MODULE$.canBuildFrom());
        Integer expansionNodeId = lookUpNode.getLookupNode();
        DataframeAndColumnMetadata expansionNode = (DataframeAndColumnMetadata)graphTraverser.nodeIdToDataframeAndColumnMetadataMap().apply((Object)BoxesRunTime.boxToInteger((int)Predef$.MODULE$.Integer2int(expansionNodeId)));
        String seqJoinFeatureName = (String)graphTraverser.nodeIdToFeatureName().apply((Object)lookUpNode.getId());
        Map<String, FeatureValue> expansionNodeDefaultConverter = NodeUtils$.MODULE$.getDefaultConverter((Seq<AnyNode>)((Seq)new .colon.colon((Object)((AnyNode)graphTraverser.nodes().apply(Predef$.MODULE$.Integer2int(expansionNodeId))), (List)Nil$.MODULE$)));
        DataframeAndColumnMetadata lookupNodeContext = this.processLookupNode(lookUpNode, baseNode, (Seq<String>)baseKeyColumns, expansionNode, contextDf, seqJoinFeatureName, SparkJoinWithJoinCondition$.MODULE$.apply(SequentialJoinConditionBuilder$.MODULE$), expansionNodeDefaultConverter, graphTraverser.ss());
        graphTraverser.nodeIdToDataframeAndColumnMetadataMap().update((Object)BoxesRunTime.boxToInteger((int)Predef$.MODULE$.Integer2int(lookUpNode.getId())), (Object)lookupNodeContext);
        return lookupNodeContext.df();
    }

    @Override
    public Dataset<Row> batchEvaluate(Seq<AnyNode> nodes, FCMGraphTraverser graphTraverser, Dataset<Row> contextDf, List<DataPathHandler> dataPathHandlers) {
        return (Dataset)nodes.foldLeft(contextDf, (Function2 & Serializable & scala.Serializable)(updatedContextDf, node) -> MODULE$.evaluate((AnyNode)node, graphTraverser, (Dataset<Row>)updatedContextDf, dataPathHandlers));
    }

    private LookupNodeEvaluator$() {
        MODULE$ = this;
    }
}

