/*
 * Decompiled with CFR 0.152.
 */
package io.openlineage.spark3.agent.lifecycle.plan.column;

import io.openlineage.spark.agent.lifecycle.plan.column.ColumnLevelLineageBuilder;
import io.openlineage.spark.agent.lifecycle.plan.column.ColumnLevelLineageContext;
import io.openlineage.spark.agent.util.ExtensionPlanUtils;
import io.openlineage.spark.agent.util.ScalaConversionUtils;
import io.openlineage.spark.extension.scala.v1.ColumnLevelLineageNode;
import io.openlineage.spark.extension.scala.v1.ExpressionDependencyWithDelegate;
import io.openlineage.spark.extension.scala.v1.ExpressionDependencyWithIdentifier;
import io.openlineage.spark.shaded.org.apache.commons.lang3.reflect.MethodUtils;
import io.openlineage.spark3.agent.lifecycle.plan.column.CustomCollectorsUtils;
import io.openlineage.spark3.agent.lifecycle.plan.column.JdbcColumnLineageCollector;
import io.openlineage.spark3.agent.lifecycle.plan.column.visitors.ExpressionDependencyVisitor;
import io.openlineage.spark3.agent.lifecycle.plan.column.visitors.IcebergMergeIntoDependencyVisitor;
import io.openlineage.spark3.agent.lifecycle.plan.column.visitors.UnionDependencyVisitor;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;
import org.apache.spark.sql.catalyst.expressions.ExprId;
import org.apache.spark.sql.catalyst.expressions.Expression;
import org.apache.spark.sql.catalyst.expressions.NamedExpression;
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression;
import org.apache.spark.sql.catalyst.plans.logical.Aggregate;
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan;
import org.apache.spark.sql.catalyst.plans.logical.Project;
import org.apache.spark.sql.execution.datasources.LogicalRelation;
import org.apache.spark.sql.execution.datasources.jdbc.JDBCRelation;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.collection.Seq;
import scala.runtime.BoxedUnit;

public class ExpressionDependencyCollector {
    private static final Logger log = LoggerFactory.getLogger(ExpressionDependencyCollector.class);
    private static final List<ExpressionDependencyVisitor> expressionDependencyVisitors = Arrays.asList(new UnionDependencyVisitor(), new IcebergMergeIntoDependencyVisitor());

    static void collect(ColumnLevelLineageContext context, LogicalPlan plan) {
        plan.foreach(node -> {
            ExpressionDependencyCollector.collectFromNode(context, node);
            return BoxedUnit.UNIT;
        });
    }

    static void collectFromNode(ColumnLevelLineageContext context, LogicalPlan node) {
        expressionDependencyVisitors.stream().filter(collector -> collector.isDefinedAt(node)).forEach(collector -> collector.apply(node, context.getBuilder()));
        CustomCollectorsUtils.collectExpressionDependencies(context, node);
        LinkedList expressions = new LinkedList();
        if (node instanceof ColumnLevelLineageNode) {
            ExpressionDependencyCollector.extensionColumnLineage(context, (ColumnLevelLineageNode)node);
        } else if (node instanceof Project) {
            expressions.addAll(ScalaConversionUtils.fromSeq(((Project)node).projectList()));
        } else if (node instanceof Aggregate) {
            expressions.addAll(ScalaConversionUtils.fromSeq(((Aggregate)node).aggregateExpressions()));
        } else if (node instanceof LogicalRelation && ((LogicalRelation)node).relation() instanceof JDBCRelation) {
            JdbcColumnLineageCollector.extractExpressionsFromJDBC(node, context.getBuilder());
        }
        expressions.stream().forEach(expr -> ExpressionDependencyCollector.traverseExpression((Expression)expr, expr.exprId(), context.getBuilder()));
    }

    private static void extensionColumnLineage(ColumnLevelLineageContext context, ColumnLevelLineageNode node) {
        List deps = ScalaConversionUtils.fromSeq(node.columnLevelLineageDependencies(ExtensionPlanUtils.context(context.getEvent(), context.getOlContext())).toSeq());
        deps.stream().filter(e -> e instanceof ExpressionDependencyWithDelegate).map(e -> (ExpressionDependencyWithDelegate)e).filter(e -> e.expression() instanceof Expression).forEach(e -> ExpressionDependencyCollector.traverseExpression((Expression)e.expression(), ExprId.apply((long)e.outputExprId().exprId()), context.getBuilder()));
        deps.stream().filter(e -> e instanceof ExpressionDependencyWithIdentifier).map(e -> (ExpressionDependencyWithIdentifier)e).forEach(d -> ScalaConversionUtils.fromSeq(d.inputExprIds().toSeq()).stream().forEach(i -> context.getBuilder().addDependency(ExprId.apply((long)d.outputExprId().exprId()), ExprId.apply((long)i.exprId()))));
    }

    public static void traverseExpression(Expression expr, ExprId outputExprId, ColumnLevelLineageBuilder builder) {
        if (expr instanceof NamedExpression && !((NamedExpression)expr).exprId().equals((Object)outputExprId)) {
            builder.addDependency(outputExprId, ((NamedExpression)expr).exprId());
        }
        if (expr.children() != null) {
            ScalaConversionUtils.fromSeq(expr.children()).stream().forEach(child -> ExpressionDependencyCollector.traverseExpression(child, outputExprId, builder));
        }
        if (expr instanceof AggregateExpression) {
            AggregateExpression aggr = (AggregateExpression)expr;
            if (MethodUtils.getAccessibleMethod(AggregateExpression.class, "resultId", new Class[0]) != null) {
                builder.addDependency(outputExprId, aggr.resultId());
            } else {
                try {
                    Seq resultIds = (Seq)MethodUtils.invokeMethod(aggr, "resultIds");
                    ScalaConversionUtils.fromSeq(resultIds).stream().forEach(e -> builder.addDependency(outputExprId, (ExprId)e));
                }
                catch (Exception e2) {
                    log.warn("Failed extracting resultIds from AggregateExpression", (Throwable)e2);
                }
            }
        }
    }
}

