/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import io.trino.Session;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.metadata.Metadata;
import io.trino.spi.ErrorCodeSupplier;
import io.trino.spi.StandardErrorCode;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.BooleanType;
import io.trino.spi.type.Type;
import io.trino.sql.analyzer.TypeSignatureProvider;
import io.trino.sql.ir.Booleans;
import io.trino.sql.ir.Cast;
import io.trino.sql.ir.Comparison;
import io.trino.sql.ir.Constant;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.IrExpressions;
import io.trino.sql.ir.IsNull;
import io.trino.sql.planner.LogicalPlanner;
import io.trino.sql.planner.PlanNodeIdAllocator;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.SymbolAllocator;
import io.trino.sql.planner.SymbolsExtractor;
import io.trino.sql.planner.iterative.Lookup;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.iterative.rule.ImplementLimitWithTies;
import io.trino.sql.planner.iterative.rule.Util;
import io.trino.sql.planner.optimizations.PlanNodeSearcher;
import io.trino.sql.planner.optimizations.QueryCardinalityUtil;
import io.trino.sql.planner.plan.AssignUniqueId;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.CorrelatedJoinNode;
import io.trino.sql.planner.plan.DataOrganizationSpecification;
import io.trino.sql.planner.plan.EnforceSingleRowNode;
import io.trino.sql.planner.plan.FilterNode;
import io.trino.sql.planner.plan.JoinType;
import io.trino.sql.planner.plan.LimitNode;
import io.trino.sql.planner.plan.Patterns;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.PlanVisitor;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.planner.plan.RowNumberNode;
import io.trino.sql.planner.plan.TopNNode;
import io.trino.sql.planner.plan.UnnestNode;
import io.trino.sql.planner.plan.WindowNode;
import java.lang.invoke.LambdaMetafactory;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;

public class DecorrelateUnnest
implements Rule<CorrelatedJoinNode> {
    private static final Pattern<CorrelatedJoinNode> PATTERN = Patterns.correlatedJoin().with(Pattern.nonEmpty(Patterns.CorrelatedJoin.correlation())).with(Patterns.CorrelatedJoin.filter().equalTo((Object)Booleans.TRUE)).matching(node -> node.getType() == JoinType.INNER || node.getType() == JoinType.LEFT);
    private final Metadata metadata;

    public DecorrelateUnnest(Metadata metadata) {
        this.metadata = Objects.requireNonNull(metadata, "metadata is null");
    }

    @Override
    public Pattern<CorrelatedJoinNode> getPattern() {
        return PATTERN;
    }

    @Override
    public Rule.Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Rule.Context context) {
        Optional<PlanNode> subqueryUnnest;
        PlanNode searchRoot = correlatedJoinNode.getSubquery();
        Optional<PlanNode> enforceSingleRow = PlanNodeSearcher.searchFrom(searchRoot, context.getLookup()).where(EnforceSingleRowNode.class::isInstance).recurseOnlyWhen(planNode -> false).findFirst();
        if (enforceSingleRow.isPresent()) {
            searchRoot = ((EnforceSingleRowNode)enforceSingleRow.get()).getSource();
        }
        if ((subqueryUnnest = PlanNodeSearcher.searchFrom(searchRoot, context.getLookup()).where(node -> DecorrelateUnnest.isSupportedUnnest(node, correlatedJoinNode.getCorrelation(), context.getLookup())).recurseOnlyWhen(node -> {
            TopNNode topNNode;
            LimitNode limitNode;
            return node instanceof ProjectNode || node instanceof LimitNode && (limitNode = (LimitNode)node).getCount() > 0L || node instanceof TopNNode && (topNNode = (TopNNode)node).getCount() > 0L;
        }).findFirst()).isEmpty()) {
            return Rule.Result.empty();
        }
        UnnestNode unnestNode = (UnnestNode)subqueryUnnest.get();
        Symbol uniqueSymbol = context.getSymbolAllocator().newSymbol("unique", (Type)BigintType.BIGINT);
        PlanNode input = new AssignUniqueId(context.getIdAllocator().getNextId(), correlatedJoinNode.getInput(), uniqueSymbol);
        PlanNode unnestSource = context.getLookup().resolve(unnestNode.getSource());
        if (unnestSource instanceof ProjectNode) {
            ProjectNode sourceProjection = (ProjectNode)unnestSource;
            input = new ProjectNode(sourceProjection.getId(), input, Assignments.builder().putIdentities(input.getOutputSymbols()).putAll(sourceProjection.getAssignments()).build());
        }
        JoinType unnestJoinType = JoinType.LEFT;
        if (enforceSingleRow.isEmpty() && correlatedJoinNode.getType() == JoinType.INNER && unnestNode.getJoinType() == JoinType.INNER) {
            unnestJoinType = JoinType.INNER;
        }
        Symbol ordinalitySymbol = unnestNode.getOrdinalitySymbol().orElseGet(() -> context.getSymbolAllocator().newSymbol("ordinality", (Type)BigintType.BIGINT));
        UnnestNode rewrittenUnnest = new UnnestNode(context.getIdAllocator().getNextId(), input, input.getOutputSymbols(), unnestNode.getMappings(), Optional.of(ordinalitySymbol), unnestJoinType);
        PlanNode rewrittenPlan = Rewriter.rewriteNodeSequence(correlatedJoinNode.getSubquery(), input.getOutputSymbols(), ordinalitySymbol, uniqueSymbol, rewrittenUnnest, context.getSession(), this.metadata, context.getLookup(), context.getIdAllocator(), context.getSymbolAllocator());
        if (unnestNode.getJoinType() == JoinType.INNER && rewrittenUnnest.getJoinType() == JoinType.LEFT) {
            Assignments.Builder assignments = Assignments.builder().putIdentities(correlatedJoinNode.getInput().getOutputSymbols());
            for (Symbol subquerySymbol : correlatedJoinNode.getSubquery().getOutputSymbols()) {
                assignments.put(subquerySymbol, IrExpressions.ifExpression(new IsNull(ordinalitySymbol.toSymbolReference()), new Constant(subquerySymbol.type(), null), subquerySymbol.toSymbolReference()));
            }
            rewrittenPlan = new ProjectNode(context.getIdAllocator().getNextId(), rewrittenPlan, assignments.build());
        }
        return Rule.Result.ofPlanNode(Util.restrictOutputs(context.getIdAllocator(), rewrittenPlan, (Set<Symbol>)ImmutableSet.copyOf(correlatedJoinNode.getOutputSymbols())).orElse(rewrittenPlan));
    }

    /*
     * Unable to fully structure code
     */
    private static boolean isSupportedUnnest(PlanNode node, List<Symbol> correlation, Lookup lookup) {
        if (!(node instanceof UnnestNode)) {
            return false;
        }
        unnestNode = (UnnestNode)node;
        unnestSymbols = (List)unnestNode.getMappings().stream().map((Function<UnnestNode.Mapping, Symbol>)LambdaMetafactory.metafactory(null, null, null, (Ljava/lang/Object;)Ljava/lang/Object;, getInput(), (Lio/trino/sql/planner/plan/UnnestNode$Mapping;)Lio/trino/sql/planner/Symbol;)()).collect(ImmutableList.toImmutableList());
        unnestSource = lookup.resolve(unnestNode.getSource());
        if (ImmutableSet.copyOf(correlation).containsAll((Collection)unnestSymbols)) ** GOTO lbl-1000
        if (unnestSource instanceof ProjectNode) {
            projectNode = (ProjectNode)unnestSource;
            ** if (!ImmutableSet.copyOf(correlation).containsAll(SymbolsExtractor.extractUnique(projectNode.getAssignments().getExpressions()))) goto lbl-1000
        }
        ** GOTO lbl-1000
lbl-1000:
        // 2 sources

        {
            v0 = true;
            ** GOTO lbl14
        }
lbl-1000:
        // 2 sources

        {
            v0 = false;
        }
lbl14:
        // 2 sources

        basedOnCorrelation = v0;
        return QueryCardinalityUtil.isScalar(unnestNode.getSource(), lookup) != false && unnestNode.getReplicateSymbols().isEmpty() != false && basedOnCorrelation != false && (unnestNode.getJoinType() == JoinType.INNER || unnestNode.getJoinType() == JoinType.LEFT);
    }

    private static class Rewriter
    extends PlanVisitor<RewriteResult, Void> {
        private final List<Symbol> leftOutputs;
        private final Symbol ordinalitySymbol;
        private final Symbol uniqueSymbol;
        private final PlanNode sequenceSource;
        private final Session session;
        private final Metadata metadata;
        private final Lookup lookup;
        private final PlanNodeIdAllocator idAllocator;
        private final SymbolAllocator symbolAllocator;

        private Rewriter(List<Symbol> leftOutputs, Symbol ordinalitySymbol, Symbol uniqueSymbol, PlanNode sequenceSource, Session session, Metadata metadata, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator) {
            this.leftOutputs = ImmutableList.copyOf((Collection)Objects.requireNonNull(leftOutputs, "leftOutputs is null"));
            this.ordinalitySymbol = Objects.requireNonNull(ordinalitySymbol, "ordinalitySymbol is null");
            this.uniqueSymbol = Objects.requireNonNull(uniqueSymbol, "uniqueSymbol is null");
            this.sequenceSource = Objects.requireNonNull(sequenceSource, "sequenceSource is null");
            this.session = Objects.requireNonNull(session, "session is null");
            this.metadata = Objects.requireNonNull(metadata, "metadata is null");
            this.lookup = Objects.requireNonNull(lookup, "lookup is null");
            this.idAllocator = Objects.requireNonNull(idAllocator, "idAllocator is null");
            this.symbolAllocator = Objects.requireNonNull(symbolAllocator, "symbolAllocator is null");
        }

        public static PlanNode rewriteNodeSequence(PlanNode root, List<Symbol> leftOutputs, Symbol ordinalitySymbol, Symbol uniqueSymbol, PlanNode sequenceSource, Session session, Metadata metadata, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator) {
            return new Rewriter(leftOutputs, ordinalitySymbol, uniqueSymbol, sequenceSource, session, metadata, lookup, idAllocator, symbolAllocator).rewrite(root).getPlan();
        }

        private RewriteResult rewrite(PlanNode node) {
            return this.lookup.resolve(node).accept(this, null);
        }

        @Override
        protected RewriteResult visitPlan(PlanNode node, Void context) {
            throw new IllegalStateException("Unexpected node type: " + node.getClass().getSimpleName());
        }

        @Override
        public RewriteResult visitUnnest(UnnestNode node, Void context) {
            return new RewriteResult(this.sequenceSource, Optional.empty());
        }

        @Override
        public RewriteResult visitEnforceSingleRow(EnforceSingleRowNode node, Void context) {
            PlanNode sourceNode;
            Symbol rowNumberSymbol;
            RewriteResult source = this.rewrite(node.getSource());
            if (QueryCardinalityUtil.isScalar(source.getPlan(), this.lookup)) {
                return source;
            }
            if (source.getRowNumberSymbol().isPresent()) {
                rowNumberSymbol = source.getRowNumberSymbol().get();
                sourceNode = source.getPlan();
            } else {
                rowNumberSymbol = this.symbolAllocator.newSymbol("row_number", (Type)BigintType.BIGINT);
                sourceNode = new RowNumberNode(this.idAllocator.getNextId(), source.getPlan(), (List<Symbol>)ImmutableList.of((Object)this.uniqueSymbol), false, rowNumberSymbol, Optional.of(2), Optional.empty());
            }
            Expression predicate = IrExpressions.ifExpression(new Comparison(Comparison.Operator.GREATER_THAN, rowNumberSymbol.toSymbolReference(), new Constant((Type)BigintType.BIGINT, 1L)), new Cast(LogicalPlanner.failFunction(this.metadata, (ErrorCodeSupplier)StandardErrorCode.SUBQUERY_MULTIPLE_ROWS, "Scalar sub-query has returned multiple rows"), (Type)BooleanType.BOOLEAN), Booleans.TRUE);
            return new RewriteResult(new FilterNode(this.idAllocator.getNextId(), sourceNode, predicate), Optional.of(rowNumberSymbol));
        }

        @Override
        public RewriteResult visitLimit(LimitNode node, Void context) {
            PlanNode sourceNode;
            Symbol rowNumberSymbol;
            RewriteResult source = this.rewrite(node.getSource());
            if (node.isWithTies()) {
                return new RewriteResult(ImplementLimitWithTies.rewriteLimitWithTiesWithPartitioning(node, source.getPlan(), this.session, this.metadata, this.idAllocator, this.symbolAllocator, (List<Symbol>)ImmutableList.of((Object)this.uniqueSymbol)), Optional.empty());
            }
            if (source.getRowNumberSymbol().isPresent()) {
                rowNumberSymbol = source.getRowNumberSymbol().get();
                sourceNode = source.getPlan();
            } else {
                rowNumberSymbol = this.symbolAllocator.newSymbol("row_number", (Type)BigintType.BIGINT);
                sourceNode = new RowNumberNode(this.idAllocator.getNextId(), source.getPlan(), (List<Symbol>)ImmutableList.of((Object)this.uniqueSymbol), false, rowNumberSymbol, Optional.empty(), Optional.empty());
            }
            return new RewriteResult(new FilterNode(this.idAllocator.getNextId(), sourceNode, new Comparison(Comparison.Operator.LESS_THAN_OR_EQUAL, rowNumberSymbol.toSymbolReference(), new Constant((Type)BigintType.BIGINT, node.getCount()))), Optional.of(rowNumberSymbol));
        }

        @Override
        public RewriteResult visitTopN(TopNNode node, Void context) {
            RewriteResult source = this.rewrite(node.getSource());
            Symbol rowNumberSymbol = this.symbolAllocator.newSymbol("row_number", (Type)BigintType.BIGINT);
            WindowNode.Function rowNumberFunction = new WindowNode.Function(this.metadata.resolveBuiltinFunction("row_number", (List<TypeSignatureProvider>)ImmutableList.of()), (List<Expression>)ImmutableList.of(), Optional.empty(), WindowNode.Frame.DEFAULT_FRAME, false, false);
            WindowNode windowNode = new WindowNode(this.idAllocator.getNextId(), source.getPlan(), new DataOrganizationSpecification((List<Symbol>)ImmutableList.of((Object)this.uniqueSymbol), Optional.of(node.getOrderingScheme())), (Map<Symbol, WindowNode.Function>)ImmutableMap.of((Object)rowNumberSymbol, (Object)rowNumberFunction), Optional.empty(), (Set<Symbol>)ImmutableSet.of(), 0);
            return new RewriteResult(new FilterNode(this.idAllocator.getNextId(), windowNode, new Comparison(Comparison.Operator.LESS_THAN_OR_EQUAL, rowNumberSymbol.toSymbolReference(), new Constant((Type)BigintType.BIGINT, node.getCount()))), Optional.of(rowNumberSymbol));
        }

        @Override
        public RewriteResult visitProject(ProjectNode node, Void context) {
            RewriteResult source = this.rewrite(node.getSource());
            Assignments.Builder assignments = Assignments.builder().putAll(node.getAssignments()).putIdentities(this.leftOutputs).putIdentity(this.ordinalitySymbol);
            source.getRowNumberSymbol().ifPresent(assignments::putIdentity);
            return new RewriteResult(new ProjectNode(node.getId(), source.getPlan(), assignments.build()), source.getRowNumberSymbol());
        }
    }

    private static class RewriteResult {
        PlanNode plan;
        Optional<Symbol> rowNumberSymbol;

        public RewriteResult(PlanNode plan, Optional<Symbol> rowNumberSymbol) {
            this.plan = Objects.requireNonNull(plan, "plan is null");
            this.rowNumberSymbol = Objects.requireNonNull(rowNumberSymbol, "rowNumberSymbol is null");
        }

        public PlanNode getPlan() {
            return this.plan;
        }

        public Optional<Symbol> getRowNumberSymbol() {
            return this.rowNumberSymbol;
        }
    }
}

