/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner.optimizations;

import com.google.common.base.Verify;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
import io.trino.Session;
import io.trino.SystemSessionProperties;
import io.trino.spi.function.FunctionId;
import io.trino.spi.predicate.Domain;
import io.trino.spi.predicate.Range;
import io.trino.spi.predicate.TupleDomain;
import io.trino.spi.predicate.ValueSet;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.Type;
import io.trino.sql.PlannerContext;
import io.trino.sql.analyzer.TypeSignatureProvider;
import io.trino.sql.ir.Booleans;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.IrUtils;
import io.trino.sql.planner.DomainTranslator;
import io.trino.sql.planner.PlanNodeIdAllocator;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.optimizations.PlanOptimizer;
import io.trino.sql.planner.plan.ChildReplacer;
import io.trino.sql.planner.plan.FilterNode;
import io.trino.sql.planner.plan.LimitNode;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.RowNumberNode;
import io.trino.sql.planner.plan.SimplePlanRewriter;
import io.trino.sql.planner.plan.TopNRankingNode;
import io.trino.sql.planner.plan.ValuesNode;
import io.trino.sql.planner.plan.WindowNode;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.OptionalInt;

public class WindowFilterPushDown
implements PlanOptimizer {
    private final PlannerContext plannerContext;

    public WindowFilterPushDown(PlannerContext plannerContext) {
        this.plannerContext = Objects.requireNonNull(plannerContext, "plannerContext is null");
    }

    @Override
    public PlanNode optimize(PlanNode plan, PlanOptimizer.Context context) {
        Objects.requireNonNull(plan, "plan is null");
        return SimplePlanRewriter.rewriteWith(new Rewriter(context.idAllocator(), this.plannerContext, context.session()), plan, null);
    }

    private static class Rewriter
    extends SimplePlanRewriter<Void> {
        private final PlanNodeIdAllocator idAllocator;
        private final PlannerContext plannerContext;
        private final Session session;
        private final FunctionId rowNumberFunctionId;
        private final FunctionId rankFunctionId;
        private final DomainTranslator domainTranslator;

        private Rewriter(PlanNodeIdAllocator idAllocator, PlannerContext plannerContext, Session session) {
            this.idAllocator = Objects.requireNonNull(idAllocator, "idAllocator is null");
            this.plannerContext = Objects.requireNonNull(plannerContext, "plannerContext is null");
            this.session = Objects.requireNonNull(session, "session is null");
            this.rowNumberFunctionId = plannerContext.getMetadata().resolveBuiltinFunction("row_number", (List<TypeSignatureProvider>)ImmutableList.of()).functionId();
            this.rankFunctionId = plannerContext.getMetadata().resolveBuiltinFunction("rank", (List<TypeSignatureProvider>)ImmutableList.of()).functionId();
            this.domainTranslator = new DomainTranslator(plannerContext.getMetadata());
        }

        @Override
        public PlanNode visitWindow(WindowNode node, SimplePlanRewriter.RewriteContext<Void> context) {
            PlanNode rewrittenSource = context.rewrite(node.getSource());
            if (this.canReplaceWithRowNumber(node)) {
                return new RowNumberNode(this.idAllocator.getNextId(), rewrittenSource, node.getPartitionBy(), false, (Symbol)Iterables.getOnlyElement(node.getWindowFunctions().keySet()), Optional.empty(), Optional.empty());
            }
            return ChildReplacer.replaceChildren(node, (List<PlanNode>)ImmutableList.of((Object)rewrittenSource));
        }

        @Override
        public PlanNode visitLimit(LimitNode node, SimplePlanRewriter.RewriteContext<Void> context) {
            if (node.isWithTies() || node.requiresPreSortedInputs()) {
                return context.defaultRewrite(node);
            }
            if (node.getCount() == 0L) {
                return node;
            }
            if (node.getCount() > Integer.MAX_VALUE) {
                return context.defaultRewrite(node);
            }
            PlanNode source = context.rewrite(node.getSource());
            int limit = Math.toIntExact(node.getCount());
            if (source instanceof RowNumberNode) {
                RowNumberNode rowNumberNode = Rewriter.mergeLimit((RowNumberNode)source, limit);
                if (rowNumberNode.getPartitionBy().isEmpty()) {
                    return rowNumberNode;
                }
                source = rowNumberNode;
            } else if (source instanceof WindowNode) {
                Optional<TopNRankingNode.RankingType> rankingType;
                WindowNode windowNode = (WindowNode)source;
                if (SystemSessionProperties.isOptimizeTopNRanking(this.session) && (rankingType = this.toTopNRankingType(windowNode)).isPresent()) {
                    TopNRankingNode topNRankingNode = this.convertToTopNRanking(windowNode, rankingType.get(), limit);
                    if (rankingType.get() == TopNRankingNode.RankingType.ROW_NUMBER && windowNode.getPartitionBy().isEmpty()) {
                        return topNRankingNode;
                    }
                    source = topNRankingNode;
                }
            }
            return ChildReplacer.replaceChildren(node, (List<PlanNode>)ImmutableList.of((Object)source));
        }

        @Override
        public PlanNode visitFilter(FilterNode node, SimplePlanRewriter.RewriteContext<Void> context) {
            PlanNode source = context.rewrite(node.getSource());
            TupleDomain<Symbol> tupleDomain = DomainTranslator.getExtractionResult(this.plannerContext, this.session, node.getPredicate()).getTupleDomain();
            if (source instanceof RowNumberNode) {
                Symbol rowNumberSymbol = ((RowNumberNode)source).getRowNumberSymbol();
                OptionalInt upperBound = Rewriter.extractUpperBound(tupleDomain, rowNumberSymbol);
                if (upperBound.isPresent()) {
                    if (upperBound.getAsInt() <= 0) {
                        return new ValuesNode(node.getId(), node.getOutputSymbols(), (List<Expression>)ImmutableList.of());
                    }
                    source = Rewriter.mergeLimit((RowNumberNode)source, upperBound.getAsInt());
                    return this.rewriteFilterSource(node, source, rowNumberSymbol, ((RowNumberNode)source).getMaxRowCountPerPartition().get());
                }
            } else if (source instanceof WindowNode) {
                Symbol rankingSymbol;
                OptionalInt upperBound;
                Optional<TopNRankingNode.RankingType> rankingType;
                WindowNode windowNode = (WindowNode)source;
                if (SystemSessionProperties.isOptimizeTopNRanking(this.session) && (rankingType = this.toTopNRankingType(windowNode)).isPresent() && (upperBound = Rewriter.extractUpperBound(tupleDomain, rankingSymbol = (Symbol)((Map.Entry)Iterables.getOnlyElement(windowNode.getWindowFunctions().entrySet())).getKey())).isPresent()) {
                    if (upperBound.getAsInt() <= 0) {
                        return new ValuesNode(node.getId(), node.getOutputSymbols(), (List<Expression>)ImmutableList.of());
                    }
                    source = this.convertToTopNRanking(windowNode, rankingType.get(), upperBound.getAsInt());
                    return this.rewriteFilterSource(node, source, rankingSymbol, upperBound.getAsInt());
                }
            }
            return ChildReplacer.replaceChildren(node, (List<PlanNode>)ImmutableList.of((Object)source));
        }

        private PlanNode rewriteFilterSource(FilterNode filterNode, PlanNode source, Symbol rankingSymbol, int upperBound) {
            DomainTranslator.ExtractionResult extractionResult = DomainTranslator.getExtractionResult(this.plannerContext, this.session, filterNode.getPredicate());
            TupleDomain<Symbol> tupleDomain = extractionResult.getTupleDomain();
            if (!Rewriter.allRankingValuesInDomain(tupleDomain, rankingSymbol, upperBound)) {
                return new FilterNode(filterNode.getId(), source, filterNode.getPredicate());
            }
            TupleDomain newTupleDomain = tupleDomain.filter((symbol, domain) -> !symbol.equals(rankingSymbol));
            Expression newPredicate = IrUtils.combineConjuncts(extractionResult.getRemainingExpression(), this.domainTranslator.toPredicate((TupleDomain<Symbol>)newTupleDomain));
            if (newPredicate.equals(Booleans.TRUE)) {
                return source;
            }
            return new FilterNode(filterNode.getId(), source, newPredicate);
        }

        private static boolean allRankingValuesInDomain(TupleDomain<Symbol> tupleDomain, Symbol symbol, long upperBound) {
            if (tupleDomain.isNone()) {
                return false;
            }
            Domain domain = (Domain)((Map)tupleDomain.getDomains().get()).get(symbol);
            if (domain == null) {
                return true;
            }
            return domain.getValues().contains(ValueSet.ofRanges((Range)Range.range((Type)domain.getType(), (Object)1L, (boolean)true, (Object)upperBound, (boolean)true), (Range[])new Range[0]));
        }

        private static OptionalInt extractUpperBound(TupleDomain<Symbol> tupleDomain, Symbol symbol) {
            if (tupleDomain.isNone()) {
                return OptionalInt.empty();
            }
            Domain domain = (Domain)((Map)tupleDomain.getDomains().get()).get(symbol);
            if (domain == null) {
                return OptionalInt.empty();
            }
            ValueSet values = domain.getValues();
            if (values.isAll() || values.isNone() || values.getRanges().getRangeCount() <= 0) {
                return OptionalInt.empty();
            }
            Range span = values.getRanges().getSpan();
            if (span.isHighUnbounded()) {
                return OptionalInt.empty();
            }
            Verify.verify((boolean)domain.getType().equals((Object)BigintType.BIGINT));
            long upperBound = (Long)span.getHighBoundedValue();
            if (!span.isHighInclusive()) {
                --upperBound;
            }
            if (upperBound >= Integer.MIN_VALUE && upperBound <= Integer.MAX_VALUE) {
                return OptionalInt.of(Math.toIntExact(upperBound));
            }
            return OptionalInt.empty();
        }

        private static RowNumberNode mergeLimit(RowNumberNode node, int newRowCountPerPartition) {
            if (node.getMaxRowCountPerPartition().isPresent()) {
                newRowCountPerPartition = Math.min(node.getMaxRowCountPerPartition().get(), newRowCountPerPartition);
            }
            return new RowNumberNode(node.getId(), node.getSource(), node.getPartitionBy(), node.isOrderSensitive(), node.getRowNumberSymbol(), Optional.of(newRowCountPerPartition), node.getHashSymbol());
        }

        private TopNRankingNode convertToTopNRanking(WindowNode windowNode, TopNRankingNode.RankingType rankingType, int limit) {
            return new TopNRankingNode(this.idAllocator.getNextId(), windowNode.getSource(), windowNode.getSpecification(), rankingType, (Symbol)Iterables.getOnlyElement(windowNode.getWindowFunctions().keySet()), limit, false, Optional.empty());
        }

        private boolean canReplaceWithRowNumber(WindowNode node) {
            if (node.getWindowFunctions().size() != 1) {
                return false;
            }
            Symbol rankingSymbol = (Symbol)((Map.Entry)Iterables.getOnlyElement(node.getWindowFunctions().entrySet())).getKey();
            FunctionId functionId = node.getWindowFunctions().get(rankingSymbol).getResolvedFunction().functionId();
            return functionId.equals((Object)this.rowNumberFunctionId) && node.getOrderingScheme().isEmpty();
        }

        private Optional<TopNRankingNode.RankingType> toTopNRankingType(WindowNode node) {
            if (node.getWindowFunctions().size() != 1 || node.getOrderingScheme().isEmpty()) {
                return Optional.empty();
            }
            Symbol rankingSymbol = (Symbol)((Map.Entry)Iterables.getOnlyElement(node.getWindowFunctions().entrySet())).getKey();
            FunctionId functionId = node.getWindowFunctions().get(rankingSymbol).getResolvedFunction().functionId();
            if (functionId.equals((Object)this.rowNumberFunctionId)) {
                return Optional.of(TopNRankingNode.RankingType.ROW_NUMBER);
            }
            if (functionId.equals((Object)this.rankFunctionId)) {
                return Optional.of(TopNRankingNode.RankingType.RANK);
            }
            return Optional.empty();
        }
    }
}

