/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner.iterative.rule;

import com.google.common.base.Verify;
import com.google.common.collect.ImmutableList;
import io.trino.Session;
import io.trino.matching.Capture;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.metadata.Metadata;
import io.trino.spi.predicate.Domain;
import io.trino.spi.predicate.Marker;
import io.trino.spi.predicate.Range;
import io.trino.spi.predicate.TupleDomain;
import io.trino.spi.predicate.ValueSet;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.Type;
import io.trino.spi.type.TypeOperators;
import io.trino.sql.ExpressionUtils;
import io.trino.sql.planner.DomainTranslator;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.TypeProvider;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.plan.FilterNode;
import io.trino.sql.planner.plan.Patterns;
import io.trino.sql.planner.plan.RowNumberNode;
import io.trino.sql.planner.plan.ValuesNode;
import io.trino.sql.tree.BooleanLiteral;
import io.trino.sql.tree.Expression;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.OptionalInt;

public class PushdownFilterIntoRowNumber
implements Rule<FilterNode> {
    private static final Capture<RowNumberNode> CHILD = Capture.newCapture();
    private static final Pattern<FilterNode> PATTERN = Patterns.filter().with(Patterns.source().matching(Patterns.rowNumber().capturedAs(CHILD)));
    private final Metadata metadata;
    private final DomainTranslator domainTranslator;
    private final TypeOperators typeOperators;

    public PushdownFilterIntoRowNumber(Metadata metadata, TypeOperators typeOperators) {
        this.metadata = metadata;
        this.domainTranslator = new DomainTranslator(metadata);
        this.typeOperators = typeOperators;
    }

    @Override
    public Pattern<FilterNode> getPattern() {
        return PATTERN;
    }

    @Override
    public Rule.Result apply(FilterNode node, Captures captures, Rule.Context context) {
        boolean needRewriteSource;
        RowNumberNode source;
        Symbol rowNumberSymbol;
        Session session = context.getSession();
        TypeProvider types = context.getSymbolAllocator().getTypes();
        DomainTranslator.ExtractionResult extractionResult = DomainTranslator.fromPredicate(this.metadata, this.typeOperators, session, node.getPredicate(), types);
        TupleDomain<Symbol> tupleDomain = extractionResult.getTupleDomain();
        OptionalInt upperBound = PushdownFilterIntoRowNumber.extractUpperBound(tupleDomain, rowNumberSymbol = (source = (RowNumberNode)captures.get(CHILD)).getRowNumberSymbol());
        if (upperBound.isEmpty()) {
            return Rule.Result.empty();
        }
        if (upperBound.getAsInt() <= 0) {
            return Rule.Result.ofPlanNode(new ValuesNode(node.getId(), node.getOutputSymbols(), (List<Expression>)ImmutableList.of()));
        }
        RowNumberNode merged = PushdownFilterIntoRowNumber.mergeLimit(source, upperBound.getAsInt());
        boolean bl = needRewriteSource = !merged.getMaxRowCountPerPartition().equals(source.getMaxRowCountPerPartition());
        if (needRewriteSource) {
            source = merged;
        }
        if (!PushdownFilterIntoRowNumber.allRowNumberValuesInDomain(tupleDomain, rowNumberSymbol, source.getMaxRowCountPerPartition().get().intValue())) {
            if (needRewriteSource) {
                return Rule.Result.ofPlanNode(new FilterNode(node.getId(), source, node.getPredicate()));
            }
            return Rule.Result.empty();
        }
        TupleDomain newTupleDomain = tupleDomain.filter((symbol, domain) -> !symbol.equals(rowNumberSymbol));
        Expression newPredicate = ExpressionUtils.combineConjuncts(this.metadata, extractionResult.getRemainingExpression(), this.domainTranslator.toPredicate((TupleDomain<Symbol>)newTupleDomain));
        if (newPredicate.equals((Object)BooleanLiteral.TRUE_LITERAL)) {
            return Rule.Result.ofPlanNode(source);
        }
        if (!newPredicate.equals((Object)node.getPredicate())) {
            return Rule.Result.ofPlanNode(new FilterNode(node.getId(), source, newPredicate));
        }
        return Rule.Result.empty();
    }

    private static boolean allRowNumberValuesInDomain(TupleDomain<Symbol> tupleDomain, Symbol symbol, long upperBound) {
        if (tupleDomain.isNone()) {
            return false;
        }
        Domain domain = (Domain)((Map)tupleDomain.getDomains().get()).get(symbol);
        if (domain == null) {
            return true;
        }
        return domain.getValues().contains(ValueSet.ofRanges((Range)Range.range((Type)domain.getType(), (Object)1L, (boolean)true, (Object)upperBound, (boolean)true), (Range[])new Range[0]));
    }

    private static OptionalInt extractUpperBound(TupleDomain<Symbol> tupleDomain, Symbol symbol) {
        if (tupleDomain.isNone()) {
            return OptionalInt.empty();
        }
        Domain rowNumberDomain = (Domain)((Map)tupleDomain.getDomains().get()).get(symbol);
        if (rowNumberDomain == null) {
            return OptionalInt.empty();
        }
        ValueSet values = rowNumberDomain.getValues();
        if (values.isAll() || values.isNone() || values.getRanges().getRangeCount() <= 0) {
            return OptionalInt.empty();
        }
        Range span = values.getRanges().getSpan();
        if (span.getHigh().isUpperUnbounded()) {
            return OptionalInt.empty();
        }
        Verify.verify((boolean)rowNumberDomain.getType().equals(BigintType.BIGINT));
        long upperBound = (Long)span.getHigh().getValue();
        if (span.getHigh().getBound() == Marker.Bound.BELOW) {
            --upperBound;
        }
        if (upperBound >= Integer.MIN_VALUE && upperBound <= Integer.MAX_VALUE) {
            return OptionalInt.of(Math.toIntExact(upperBound));
        }
        return OptionalInt.empty();
    }

    private static RowNumberNode mergeLimit(RowNumberNode node, int newRowCountPerPartition) {
        if (node.getMaxRowCountPerPartition().isPresent()) {
            newRowCountPerPartition = Math.min(node.getMaxRowCountPerPartition().get(), newRowCountPerPartition);
        }
        return new RowNumberNode(node.getId(), node.getSource(), node.getPartitionBy(), node.isOrderSensitive(), node.getRowNumberSymbol(), Optional.of(newRowCountPerPartition), node.getHashSymbol());
    }
}

