/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.metadata.Metadata;
import io.trino.metadata.ResolvedFunction;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.BooleanType;
import io.trino.spi.type.Type;
import io.trino.spi.type.VarcharType;
import io.trino.sql.analyzer.TypeSignatureProvider;
import io.trino.sql.analyzer.TypeSignatureTranslator;
import io.trino.sql.planner.PlanNodeIdAllocator;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.SymbolAllocator;
import io.trino.sql.planner.SymbolsExtractor;
import io.trino.sql.planner.iterative.Lookup;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.iterative.rule.ImplementLimitWithTies;
import io.trino.sql.planner.iterative.rule.Util;
import io.trino.sql.planner.optimizations.PlanNodeSearcher;
import io.trino.sql.planner.optimizations.QueryCardinalityUtil;
import io.trino.sql.planner.plan.AssignUniqueId;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.CorrelatedJoinNode;
import io.trino.sql.planner.plan.EnforceSingleRowNode;
import io.trino.sql.planner.plan.FilterNode;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.LimitNode;
import io.trino.sql.planner.plan.Patterns;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.PlanVisitor;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.planner.plan.RowNumberNode;
import io.trino.sql.planner.plan.TopNNode;
import io.trino.sql.planner.plan.UnnestNode;
import io.trino.sql.planner.plan.WindowNode;
import io.trino.sql.tree.BooleanLiteral;
import io.trino.sql.tree.Cast;
import io.trino.sql.tree.ComparisonExpression;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.FunctionCall;
import io.trino.sql.tree.GenericLiteral;
import io.trino.sql.tree.IfExpression;
import io.trino.sql.tree.IsNullPredicate;
import io.trino.sql.tree.NullLiteral;
import io.trino.sql.tree.QualifiedName;
import io.trino.sql.tree.StringLiteral;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;

public class DecorrelateUnnest
implements Rule<CorrelatedJoinNode> {
    private static final Pattern<CorrelatedJoinNode> PATTERN = Patterns.correlatedJoin().with(Pattern.nonEmpty(Patterns.CorrelatedJoin.correlation())).with(Patterns.CorrelatedJoin.filter().equalTo((Object)BooleanLiteral.TRUE_LITERAL)).matching(node -> node.getType() == CorrelatedJoinNode.Type.INNER || node.getType() == CorrelatedJoinNode.Type.LEFT);
    private final Metadata metadata;

    public DecorrelateUnnest(Metadata metadata) {
        this.metadata = Objects.requireNonNull(metadata, "metadata is null");
    }

    @Override
    public Pattern<CorrelatedJoinNode> getPattern() {
        return PATTERN;
    }

    @Override
    public Rule.Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Rule.Context context) {
        Optional subqueryUnnest;
        PlanNode searchRoot = correlatedJoinNode.getSubquery();
        Optional enforceSingleRow = PlanNodeSearcher.searchFrom(searchRoot, context.getLookup()).where(EnforceSingleRowNode.class::isInstance).recurseOnlyWhen(planNode -> false).findFirst();
        if (enforceSingleRow.isPresent()) {
            searchRoot = ((EnforceSingleRowNode)enforceSingleRow.get()).getSource();
        }
        if ((subqueryUnnest = PlanNodeSearcher.searchFrom(searchRoot, context.getLookup()).where(node -> DecorrelateUnnest.isSupportedUnnest(node, correlatedJoinNode.getCorrelation(), context.getLookup())).recurseOnlyWhen(node -> node instanceof ProjectNode || node instanceof LimitNode && ((LimitNode)node).getCount() > 0L || node instanceof TopNNode && ((TopNNode)node).getCount() > 0L).findFirst()).isEmpty()) {
            return Rule.Result.empty();
        }
        UnnestNode unnestNode = (UnnestNode)subqueryUnnest.get();
        Symbol uniqueSymbol = context.getSymbolAllocator().newSymbol("unique", (Type)BigintType.BIGINT);
        PlanNode input = new AssignUniqueId(context.getIdAllocator().getNextId(), correlatedJoinNode.getInput(), uniqueSymbol);
        PlanNode unnestSource = context.getLookup().resolve(unnestNode.getSource());
        if (unnestSource instanceof ProjectNode) {
            ProjectNode sourceProjection = (ProjectNode)unnestSource;
            input = new ProjectNode(sourceProjection.getId(), input, Assignments.builder().putIdentities(input.getOutputSymbols()).putAll(sourceProjection.getAssignments()).build());
        }
        JoinNode.Type unnestJoinType = JoinNode.Type.LEFT;
        if (enforceSingleRow.isEmpty() && correlatedJoinNode.getType() == CorrelatedJoinNode.Type.INNER && unnestNode.getJoinType() == JoinNode.Type.INNER) {
            unnestJoinType = JoinNode.Type.INNER;
        }
        Symbol ordinalitySymbol = unnestNode.getOrdinalitySymbol().orElseGet(() -> context.getSymbolAllocator().newSymbol("ordinality", (Type)BigintType.BIGINT));
        UnnestNode rewrittenUnnest = new UnnestNode(context.getIdAllocator().getNextId(), input, input.getOutputSymbols(), unnestNode.getMappings(), Optional.of(ordinalitySymbol), unnestJoinType, Optional.empty());
        PlanNode rewrittenPlan = Rewriter.rewriteNodeSequence(correlatedJoinNode.getSubquery(), input.getOutputSymbols(), ordinalitySymbol, uniqueSymbol, rewrittenUnnest, this.metadata, context.getLookup(), context.getIdAllocator(), context.getSymbolAllocator());
        if (unnestNode.getJoinType() == JoinNode.Type.INNER && rewrittenUnnest.getJoinType() == JoinNode.Type.LEFT) {
            Assignments.Builder assignments = Assignments.builder().putIdentities(correlatedJoinNode.getInput().getOutputSymbols());
            for (Symbol subquerySymbol : correlatedJoinNode.getSubquery().getOutputSymbols()) {
                assignments.put(subquerySymbol, (Expression)new IfExpression((Expression)new IsNullPredicate((Expression)ordinalitySymbol.toSymbolReference()), (Expression)new Cast((Expression)new NullLiteral(), TypeSignatureTranslator.toSqlType(context.getSymbolAllocator().getTypes().get(subquerySymbol))), (Expression)subquerySymbol.toSymbolReference()));
            }
            rewrittenPlan = new ProjectNode(context.getIdAllocator().getNextId(), rewrittenPlan, assignments.build());
        }
        return Rule.Result.ofPlanNode(Util.restrictOutputs(context.getIdAllocator(), rewrittenPlan, (Set<Symbol>)ImmutableSet.copyOf(correlatedJoinNode.getOutputSymbols())).orElse(rewrittenPlan));
    }

    private static boolean isSupportedUnnest(PlanNode node, List<Symbol> correlation, Lookup lookup) {
        if (!(node instanceof UnnestNode)) {
            return false;
        }
        UnnestNode unnestNode = (UnnestNode)node;
        List unnestSymbols = (List)unnestNode.getMappings().stream().map(UnnestNode.Mapping::getInput).collect(ImmutableList.toImmutableList());
        PlanNode unnestSource = lookup.resolve(unnestNode.getSource());
        boolean basedOnCorrelation = ImmutableSet.copyOf(correlation).containsAll((Collection)unnestSymbols) || unnestSource instanceof ProjectNode && ImmutableSet.copyOf(correlation).containsAll(SymbolsExtractor.extractUnique(((ProjectNode)unnestSource).getAssignments().getExpressions()));
        return !(!QueryCardinalityUtil.isScalar(unnestNode.getSource(), lookup) || !unnestNode.getReplicateSymbols().isEmpty() || !basedOnCorrelation || unnestNode.getJoinType() != JoinNode.Type.INNER && unnestNode.getJoinType() != JoinNode.Type.LEFT || !unnestNode.getFilter().isEmpty() && !unnestNode.getFilter().get().equals((Object)BooleanLiteral.TRUE_LITERAL));
    }

    private static class RewriteResult {
        PlanNode plan;
        Optional<Symbol> rowNumberSymbol;

        public RewriteResult(PlanNode plan, Optional<Symbol> rowNumberSymbol) {
            this.plan = Objects.requireNonNull(plan, "plan is null");
            this.rowNumberSymbol = Objects.requireNonNull(rowNumberSymbol, "rowNumberSymbol is null");
        }

        public PlanNode getPlan() {
            return this.plan;
        }

        public Optional<Symbol> getRowNumberSymbol() {
            return this.rowNumberSymbol;
        }
    }

    private static class Rewriter
    extends PlanVisitor<RewriteResult, Void> {
        private final List<Symbol> leftOutputs;
        private final Symbol ordinalitySymbol;
        private final Symbol uniqueSymbol;
        private final PlanNode sequenceSource;
        private final Metadata metadata;
        private final Lookup lookup;
        private final PlanNodeIdAllocator idAllocator;
        private final SymbolAllocator symbolAllocator;

        private Rewriter(List<Symbol> leftOutputs, Symbol ordinalitySymbol, Symbol uniqueSymbol, PlanNode sequenceSource, Metadata metadata, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator) {
            this.leftOutputs = ImmutableList.copyOf((Collection)Objects.requireNonNull(leftOutputs, "leftOutputs is null"));
            this.ordinalitySymbol = Objects.requireNonNull(ordinalitySymbol, "ordinalitySymbol is null");
            this.uniqueSymbol = Objects.requireNonNull(uniqueSymbol, "uniqueSymbol is null");
            this.sequenceSource = Objects.requireNonNull(sequenceSource, "sequenceSource is null");
            this.metadata = Objects.requireNonNull(metadata, "metadata is null");
            this.lookup = Objects.requireNonNull(lookup, "lookup is null");
            this.idAllocator = Objects.requireNonNull(idAllocator, "idAllocator is null");
            this.symbolAllocator = Objects.requireNonNull(symbolAllocator, "symbolAllocator is null");
        }

        public static PlanNode rewriteNodeSequence(PlanNode root, List<Symbol> leftOutputs, Symbol ordinalitySymbol, Symbol uniqueSymbol, PlanNode sequenceSource, Metadata metadata, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator) {
            return new Rewriter(leftOutputs, ordinalitySymbol, uniqueSymbol, sequenceSource, metadata, lookup, idAllocator, symbolAllocator).rewrite(root).getPlan();
        }

        private RewriteResult rewrite(PlanNode node) {
            return this.lookup.resolve(node).accept(this, null);
        }

        @Override
        protected RewriteResult visitPlan(PlanNode node, Void context) {
            throw new IllegalStateException("Unexpected node type: " + node.getClass().getSimpleName());
        }

        @Override
        public RewriteResult visitUnnest(UnnestNode node, Void context) {
            return new RewriteResult(this.sequenceSource, Optional.empty());
        }

        @Override
        public RewriteResult visitEnforceSingleRow(EnforceSingleRowNode node, Void context) {
            PlanNode sourceNode;
            Symbol rowNumberSymbol;
            RewriteResult source = this.rewrite(node.getSource());
            if (QueryCardinalityUtil.isScalar(source.getPlan(), this.lookup)) {
                return source;
            }
            if (source.getRowNumberSymbol().isPresent()) {
                rowNumberSymbol = source.getRowNumberSymbol().get();
                sourceNode = source.getPlan();
            } else {
                rowNumberSymbol = this.symbolAllocator.newSymbol("row_number", (Type)BigintType.BIGINT);
                sourceNode = new RowNumberNode(this.idAllocator.getNextId(), source.getPlan(), (List<Symbol>)ImmutableList.of((Object)this.uniqueSymbol), false, rowNumberSymbol, Optional.of(2), Optional.empty());
            }
            ResolvedFunction fail = this.metadata.resolveFunction(QualifiedName.of((String)"fail"), TypeSignatureProvider.fromTypes(new Type[]{VarcharType.VARCHAR}));
            IfExpression predicate = new IfExpression((Expression)new ComparisonExpression(ComparisonExpression.Operator.GREATER_THAN, (Expression)rowNumberSymbol.toSymbolReference(), (Expression)new GenericLiteral("BIGINT", "1")), (Expression)new Cast((Expression)new FunctionCall(fail.toQualifiedName(), (List)ImmutableList.of((Object)new Cast((Expression)new StringLiteral("Scalar sub-query has returned multiple rows"), TypeSignatureTranslator.toSqlType((Type)VarcharType.VARCHAR)))), TypeSignatureTranslator.toSqlType((Type)BooleanType.BOOLEAN)), (Expression)BooleanLiteral.TRUE_LITERAL);
            return new RewriteResult(new FilterNode(this.idAllocator.getNextId(), sourceNode, (Expression)predicate), Optional.of(rowNumberSymbol));
        }

        @Override
        public RewriteResult visitLimit(LimitNode node, Void context) {
            PlanNode sourceNode;
            Symbol rowNumberSymbol;
            RewriteResult source = this.rewrite(node.getSource());
            if (node.isWithTies()) {
                return new RewriteResult(ImplementLimitWithTies.rewriteLimitWithTiesWithPartitioning(node, source.getPlan(), this.metadata, this.idAllocator, this.symbolAllocator, (List<Symbol>)ImmutableList.of((Object)this.uniqueSymbol)), Optional.empty());
            }
            if (source.getRowNumberSymbol().isPresent()) {
                rowNumberSymbol = source.getRowNumberSymbol().get();
                sourceNode = source.getPlan();
            } else {
                rowNumberSymbol = this.symbolAllocator.newSymbol("row_number", (Type)BigintType.BIGINT);
                sourceNode = new RowNumberNode(this.idAllocator.getNextId(), source.getPlan(), (List<Symbol>)ImmutableList.of((Object)this.uniqueSymbol), false, rowNumberSymbol, Optional.empty(), Optional.empty());
            }
            return new RewriteResult(new FilterNode(this.idAllocator.getNextId(), sourceNode, (Expression)new ComparisonExpression(ComparisonExpression.Operator.LESS_THAN_OR_EQUAL, (Expression)rowNumberSymbol.toSymbolReference(), (Expression)new GenericLiteral("BIGINT", Long.toString(node.getCount())))), Optional.of(rowNumberSymbol));
        }

        @Override
        public RewriteResult visitTopN(TopNNode node, Void context) {
            RewriteResult source = this.rewrite(node.getSource());
            Symbol rowNumberSymbol = this.symbolAllocator.newSymbol("row_number", (Type)BigintType.BIGINT);
            WindowNode.Function rowNumberFunction = new WindowNode.Function(this.metadata.resolveFunction(QualifiedName.of((String)"row_number"), (List<TypeSignatureProvider>)ImmutableList.of()), (List<Expression>)ImmutableList.of(), WindowNode.Frame.DEFAULT_FRAME, false);
            WindowNode windowNode = new WindowNode(this.idAllocator.getNextId(), source.getPlan(), new WindowNode.Specification((List<Symbol>)ImmutableList.of((Object)this.uniqueSymbol), Optional.of(node.getOrderingScheme())), (Map<Symbol, WindowNode.Function>)ImmutableMap.of((Object)rowNumberSymbol, (Object)rowNumberFunction), Optional.empty(), (Set<Symbol>)ImmutableSet.of(), 0);
            return new RewriteResult(new FilterNode(this.idAllocator.getNextId(), windowNode, (Expression)new ComparisonExpression(ComparisonExpression.Operator.LESS_THAN_OR_EQUAL, (Expression)rowNumberSymbol.toSymbolReference(), (Expression)new GenericLiteral("BIGINT", Long.toString(node.getCount())))), Optional.of(rowNumberSymbol));
        }

        @Override
        public RewriteResult visitProject(ProjectNode node, Void context) {
            RewriteResult source = this.rewrite(node.getSource());
            Assignments.Builder assignments = Assignments.builder().putAll(node.getAssignments()).putIdentities(this.leftOutputs).putIdentity(this.ordinalitySymbol);
            source.getRowNumberSymbol().ifPresent(assignments::putIdentity);
            return new RewriteResult(new ProjectNode(node.getId(), source.getPlan(), assignments.build()), source.getRowNumberSymbol());
        }
    }
}

