/*
 * Decompiled with CFR 0.152.
 */
package io.openlineage.spark3.agent.lifecycle.plan.column.visitors;

import io.openlineage.spark.agent.lifecycle.plan.column.ColumnLevelLineageBuilder;
import io.openlineage.spark.agent.util.ScalaConversionUtils;
import io.openlineage.spark3.agent.lifecycle.plan.column.ExpressionDependencyCollector;
import io.openlineage.spark3.agent.lifecycle.plan.column.visitors.ExpressionDependencyVisitor;
import java.util.Arrays;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.spark.sql.catalyst.expressions.Attribute;
import org.apache.spark.sql.catalyst.expressions.Expression;
import org.apache.spark.sql.catalyst.expressions.NamedExpression;
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.Option;
import scala.collection.Seq;

public class IcebergMergeIntoDependencyVisitor
implements ExpressionDependencyVisitor {
    private static final Logger log = LoggerFactory.getLogger(IcebergMergeIntoDependencyVisitor.class);
    private static final String MERGE_ROWS_CLASS_NAME = "org.apache.spark.sql.catalyst.plans.logical.MergeRows";
    private static final String MERGE_INTO_CLASS_NAME = "org.apache.spark.sql.catalyst.plans.logical.MergeInto";
    public static final String MATCHED_OUTPUTS = "matchedOutputs";
    public static final String NOT_MATCHED_OUTPUTS = "notMatchedOutputs";

    @Override
    public boolean isDefinedAt(LogicalPlan plan) {
        return Arrays.asList(MERGE_INTO_CLASS_NAME, MERGE_ROWS_CLASS_NAME).contains(plan.getClass().getCanonicalName());
    }

    @Override
    public void apply(LogicalPlan node, ColumnLevelLineageBuilder builder) {
        try {
            String nodeClass = node.getClass().getCanonicalName();
            if (MERGE_ROWS_CLASS_NAME.equals(nodeClass)) {
                Class<?> mergeRows = Class.forName(MERGE_ROWS_CLASS_NAME);
                boolean isNewerImplementation = Optional.of((Seq)mergeRows.getMethod(MATCHED_OUTPUTS, new Class[0]).invoke((Object)node, new Object[0])).filter(seq -> seq.size() > 0).map(seq -> (Seq)seq.apply(0)).filter(seq -> seq.size() > 0).map(seq -> seq.apply(0)).filter(el -> el instanceof Seq).isPresent();
                if (isNewerImplementation) {
                    Seq matched = (Seq)mergeRows.getMethod(MATCHED_OUTPUTS, new Class[0]).invoke((Object)node, new Object[0]);
                    Seq notMatched = (Seq)mergeRows.getMethod(NOT_MATCHED_OUTPUTS, new Class[0]).invoke((Object)node, new Object[0]);
                    this.collect((Seq<Attribute>)node.output(), IcebergMergeIntoDependencyVisitor.fromSeqNestedThreeTimes((Seq<Seq<Seq<Expression>>>)matched).get(0), IcebergMergeIntoDependencyVisitor.fromSeqNestedTwice((Seq<Seq<Expression>>)notMatched), builder);
                    if (matched.size() > 1) {
                        this.collect((Seq<Attribute>)node.output(), IcebergMergeIntoDependencyVisitor.fromSeqNestedThreeTimes((Seq<Seq<Seq<Expression>>>)matched).get(1), IcebergMergeIntoDependencyVisitor.fromSeqNestedTwice((Seq<Seq<Expression>>)notMatched), builder);
                    }
                } else {
                    Seq matched = (Seq)mergeRows.getMethod(MATCHED_OUTPUTS, new Class[0]).invoke((Object)node, new Object[0]);
                    Seq notMatched = (Seq)mergeRows.getMethod(NOT_MATCHED_OUTPUTS, new Class[0]).invoke((Object)node, new Object[0]);
                    this.collect((Seq<Attribute>)node.output(), IcebergMergeIntoDependencyVisitor.fromSeqNestedTwice((Seq<Seq<Expression>>)matched), IcebergMergeIntoDependencyVisitor.fromSeqNestedTwice((Seq<Seq<Expression>>)notMatched), builder);
                }
            } else if (MERGE_INTO_CLASS_NAME.equals(nodeClass)) {
                Class<?> mergeInto = Class.forName(MERGE_INTO_CLASS_NAME);
                Class<?> mergeIntoParamsClass = Class.forName("org.apache.spark.sql.catalyst.plans.logical.MergeIntoParams");
                Object mergeIntoParams = mergeInto.getMethod("mergeIntoProcessor", new Class[0]).invoke((Object)node, new Object[0]);
                Seq matched = (Seq)mergeIntoParamsClass.getMethod(MATCHED_OUTPUTS, new Class[0]).invoke(mergeIntoParams, new Object[0]);
                Seq notMatched = (Seq)mergeIntoParamsClass.getMethod(NOT_MATCHED_OUTPUTS, new Class[0]).invoke(mergeIntoParams, new Object[0]);
                this.collect((Seq<Attribute>)node.output(), ScalaConversionUtils.fromSeq(matched).stream().filter(Option::isDefined).map(Option::get).map(is -> ScalaConversionUtils.fromSeq(is)).collect(Collectors.toList()), ScalaConversionUtils.fromSeq(notMatched).stream().filter(Option::isDefined).map(Option::get).map(is -> ScalaConversionUtils.fromSeq(is)).collect(Collectors.toList()), builder);
            }
        }
        catch (Exception e) {
            log.error("Collecting dependencies for Iceberg MergeInto failed", (Throwable)e);
        }
    }

    void collect(Seq<Attribute> outputSeq, List<List<Expression>> matched, List<List<Expression>> notMatched, ColumnLevelLineageBuilder builder) {
        Attribute[] output = ScalaConversionUtils.fromSeq(outputSeq).toArray(new Attribute[0]);
        IntStream.range(0, output.length).forEach(position -> {
            matched.stream().filter(exprs -> exprs.size() > position).map(exprs -> (Expression)exprs.get(position)).filter(expr -> expr instanceof NamedExpression).forEach(expr -> ExpressionDependencyCollector.traverseExpression(expr, output[position].exprId(), builder));
            notMatched.stream().filter(exprs -> exprs.size() > position).map(exprs -> (Expression)exprs.get(position)).filter(expr -> expr instanceof NamedExpression).forEach(expr -> ExpressionDependencyCollector.traverseExpression(expr, output[position].exprId(), builder));
        });
    }

    private static List<List<List<Expression>>> fromSeqNestedThreeTimes(Seq<Seq<Seq<Expression>>> expressions) {
        return ScalaConversionUtils.fromSeq(expressions).stream().map(s -> ScalaConversionUtils.fromSeq(s)).map(l -> l.stream().map(is -> ScalaConversionUtils.fromSeq(is)).collect(Collectors.toList())).collect(Collectors.toList());
    }

    private static List<List<Expression>> fromSeqNestedTwice(Seq<Seq<Expression>> expressions) {
        return ScalaConversionUtils.fromSeq(expressions).stream().map(s -> ScalaConversionUtils.fromSeq(s)).collect(Collectors.toList());
    }
}

