/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner.iterative.rule;

import com.google.common.base.Preconditions;
import com.google.common.base.Verify;
import com.google.common.collect.ImmutableBiMap;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.trino.Session;
import io.trino.SystemSessionProperties;
import io.trino.cost.StatsProvider;
import io.trino.matching.Capture;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.metadata.Metadata;
import io.trino.metadata.TableHandle;
import io.trino.metadata.TableProperties;
import io.trino.spi.connector.ColumnHandle;
import io.trino.spi.connector.Constraint;
import io.trino.spi.connector.ConstraintApplicationResult;
import io.trino.spi.expression.ConnectorExpression;
import io.trino.spi.expression.Constant;
import io.trino.spi.predicate.Domain;
import io.trino.spi.predicate.TupleDomain;
import io.trino.spi.type.Type;
import io.trino.sql.DynamicFilters;
import io.trino.sql.PlannerContext;
import io.trino.sql.ir.BooleanLiteral;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.IrUtils;
import io.trino.sql.ir.NodeRef;
import io.trino.sql.planner.ConnectorExpressionTranslator;
import io.trino.sql.planner.DeterminismEvaluator;
import io.trino.sql.planner.DomainTranslator;
import io.trino.sql.planner.IrExpressionInterpreter;
import io.trino.sql.planner.IrTypeAnalyzer;
import io.trino.sql.planner.LayoutConstraintEvaluator;
import io.trino.sql.planner.LiteralEncoder;
import io.trino.sql.planner.NoOpSymbolResolver;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.SymbolAllocator;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.iterative.rule.Rules;
import io.trino.sql.planner.iterative.rule.SimplifyExpressions;
import io.trino.sql.planner.plan.FilterNode;
import io.trino.sql.planner.plan.Patterns;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.TableScanNode;
import io.trino.sql.planner.plan.ValuesNode;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Function;
import org.assertj.core.util.VisibleForTesting;

public class PushPredicateIntoTableScan
implements Rule<FilterNode> {
    private static final Capture<TableScanNode> TABLE_SCAN = Capture.newCapture();
    private static final Pattern<FilterNode> PATTERN = Patterns.filter().with(Patterns.source().matching(Patterns.tableScan().capturedAs(TABLE_SCAN)));
    private final PlannerContext plannerContext;
    private final IrTypeAnalyzer typeAnalyzer;
    private final boolean pruneWithPredicateExpression;

    public PushPredicateIntoTableScan(PlannerContext plannerContext, IrTypeAnalyzer typeAnalyzer, boolean pruneWithPredicateExpression) {
        this.plannerContext = Objects.requireNonNull(plannerContext, "plannerContext is null");
        this.typeAnalyzer = Objects.requireNonNull(typeAnalyzer, "typeAnalyzer is null");
        this.pruneWithPredicateExpression = pruneWithPredicateExpression;
    }

    @Override
    public Pattern<FilterNode> getPattern() {
        return PATTERN;
    }

    @Override
    public boolean isEnabled(Session session) {
        return SystemSessionProperties.isAllowPushdownIntoConnectors(session);
    }

    @Override
    public Rule.Result apply(FilterNode filterNode, Captures captures, Rule.Context context) {
        TableScanNode tableScan = (TableScanNode)captures.get(TABLE_SCAN);
        Optional<PlanNode> rewritten = PushPredicateIntoTableScan.pushFilterIntoTableScan(filterNode, tableScan, this.pruneWithPredicateExpression, context.getSession(), context.getSymbolAllocator(), this.plannerContext, this.typeAnalyzer, context.getStatsProvider(), new DomainTranslator(this.plannerContext));
        if (rewritten.isEmpty() || this.arePlansSame(filterNode, tableScan, rewritten.get())) {
            return Rule.Result.empty();
        }
        return Rule.Result.ofPlanNode(rewritten.get());
    }

    private boolean arePlansSame(FilterNode filter, TableScanNode tableScan, PlanNode rewritten) {
        if (!(rewritten instanceof FilterNode)) {
            return false;
        }
        FilterNode rewrittenFilter = (FilterNode)rewritten;
        if (!Objects.equals(filter.getPredicate(), rewrittenFilter.getPredicate())) {
            return false;
        }
        PlanNode planNode = rewrittenFilter.getSource();
        if (!(planNode instanceof TableScanNode)) {
            return false;
        }
        TableScanNode rewrittenTableScan = (TableScanNode)planNode;
        return Objects.equals(tableScan.getEnforcedConstraint(), rewrittenTableScan.getEnforcedConstraint()) && Objects.equals(tableScan.getTable(), rewrittenTableScan.getTable());
    }

    public static Optional<PlanNode> pushFilterIntoTableScan(FilterNode filterNode, TableScanNode node, boolean pruneWithPredicateExpression, Session session, SymbolAllocator symbolAllocator, PlannerContext plannerContext, IrTypeAnalyzer typeAnalyzer, StatsProvider statsProvider, DomainTranslator domainTranslator) {
        Expression resultingPredicate;
        Expression remainingDecomposedPredicate;
        Constraint constraint;
        if (!SystemSessionProperties.isAllowPushdownIntoConnectors(session)) {
            return Optional.empty();
        }
        SplitExpression splitExpression = PushPredicateIntoTableScan.splitExpression(plannerContext, filterNode.getPredicate());
        DomainTranslator.ExtractionResult decomposedPredicate = DomainTranslator.getExtractionResult(plannerContext, session, splitExpression.getDeterministicPredicate(), symbolAllocator.getTypes());
        TupleDomain newDomain = decomposedPredicate.getTupleDomain().transformKeys(node.getAssignments()::get).intersect(node.getEnforcedConstraint());
        ConnectorExpressionTranslator.ConnectorExpressionTranslation expressionTranslation = ConnectorExpressionTranslator.translateConjuncts(session, decomposedPredicate.getRemainingExpression(), symbolAllocator.getTypes(), plannerContext, typeAnalyzer);
        Map connectorExpressionAssignments = (Map)node.getAssignments().entrySet().stream().collect(ImmutableMap.toImmutableMap(entry -> ((Symbol)entry.getKey()).getName(), Map.Entry::getValue));
        ImmutableBiMap assignments = ImmutableBiMap.copyOf(node.getAssignments()).inverse();
        if (pruneWithPredicateExpression && !BooleanLiteral.TRUE_LITERAL.equals(decomposedPredicate.getRemainingExpression())) {
            Expression[] expressionArray = new Expression[2];
            expressionArray[0] = splitExpression.getDeterministicPredicate();
            expressionArray[1] = domainTranslator.toPredicate((TupleDomain<Symbol>)newDomain.simplify().transformKeys(((Map)assignments)::get));
            LayoutConstraintEvaluator evaluator = new LayoutConstraintEvaluator(plannerContext, typeAnalyzer, session, symbolAllocator.getTypes(), node.getAssignments(), IrUtils.combineConjuncts(plannerContext.getMetadata(), expressionArray));
            constraint = new Constraint(newDomain, expressionTranslation.connectorExpression(), connectorExpressionAssignments, evaluator::isCandidate, evaluator.getArguments());
        } else {
            constraint = new Constraint(newDomain, expressionTranslation.connectorExpression(), connectorExpressionAssignments);
        }
        if (constraint.predicate().isEmpty() && Constant.TRUE.equals((Object)expressionTranslation.connectorExpression()) && newDomain.contains(node.getEnforcedConstraint())) {
            Expression resultingPredicate2 = PushPredicateIntoTableScan.createResultingPredicate(plannerContext, session, symbolAllocator, typeAnalyzer, splitExpression.getDynamicFilter(), BooleanLiteral.TRUE_LITERAL, splitExpression.getNonDeterministicPredicate(), decomposedPredicate.getRemainingExpression());
            if (!BooleanLiteral.TRUE_LITERAL.equals(resultingPredicate2)) {
                return Optional.of(new FilterNode(filterNode.getId(), node, resultingPredicate2));
            }
            return Optional.of(node);
        }
        if (newDomain.isNone()) {
            return Optional.of(new ValuesNode(node.getId(), node.getOutputSymbols(), (List<Expression>)ImmutableList.of()));
        }
        Optional<ConstraintApplicationResult<TableHandle>> result = plannerContext.getMetadata().applyFilter(session, node.getTable(), constraint);
        if (result.isEmpty()) {
            return Optional.empty();
        }
        TableHandle newTable = (TableHandle)result.get().getHandle();
        TableProperties newTableProperties = plannerContext.getMetadata().getTableProperties(session, newTable);
        Optional<TableProperties.TablePartitioning> newTablePartitioning = newTableProperties.getTablePartitioning();
        if (newTableProperties.getPredicate().isNone()) {
            return Optional.of(new ValuesNode(node.getId(), node.getOutputSymbols(), (List<Expression>)ImmutableList.of()));
        }
        TupleDomain remainingFilter = result.get().getRemainingFilter();
        Optional remainingConnectorExpression = result.get().getRemainingExpression();
        boolean precalculateStatistics = result.get().isPrecalculateStatistics();
        PushPredicateIntoTableScan.verifyTablePartitioning(session, plannerContext.getMetadata(), node, newTablePartitioning);
        TableScanNode tableScan = new TableScanNode(node.getId(), newTable, node.getOutputSymbols(), node.getAssignments(), PushPredicateIntoTableScan.computeEnforced((TupleDomain<ColumnHandle>)newDomain, (TupleDomain<ColumnHandle>)remainingFilter), Rules.deriveTableStatisticsForPushdown(statsProvider, session, precalculateStatistics, filterNode), node.isUpdateTarget(), node.getUseConnectorNodePartitioning());
        if (remainingConnectorExpression.isEmpty() || ((ConnectorExpression)remainingConnectorExpression.get()).equals((Object)expressionTranslation.connectorExpression())) {
            remainingDecomposedPredicate = decomposedPredicate.getRemainingExpression();
        } else {
            Map variableMappings = (Map)assignments.values().stream().collect(ImmutableMap.toImmutableMap(Symbol::getName, Function.identity()));
            LiteralEncoder literalEncoder = new LiteralEncoder(plannerContext);
            Expression translatedExpression = ConnectorExpressionTranslator.translate(session, (ConnectorExpression)remainingConnectorExpression.get(), plannerContext, variableMappings, literalEncoder);
            Map<NodeRef<Expression>, Type> translatedExpressionTypes = typeAnalyzer.getTypes(session, symbolAllocator.getTypes(), translatedExpression);
            translatedExpression = literalEncoder.toExpression(new IrExpressionInterpreter(translatedExpression, plannerContext, session, translatedExpressionTypes).optimize(NoOpSymbolResolver.INSTANCE), translatedExpressionTypes.get(NodeRef.of(translatedExpression)));
            remainingDecomposedPredicate = IrUtils.combineConjuncts(plannerContext.getMetadata(), translatedExpression, expressionTranslation.remainingExpression());
        }
        if (!BooleanLiteral.TRUE_LITERAL.equals(resultingPredicate = PushPredicateIntoTableScan.createResultingPredicate(plannerContext, session, symbolAllocator, typeAnalyzer, splitExpression.getDynamicFilter(), domainTranslator.toPredicate((TupleDomain<Symbol>)remainingFilter.transformKeys(((Map)assignments)::get)), splitExpression.getNonDeterministicPredicate(), remainingDecomposedPredicate))) {
            return Optional.of(new FilterNode(filterNode.getId(), tableScan, resultingPredicate));
        }
        return Optional.of(tableScan);
    }

    private static void verifyTablePartitioning(Session session, Metadata metadata, TableScanNode oldTableScan, Optional<TableProperties.TablePartitioning> newTablePartitioning) {
        if (oldTableScan.getUseConnectorNodePartitioning().isEmpty()) {
            return;
        }
        Optional<TableProperties.TablePartitioning> oldTablePartitioning = metadata.getTableProperties(session, oldTableScan.getTable()).getTablePartitioning();
        Verify.verify((boolean)newTablePartitioning.equals(oldTablePartitioning), (String)"Partitioning must not change after predicate is pushed down", (Object[])new Object[0]);
    }

    private static SplitExpression splitExpression(PlannerContext plannerContext, Expression predicate) {
        Metadata metadata = plannerContext.getMetadata();
        ArrayList<Expression> dynamicFilters = new ArrayList<Expression>();
        ArrayList<Expression> deterministicPredicates = new ArrayList<Expression>();
        ArrayList<Expression> nonDeterministicPredicate = new ArrayList<Expression>();
        for (Expression conjunct : IrUtils.extractConjuncts(predicate)) {
            if (DynamicFilters.isDynamicFilter(conjunct)) {
                dynamicFilters.add(conjunct);
                continue;
            }
            if (DeterminismEvaluator.isDeterministic(conjunct, metadata)) {
                deterministicPredicates.add(conjunct);
                continue;
            }
            nonDeterministicPredicate.add(conjunct);
        }
        return new SplitExpression(IrUtils.combineConjuncts(metadata, dynamicFilters), IrUtils.combineConjuncts(metadata, deterministicPredicates), IrUtils.combineConjuncts(metadata, nonDeterministicPredicate));
    }

    static Expression createResultingPredicate(PlannerContext plannerContext, Session session, SymbolAllocator symbolAllocator, IrTypeAnalyzer typeAnalyzer, Expression dynamicFilter, Expression unenforcedConstraints, Expression nonDeterministicPredicate, Expression remainingDecomposedPredicate) {
        Expression expression = IrUtils.combineConjuncts(plannerContext.getMetadata(), dynamicFilter, unenforcedConstraints, nonDeterministicPredicate, remainingDecomposedPredicate);
        expression = SimplifyExpressions.rewrite(expression, session, symbolAllocator, plannerContext, typeAnalyzer);
        return expression;
    }

    public static TupleDomain<ColumnHandle> computeEnforced(TupleDomain<ColumnHandle> predicate, TupleDomain<ColumnHandle> unenforced) {
        Preconditions.checkArgument((!unenforced.isNone() ? 1 : 0) != 0, (Object)"Unexpected unenforced none tuple domain");
        Map predicateDomains = (Map)predicate.getDomains().get();
        Map unenforcedDomains = (Map)unenforced.getDomains().get();
        ImmutableMap.Builder enforcedDomainsBuilder = ImmutableMap.builder();
        for (Map.Entry entry : predicateDomains.entrySet()) {
            ColumnHandle predicateColumnHandle = (ColumnHandle)entry.getKey();
            Domain predicateDomain = (Domain)entry.getValue();
            if (unenforcedDomains.containsKey(predicateColumnHandle)) {
                Domain unenforcedDomain = (Domain)unenforcedDomains.get(predicateColumnHandle);
                Preconditions.checkArgument((boolean)predicateDomain.contains(unenforcedDomain), (String)"Unexpected unenforced domain %s on column %s. Expected all, none, or a domain equal to or narrower than %s", (Object)unenforcedDomain, (Object)predicateColumnHandle, (Object)predicateDomain);
                continue;
            }
            enforcedDomainsBuilder.put((Object)predicateColumnHandle, (Object)predicateDomain);
        }
        return TupleDomain.withColumnDomains((Map)enforcedDomainsBuilder.buildOrThrow());
    }

    @VisibleForTesting
    public boolean getPruneWithPredicateExpression() {
        return this.pruneWithPredicateExpression;
    }

    private static class SplitExpression {
        private final Expression dynamicFilter;
        private final Expression deterministicPredicate;
        private final Expression nonDeterministicPredicate;

        public SplitExpression(Expression dynamicFilter, Expression deterministicPredicate, Expression nonDeterministicPredicate) {
            this.dynamicFilter = Objects.requireNonNull(dynamicFilter, "dynamicFilter is null");
            this.deterministicPredicate = Objects.requireNonNull(deterministicPredicate, "deterministicPredicate is null");
            this.nonDeterministicPredicate = Objects.requireNonNull(nonDeterministicPredicate, "nonDeterministicPredicate is null");
        }

        public Expression getDynamicFilter() {
            return this.dynamicFilter;
        }

        public Expression getDeterministicPredicate() {
            return this.deterministicPredicate;
        }

        public Expression getNonDeterministicPredicate() {
            return this.nonDeterministicPredicate;
        }
    }
}

