/*
 * Decompiled with CFR 0.152.
 */
package io.openlineage.spark3.agent.lifecycle.plan.column;

import io.openlineage.spark.agent.lifecycle.plan.column.ColumnLevelLineageContext;
import io.openlineage.spark.agent.util.ExtensionPlanUtils;
import io.openlineage.spark.agent.util.ScalaConversionUtils;
import io.openlineage.spark.extension.scala.v1.ColumnLevelLineageNode;
import io.openlineage.spark.extension.scala.v1.OutputDatasetField;
import io.openlineage.spark3.agent.lifecycle.plan.column.CustomCollectorsUtils;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.spark.sql.catalyst.expressions.Attribute;
import org.apache.spark.sql.catalyst.expressions.ExprId;
import org.apache.spark.sql.catalyst.expressions.NamedExpression;
import org.apache.spark.sql.catalyst.plans.logical.Aggregate;
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan;
import org.apache.spark.sql.catalyst.plans.logical.Project;

public class OutputFieldsCollector {
    public static void collect(ColumnLevelLineageContext context, LogicalPlan plan) {
        if (plan instanceof ColumnLevelLineageNode) {
            OutputFieldsCollector.extensionColumnLineage(context, (ColumnLevelLineageNode)plan);
        } else {
            OutputFieldsCollector.getOutputExpressionsFromRoot(plan).stream().forEach(expr -> context.getBuilder().addOutput(expr.exprId(), expr.name()));
        }
        CustomCollectorsUtils.collectOutputs(context, plan);
        if (!context.getBuilder().hasOutputs()) {
            ScalaConversionUtils.fromSeq(plan.children()).stream().forEach(childPlan -> OutputFieldsCollector.collect(context, childPlan));
        }
    }

    private static void extensionColumnLineage(ColumnLevelLineageContext context, ColumnLevelLineageNode node) {
        ScalaConversionUtils.fromSeq(node.columnLevelLineageOutputs(ExtensionPlanUtils.context(context.getEvent(), context.getOlContext())).toSeq()).stream().filter(df -> df instanceof OutputDatasetField).forEach(o -> {
            OutputDatasetField of = (OutputDatasetField)o;
            context.getBuilder().addOutput(ExprId.apply((long)of.exprId().exprId()), of.field());
        });
    }

    static List<NamedExpression> getOutputExpressionsFromRoot(LogicalPlan plan) {
        List<NamedExpression> expressions = ScalaConversionUtils.fromSeq(plan.output()).stream().filter(attr -> attr instanceof Attribute).map(attr -> attr).collect(Collectors.toList());
        if (plan instanceof Aggregate) {
            expressions.addAll(ScalaConversionUtils.fromSeq(((Aggregate)plan).aggregateExpressions()));
        } else if (plan instanceof Project) {
            expressions.addAll(ScalaConversionUtils.fromSeq(((Project)plan).projectList()));
        }
        return expressions;
    }

    static List<NamedExpression> getOutputExpressionsFromTree(LogicalPlan plan) {
        List<NamedExpression> expressions = OutputFieldsCollector.getOutputExpressionsFromRoot(plan);
        if (expressions == null || expressions.isEmpty()) {
            ScalaConversionUtils.fromSeq(plan.children()).stream().forEach(childPlan -> expressions.addAll(OutputFieldsCollector.getOutputExpressionsFromTree(childPlan)));
        }
        return expressions;
    }
}

