/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner.iterative.rule;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import io.trino.matching.Capture;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.sql.ir.BooleanLiteral;
import io.trino.sql.ir.ComparisonExpression;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.IrUtils;
import io.trino.sql.ir.SymbolReference;
import io.trino.sql.planner.DeterminismEvaluator;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.SymbolsExtractor;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.FilterNode;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.JoinType;
import io.trino.sql.planner.plan.Patterns;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.ProjectNode;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

public class PushInequalityFilterExpressionBelowJoinRuleSet {
    private static final Set<ComparisonExpression.Operator> SUPPORTED_COMPARISONS = ImmutableSet.of((Object)((Object)ComparisonExpression.Operator.GREATER_THAN), (Object)((Object)ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL), (Object)((Object)ComparisonExpression.Operator.LESS_THAN), (Object)((Object)ComparisonExpression.Operator.LESS_THAN_OR_EQUAL));
    private static final Pattern<JoinNode> JOIN_PATTERN = Patterns.join();
    private static final Capture<JoinNode> JOIN_CAPTURE = Capture.newCapture();
    private static final Pattern<FilterNode> FILTER_PATTERN = Patterns.filter().with(Patterns.source().matching(Patterns.join().capturedAs(JOIN_CAPTURE)));

    public Iterable<Rule<?>> rules() {
        return ImmutableList.of(this.pushParentInequalityFilterExpressionBelowJoinRule(), this.pushJoinInequalityFilterExpressionBelowJoinRule());
    }

    public Rule<FilterNode> pushParentInequalityFilterExpressionBelowJoinRule() {
        return new PushFilterExpressionBelowJoinFilterRule();
    }

    public Rule<JoinNode> pushJoinInequalityFilterExpressionBelowJoinRule() {
        return new PushFilterExpressionBelowJoinJoinRule();
    }

    private Rule.Result pushInequalityFilterExpressionBelowJoin(Rule.Context context, JoinNode joinNode, Optional<FilterNode> filterNode) {
        JoinNodeContext joinNodeContext = new JoinNodeContext(joinNode);
        Expression parentFilterPredicate = filterNode.map(FilterNode::getPredicate).orElse(BooleanLiteral.TRUE_LITERAL);
        ImmutableMap parentFilterCandidates = joinNode.getType() == JoinType.INNER ? this.extractPushDownCandidates(joinNodeContext, parentFilterPredicate) : ImmutableMap.of((Object)true, (Object)ImmutableList.of(), (Object)false, IrUtils.extractConjuncts(parentFilterPredicate));
        Map<Boolean, List<Expression>> joinFilterCandidates = this.extractPushDownCandidates(joinNodeContext, joinNode.getFilter().orElse(BooleanLiteral.TRUE_LITERAL));
        if (((List)parentFilterCandidates.get(true)).isEmpty() && joinFilterCandidates.get(true).isEmpty()) {
            return Rule.Result.empty();
        }
        ImmutableList.Builder newParentFilterConjuncts = ImmutableList.builder().addAll((Iterable)parentFilterCandidates.get(false));
        Map<Symbol, Expression> newRightProjectionsForParentFilter = this.pushDownRightComplexExpressions(joinNodeContext, context, (List)parentFilterCandidates.get(true), (ImmutableList.Builder<Expression>)newParentFilterConjuncts);
        ImmutableList.Builder newJoinFilterConjuncts = ImmutableList.builder().addAll((Iterable)joinFilterCandidates.get(false));
        Map<Symbol, Expression> newRightProjectionsForJoinFilter = this.pushDownRightComplexExpressions(joinNodeContext, context, joinFilterCandidates.get(true), (ImmutableList.Builder<Expression>)newJoinFilterConjuncts);
        PlanNode newOutput = this.constructModifiedJoin(context, joinNode, this.conjunctsToFilter((List<Expression>)newJoinFilterConjuncts.build()), (Map<Symbol, Expression>)ImmutableMap.builder().putAll(newRightProjectionsForJoinFilter).putAll(newRightProjectionsForParentFilter).buildOrThrow(), newRightProjectionsForParentFilter.keySet());
        Optional<Expression> filter = this.conjunctsToFilter((List<Expression>)newParentFilterConjuncts.build());
        if (filter.isPresent()) {
            newOutput = new FilterNode(filterNode.get().getId(), newOutput, filter.get());
        }
        if (!joinNode.getOutputSymbols().equals(((PlanNode)newOutput).getOutputSymbols())) {
            newOutput = new ProjectNode(context.getIdAllocator().getNextId(), newOutput, Assignments.identity(joinNode.getOutputSymbols()));
        }
        return Rule.Result.ofPlanNode(newOutput);
    }

    private Optional<Expression> conjunctsToFilter(List<Expression> conjuncts) {
        return Optional.of(IrUtils.combineConjuncts(conjuncts)).filter(expression -> !BooleanLiteral.TRUE_LITERAL.equals(expression));
    }

    Map<Boolean, List<Expression>> extractPushDownCandidates(JoinNodeContext joinNodeContext, Expression filter) {
        return IrUtils.extractConjuncts(filter).stream().collect(Collectors.partitioningBy(conjunct -> this.isSupportedExpression(joinNodeContext, (Expression)conjunct)));
    }

    private boolean isSupportedExpression(JoinNodeContext joinNodeContext, Expression expression) {
        ComparisonExpression comparison;
        block7: {
            block6: {
                if (!(expression instanceof ComparisonExpression)) break block6;
                comparison = (ComparisonExpression)expression;
                if (DeterminismEvaluator.isDeterministic(expression)) break block7;
            }
            return false;
        }
        if (!SUPPORTED_COMPARISONS.contains((Object)comparison.getOperator())) {
            return false;
        }
        Set<Symbol> leftComparisonSymbols = SymbolsExtractor.extractUnique(comparison.getLeft());
        Set<Symbol> rightComparisonSymbols = SymbolsExtractor.extractUnique(comparison.getRight());
        if (leftComparisonSymbols.isEmpty() || rightComparisonSymbols.isEmpty()) {
            return false;
        }
        Set<Symbol> leftSymbols = joinNodeContext.getLeftSymbols();
        Set<Symbol> rightSymbols = joinNodeContext.getRightSymbols();
        if (!(leftSymbols.containsAll(leftComparisonSymbols) && rightSymbols.containsAll(rightComparisonSymbols) || rightSymbols.containsAll(leftComparisonSymbols) && leftSymbols.containsAll(rightComparisonSymbols))) {
            return false;
        }
        boolean alignedComparison = joinNodeContext.isComparisonAligned(comparison);
        Expression buildExpression = alignedComparison ? comparison.getRight() : comparison.getLeft();
        return !(buildExpression instanceof SymbolReference);
    }

    Map<Symbol, Expression> pushDownRightComplexExpressions(JoinNodeContext joinNodeContext, Rule.Context context, List<Expression> conjuncts, ImmutableList.Builder<Expression> newConjuncts) {
        ImmutableMap.Builder newProjections = ImmutableMap.builder();
        conjuncts.forEach(conjunct -> this.pushDownRightComplexExpression(joinNodeContext, context, newConjuncts, (ImmutableMap.Builder<Symbol, Expression>)newProjections, (Expression)conjunct));
        return newProjections.buildOrThrow();
    }

    private void pushDownRightComplexExpression(JoinNodeContext joinNodeContext, Rule.Context context, ImmutableList.Builder<Expression> newConjuncts, ImmutableMap.Builder<Symbol, Expression> newProjections, Expression conjunct) {
        Preconditions.checkArgument((boolean)(conjunct instanceof ComparisonExpression), (String)"conjunct '%s' is not a comparison", (Object)conjunct);
        ComparisonExpression comparison = (ComparisonExpression)conjunct;
        boolean alignedComparison = joinNodeContext.isComparisonAligned(comparison);
        Expression rightExpression = alignedComparison ? comparison.getRight() : comparison.getLeft();
        Expression leftExpression = alignedComparison ? comparison.getLeft() : comparison.getRight();
        Symbol rightSymbol = this.symbolForExpression(context, rightExpression);
        newConjuncts.add((Object)new ComparisonExpression(comparison.getOperator(), alignedComparison ? leftExpression : rightSymbol.toSymbolReference(), alignedComparison ? rightSymbol.toSymbolReference() : leftExpression));
        newProjections.put((Object)rightSymbol, (Object)rightExpression);
    }

    private JoinNode constructModifiedJoin(Rule.Context context, JoinNode originalJoinNode, Optional<Expression> newJoinFilter, Map<Symbol, Expression> newRightProjections, Set<Symbol> newJoinRightOutputSymbols) {
        PlanNode rightSource = newRightProjections.isEmpty() ? originalJoinNode.getRight() : new ProjectNode(context.getIdAllocator().getNextId(), originalJoinNode.getRight(), this.buildAssignments(originalJoinNode.getRight(), newRightProjections));
        return new JoinNode(originalJoinNode.getId(), originalJoinNode.getType(), originalJoinNode.getLeft(), rightSource, originalJoinNode.getCriteria(), originalJoinNode.getLeftOutputSymbols(), this.concatToList(originalJoinNode.getRightOutputSymbols(), newJoinRightOutputSymbols), originalJoinNode.isMaySkipOutputDuplicates(), newJoinFilter, originalJoinNode.getLeftHashSymbol(), originalJoinNode.getRightHashSymbol(), originalJoinNode.getDistributionType(), originalJoinNode.isSpillable(), originalJoinNode.getDynamicFilters(), originalJoinNode.getReorderJoinStatsAndCost());
    }

    private <T> List<T> concatToList(Iterable<T> left, Iterable<T> right) {
        return ImmutableList.builder().addAll(left).addAll(right).build();
    }

    private Assignments buildAssignments(PlanNode source, Map<Symbol, Expression> newRightProjections) {
        return Assignments.builder().putIdentities(source.getOutputSymbols()).putAll(newRightProjections).build();
    }

    private Symbol symbolForExpression(Rule.Context context, Expression expression) {
        Preconditions.checkArgument((!(expression instanceof SymbolReference) ? 1 : 0) != 0, (String)"expression '%s' is a SymbolReference", (Object)expression);
        return context.getSymbolAllocator().newSymbol(expression, expression.type());
    }

    private class PushFilterExpressionBelowJoinFilterRule
    implements Rule<FilterNode> {
        private PushFilterExpressionBelowJoinFilterRule() {
        }

        @Override
        public Rule.Result apply(FilterNode filterNode, Captures captures, Rule.Context context) {
            return PushInequalityFilterExpressionBelowJoinRuleSet.this.pushInequalityFilterExpressionBelowJoin(context, (JoinNode)captures.get(JOIN_CAPTURE), Optional.of(filterNode));
        }

        @Override
        public Pattern<FilterNode> getPattern() {
            return FILTER_PATTERN;
        }
    }

    private class PushFilterExpressionBelowJoinJoinRule
    implements Rule<JoinNode> {
        private PushFilterExpressionBelowJoinJoinRule() {
        }

        @Override
        public Rule.Result apply(JoinNode joinNode, Captures captures, Rule.Context context) {
            return PushInequalityFilterExpressionBelowJoinRuleSet.this.pushInequalityFilterExpressionBelowJoin(context, joinNode, Optional.empty());
        }

        @Override
        public Pattern<JoinNode> getPattern() {
            return JOIN_PATTERN;
        }
    }

    private static class JoinNodeContext {
        private final Set<Symbol> leftSymbols;
        private final Set<Symbol> rightSymbols;

        public JoinNodeContext(JoinNode joinNode) {
            Objects.requireNonNull(joinNode, "joinNode is null");
            this.leftSymbols = ImmutableSet.copyOf(joinNode.getLeft().getOutputSymbols());
            this.rightSymbols = ImmutableSet.copyOf(joinNode.getRight().getOutputSymbols());
        }

        public Set<Symbol> getLeftSymbols() {
            return this.leftSymbols;
        }

        public Set<Symbol> getRightSymbols() {
            return this.rightSymbols;
        }

        public boolean isComparisonAligned(ComparisonExpression comparison) {
            return this.leftSymbols.containsAll(SymbolsExtractor.extractUnique(comparison.getLeft()));
        }
    }
}

