/*
 * Decompiled with CFR 0.152.
 */
package com.linkedin.feathr.offline.evaluator.transformation;

import com.linkedin.feathr.common.FeatureDerivationFunction;
import com.linkedin.feathr.common.FeatureTypeConfig;
import com.linkedin.feathr.common.FeatureTypes;
import com.linkedin.feathr.common.FeatureValue;
import com.linkedin.feathr.common.tensor.TensorType;
import com.linkedin.feathr.compute.AnyNode;
import com.linkedin.feathr.compute.NodeReference;
import com.linkedin.feathr.compute.NodeReferenceArray;
import com.linkedin.feathr.compute.Transformation;
import com.linkedin.feathr.exception.ErrorLabel;
import com.linkedin.feathr.exception.FrameFeatureTransformationException;
import com.linkedin.feathr.offline.derived.functions.MvelFeatureDerivationFunction;
import com.linkedin.feathr.offline.derived.functions.SimpleMvelDerivationFunction;
import com.linkedin.feathr.offline.graph.FCMGraphTraverser;
import com.linkedin.feathr.offline.graph.NodeUtils$;
import com.linkedin.feathr.offline.mvel.plugins.FeathrExpressionExecutionContext;
import com.linkedin.feathr.offline.transformation.FDSConversionUtils$;
import com.linkedin.feathr.offline.transformation.FeatureColumnFormat$;
import com.linkedin.feathr.offline.util.CoercionUtilsScala$;
import com.linkedin.feathr.offline.util.FeaturizedDatasetUtils$;
import java.io.Serializable;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoder;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.Row$;
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder;
import org.apache.spark.sql.catalyst.encoders.RowEncoder$;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructField$;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.sql.types.StructType$;
import scala.Array$;
import scala.Function0;
import scala.Function1;
import scala.Option;
import scala.Option$;
import scala.Predef$;
import scala.collection.GenSeq;
import scala.collection.JavaConverters$;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.TraversableLike;
import scala.collection.immutable.;
import scala.collection.immutable.IndexedSeq$;
import scala.collection.immutable.List;
import scala.collection.immutable.Map;
import scala.collection.immutable.Nil$;
import scala.collection.mutable.ArrayOps;
import scala.collection.mutable.Map$;
import scala.math.Ordering;
import scala.reflect.ClassTag$;
import scala.runtime.BoxesRunTime;

public final class BaseDerivedFeatureOperator$ {
    public static BaseDerivedFeatureOperator$ MODULE$;

    static {
        new BaseDerivedFeatureOperator$();
    }

    public Dataset<Row> applyDerivationFunction(Transformation node, FeatureDerivationFunction derivationFunction, FCMGraphTraverser graphTraverser, Dataset<Row> contextDf) {
        Dataset dataset;
        String featureName = node.getFeatureName() == null ? (String)graphTraverser.nodeIdToFeatureName().apply((Object)node.getId()) : node.getFeatureName();
        Dataset inputDf = new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])contextDf.columns())).contains((Object)featureName) ? contextDf.drop(featureName) : contextDf;
        NodeReferenceArray inputs = node.getInputs();
        String[] inputFeatureNames = (String[])new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(inputs.toArray())).map((Function1 & Serializable & scala.Serializable)input -> {
            NodeReference inp = (NodeReference)input;
            return (String)graphTraverser.nodeIdToFeatureName().apply((Object)inp.getId());
        }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(String.class))))).sorted((Ordering)Ordering.String$.MODULE$);
        Seq inputNodes = new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(inputs.toArray())).map((Function1 & Serializable & scala.Serializable)input -> {
            NodeReference inp = (NodeReference)input;
            return (AnyNode)graphTraverser.nodes().apply(Predef$.MODULE$.Integer2int(inp.getId()));
        }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(AnyNode.class))))).toSeq();
        Map<String, FeatureTypeConfig> inputFeatureTypeConfigs = NodeUtils$.MODULE$.getFeatureTypeConfigsMap((Seq<AnyNode>)inputNodes);
        Map<String, FeatureTypeConfig> featureTypeConfigs = NodeUtils$.MODULE$.getFeatureTypeConfigsMapForTransformationNodes((Seq<Transformation>)((Seq)new .colon.colon((Object)node, (List)Nil$.MODULE$)));
        FeatureTypeConfig featureTypeConfig = (FeatureTypeConfig)featureTypeConfigs.getOrElse((Object)featureName, (Function0 & Serializable & scala.Serializable)() -> new FeatureTypeConfig(FeatureTypes.UNSPECIFIED));
        TensorType tensorType = FeaturizedDatasetUtils$.MODULE$.lookupTensorTypeForNonFMLFeatureRef(featureName, FeatureTypes.UNSPECIFIED, featureTypeConfig);
        DataType newSchema = FeaturizedDatasetUtils$.MODULE$.tensorTypeToDataFrameSchema(tensorType);
        StructType inputSchema = inputDf.schema();
        Option<FeathrExpressionExecutionContext> mvelContext = graphTraverser.mvelExpressionContext();
        StructType outputSchema = StructType$.MODULE$.apply((Seq)inputSchema.union((GenSeq)StructType$.MODULE$.apply((Seq)new .colon.colon((Object)new StructField(featureName, newSchema, true, StructField$.MODULE$.apply$default$4()), (List)Nil$.MODULE$)), Seq$.MODULE$.canBuildFrom()));
        ExpressionEncoder encoder = RowEncoder$.MODULE$.apply(outputSchema);
        Dataset outputDf = inputDf.map((Function1 & Serializable & scala.Serializable)row -> {
            Row row2;
            try {
                FeatureDerivationFunction featureDerivationFunction;
                scala.collection.mutable.Map contextFeatureValues = Map$.MODULE$.empty();
                new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])inputFeatureNames)).map((Function1 & Serializable & scala.Serializable)inputFeatureName -> {
                    FeatureTypeConfig featureTypeConfig = (FeatureTypeConfig)inputFeatureTypeConfigs.getOrElse(inputFeatureName, (Function0 & Serializable & scala.Serializable)() -> FeatureTypeConfig.UNDEFINED_TYPE_CONFIG);
                    FeatureValue featureValue = CoercionUtilsScala$.MODULE$.coerceFieldToFeatureValue((Row)row, inputSchema, (String)inputFeatureName, featureTypeConfig);
                    return contextFeatureValues.put(inputFeatureName, (Object)featureValue);
                }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Option.class)));
                Seq featureValues = (Seq)((TraversableLike)contextFeatureValues.toSeq().sortBy((Function1 & Serializable & scala.Serializable)x$1 -> (String)x$1._1(), (Ordering)Ordering.String$.MODULE$)).map((Function1 & Serializable & scala.Serializable)fv -> Option$.MODULE$.apply(fv._2()), Seq$.MODULE$.canBuildFrom());
                FeatureDerivationFunction featureDerivationFunction2 = derivationFunction;
                if (featureDerivationFunction2 instanceof MvelFeatureDerivationFunction) {
                    MvelFeatureDerivationFunction mvelFeatureDerivationFunction = (MvelFeatureDerivationFunction)featureDerivationFunction2;
                    mvelFeatureDerivationFunction.mvelContext_$eq(mvelContext);
                    featureDerivationFunction = mvelFeatureDerivationFunction;
                } else {
                    featureDerivationFunction = featureDerivationFunction2;
                }
                FeatureDerivationFunction derivedFunc = featureDerivationFunction;
                Seq<Option<FeatureValue>> unlinkedOutput = derivedFunc.getFeatures((Seq<Option<FeatureValue>>)featureValues);
                FeatureTypes featureType = ((FeatureTypeConfig)featureTypeConfigs.getOrElse((Object)featureName, (Function0 & Serializable & scala.Serializable)() -> FeatureTypeConfig.UNDEFINED_TYPE_CONFIG)).getFeatureType();
                Seq fdFeatureValue = (Seq)unlinkedOutput.map((Function1 & Serializable & scala.Serializable)fv -> {
                    Object object;
                    if (fv.isDefined()) {
                        FeatureTypes featureTypes = featureType;
                        FeatureTypes featureTypes2 = FeatureTypes.TENSOR;
                        object = !(featureTypes != null ? !((Object)((Object)featureTypes)).equals((Object)featureTypes2) : featureTypes2 != null) && !(derivationFunction instanceof SimpleMvelDerivationFunction) ? FDSConversionUtils$.MODULE$.rawToFDSRow(((FeatureValue)fv.get()).getAsTensorData(), newSchema) : FDSConversionUtils$.MODULE$.rawToFDSRow(JavaConverters$.MODULE$.mapAsScalaMapConverter(((FeatureValue)fv.get()).getAsTermVector()).asScala(), newSchema);
                    } else {
                        object = null;
                    }
                    return object;
                }, Seq$.MODULE$.canBuildFrom());
                row2 = Row$.MODULE$.fromSeq((Seq)outputSchema.indices().map((Function1 & Serializable & scala.Serializable)i -> BaseDerivedFeatureOperator$.$anonfun$applyDerivationFunction$11(inputSchema, fdFeatureValue, row, BoxesRunTime.unboxToInt((Object)i)), IndexedSeq$.MODULE$.canBuildFrom()));
            }
            catch (Exception e) {
                throw new FrameFeatureTransformationException(ErrorLabel.FEATHR_USER_ERROR, new StringBuilder(34).append("Fail to calculate derived feature ").append(featureName).toString(), (Throwable)e);
            }
            return row2;
        }, (Encoder)encoder);
        Object object = graphTraverser.nodeIdToFeatureName().apply((Object)node.getId());
        String string = node.getFeatureName();
        if (object == null ? string != null : !object.equals(string)) {
            String featureAlias = (String)graphTraverser.nodeIdToFeatureName().apply((Object)node.getId());
            graphTraverser.featureColumnFormatsMap().update((Object)featureAlias, (Object)FeatureColumnFormat$.MODULE$.RAW());
            dataset = outputDf.withColumnRenamed(featureName, featureAlias);
        } else {
            dataset = outputDf;
        }
        return dataset;
    }

    public static final /* synthetic */ Object $anonfun$applyDerivationFunction$11(StructType inputSchema$1, Seq fdFeatureValue$1, Row row$1, int i) {
        return i >= inputSchema$1.size() ? fdFeatureValue$1.apply(i - inputSchema$1.size()) : row$1.get(i);
    }

    private BaseDerivedFeatureOperator$() {
        MODULE$ = this;
    }
}

