/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.Session;
import com.facebook.presto.SystemSessionProperties;
import com.facebook.presto.expressions.LogicalRowExpressions;
import com.facebook.presto.matching.Capture;
import com.facebook.presto.matching.Captures;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.spi.plan.Assignments;
import com.facebook.presto.spi.plan.FilterNode;
import com.facebook.presto.spi.plan.JoinType;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.ProjectNode;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.analyzer.FeaturesConfig;
import com.facebook.presto.sql.gen.CommonSubExpressionRewriter;
import com.facebook.presto.sql.planner.PlannerUtils;
import com.facebook.presto.sql.planner.VariablesExtractor;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.planner.iterative.rule.CrossJoinWithArrayContainsToInnerJoin;
import com.facebook.presto.sql.planner.iterative.rule.CrossJoinWithArrayNotContainsToAntiJoin;
import com.facebook.presto.sql.planner.iterative.rule.CrossJoinWithOrFilterToInnerJoin;
import com.facebook.presto.sql.planner.plan.JoinNode;
import com.facebook.presto.sql.planner.plan.Patterns;
import com.facebook.presto.sql.relational.FunctionResolution;
import com.facebook.presto.sql.relational.RowExpressionDeterminismEvaluator;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Stream;

public class PushDownFilterExpressionEvaluationThroughCrossJoin
implements Rule<FilterNode> {
    private static final Capture<JoinNode> CHILD = Capture.newCapture();
    private static final Pattern<FilterNode> PATTERN = Patterns.filter().with(Patterns.source().matching(Patterns.join().matching(x -> x.getCriteria().isEmpty() && x.getType().equals((Object)JoinType.INNER)).capturedAs(CHILD)));
    private final FunctionAndTypeManager functionAndTypeManager;
    private final RowExpressionDeterminismEvaluator determinismEvaluator;

    public PushDownFilterExpressionEvaluationThroughCrossJoin(FunctionAndTypeManager functionAndTypeManager) {
        this.functionAndTypeManager = Objects.requireNonNull(functionAndTypeManager, "functionAndTypeManager is null");
        this.determinismEvaluator = new RowExpressionDeterminismEvaluator(functionAndTypeManager);
    }

    private static boolean canRewriteToInnerJoin(FunctionResolution functionResolution, RowExpression filter, List<VariableReferenceExpression> left, List<VariableReferenceExpression> right) {
        return CrossJoinWithOrFilterToInnerJoin.getCandidateOrExpression(filter, left, right) != null || CrossJoinWithArrayContainsToInnerJoin.getCandidateArrayContainsExpression(functionResolution, filter, left, right) != null || CrossJoinWithArrayNotContainsToAntiJoin.getCandidateArrayNotContainsExpression(functionResolution, filter, left, right) != null;
    }

    @Override
    public Pattern<FilterNode> getPattern() {
        return PATTERN;
    }

    @Override
    public boolean isEnabled(Session session) {
        return !SystemSessionProperties.getPushdownFilterExpressionEvaluationThroughCrossJoinStrategy(session).equals((Object)FeaturesConfig.PushDownFilterThroughCrossJoinStrategy.DISABLED);
    }

    @Override
    public Rule.Result apply(FilterNode filterNode, Captures captures, Rule.Context context) {
        JoinNode joinNode = (JoinNode)((Object)captures.get(CHILD));
        FunctionResolution functionResolution = new FunctionResolution(this.functionAndTypeManager.getFunctionAndTypeResolver());
        List<Set<RowExpression>> rowExpressionToProject = this.getRowExpressions(functionResolution, filterNode.getPredicate(), joinNode.getLeft().getOutputVariables(), joinNode.getRight().getOutputVariables());
        if (rowExpressionToProject.stream().allMatch(x -> x.isEmpty())) {
            return Rule.Result.empty();
        }
        Map rewrittenExpressionMap = (Map)Stream.concat(rowExpressionToProject.get(0).stream(), rowExpressionToProject.get(1).stream()).collect(ImmutableMap.toImmutableMap(Function.identity(), x -> context.getVariableAllocator().newVariable(x)));
        RowExpression rewrittenFilter = CommonSubExpressionRewriter.rewriteExpressionWithCSE(filterNode.getPredicate(), rewrittenExpressionMap);
        Map leftAssignment = (Map)rowExpressionToProject.get(0).stream().collect(ImmutableMap.toImmutableMap(x -> (VariableReferenceExpression)rewrittenExpressionMap.get(x), Function.identity()));
        Map rightAssignment = (Map)rowExpressionToProject.get(1).stream().collect(ImmutableMap.toImmutableMap(x -> (VariableReferenceExpression)rewrittenExpressionMap.get(x), Function.identity()));
        PlanNode leftInput = joinNode.getLeft();
        if (!leftAssignment.isEmpty()) {
            leftInput = PlannerUtils.addProjections(joinNode.getLeft(), context.getIdAllocator(), leftAssignment);
        }
        PlanNode rightInput = joinNode.getRight();
        if (!rightAssignment.isEmpty()) {
            rightInput = PlannerUtils.addProjections(joinNode.getRight(), context.getIdAllocator(), rightAssignment);
        }
        if (SystemSessionProperties.getPushdownFilterExpressionEvaluationThroughCrossJoinStrategy(context.getSession()).equals((Object)FeaturesConfig.PushDownFilterThroughCrossJoinStrategy.REWRITTEN_TO_INNER_JOIN) && !PushDownFilterExpressionEvaluationThroughCrossJoin.canRewriteToInnerJoin(functionResolution, rewrittenFilter, leftInput.getOutputVariables(), rightInput.getOutputVariables())) {
            return Rule.Result.empty();
        }
        Assignments.Builder identity = Assignments.builder();
        identity.putAll((Map)filterNode.getOutputVariables().stream().collect(ImmutableMap.toImmutableMap(Function.identity(), Function.identity())));
        return Rule.Result.ofPlanNode((PlanNode)new ProjectNode(context.getIdAllocator().getNextId(), (PlanNode)new FilterNode(filterNode.getSourceLocation(), context.getIdAllocator().getNextId(), (PlanNode)new JoinNode(joinNode.getSourceLocation(), context.getIdAllocator().getNextId(), joinNode.getType(), leftInput, rightInput, joinNode.getCriteria(), (List<VariableReferenceExpression>)ImmutableList.builder().addAll((Iterable)leftInput.getOutputVariables()).addAll((Iterable)rightInput.getOutputVariables()).build(), joinNode.getFilter(), joinNode.getLeftHashVariable(), joinNode.getRightHashVariable(), joinNode.getDistributionType(), joinNode.getDynamicFilters()), rewrittenFilter), identity.build()));
    }

    private List<Set<RowExpression>> getRowExpressions(FunctionResolution functionResolution, RowExpression filterPredicate, List<VariableReferenceExpression> left, List<VariableReferenceExpression> right) {
        List<Set<RowExpression>> candidateFromOrCondition = this.getRowExpressionsFromOrCondition(filterPredicate, left, right);
        List<Set<RowExpression>> candidateFromArrayContains = this.getRowExpressionsFromArrayContains(functionResolution, filterPredicate, left, right);
        List<Set<RowExpression>> candidateFromArrayNotContains = this.getRowExpressionsFromArrayNotContains(functionResolution, filterPredicate, left, right);
        ImmutableSet.Builder leftCandidate = ImmutableSet.builder();
        leftCandidate.addAll((Iterable)candidateFromOrCondition.get(0));
        leftCandidate.addAll((Iterable)candidateFromArrayContains.get(0));
        leftCandidate.addAll((Iterable)candidateFromArrayNotContains.get(0));
        ImmutableSet.Builder rightCandidate = ImmutableSet.builder();
        rightCandidate.addAll((Iterable)candidateFromOrCondition.get(1));
        rightCandidate.addAll((Iterable)candidateFromArrayContains.get(1));
        rightCandidate.addAll((Iterable)candidateFromArrayNotContains.get(1));
        return ImmutableList.of((Object)leftCandidate.build(), (Object)rightCandidate.build());
    }

    private List<Set<RowExpression>> getRowExpressionsFromOrCondition(RowExpression filterPredicate, List<VariableReferenceExpression> left, List<VariableReferenceExpression> right) {
        HashSet<RowExpression> leftRowExpression = new HashSet<RowExpression>();
        HashSet<RowExpression> rightRowExpression = new HashSet<RowExpression>();
        for (RowExpression conjunct : LogicalRowExpressions.extractConjuncts((RowExpression)filterPredicate)) {
            for (RowExpression disjunct : LogicalRowExpressions.extractDisjuncts((RowExpression)conjunct)) {
                if (!(disjunct instanceof CallExpression) || !((CallExpression)disjunct).getDisplayName().equals("EQUAL")) continue;
                CallExpression callExpression = (CallExpression)disjunct;
                this.addCandidateExpression((RowExpression)callExpression.getArguments().get(0), left, right, leftRowExpression, rightRowExpression);
                this.addCandidateExpression((RowExpression)callExpression.getArguments().get(1), left, right, leftRowExpression, rightRowExpression);
            }
        }
        return ImmutableList.of(leftRowExpression, rightRowExpression);
    }

    private List<Set<RowExpression>> getRowExpressionsFromArrayContains(FunctionResolution functionResolution, RowExpression filterPredicate, List<VariableReferenceExpression> left, List<VariableReferenceExpression> right) {
        HashSet<RowExpression> leftRowExpression = new HashSet<RowExpression>();
        HashSet<RowExpression> rightRowExpression = new HashSet<RowExpression>();
        if (filterPredicate instanceof CallExpression && functionResolution.isArrayContainsFunction(((CallExpression)filterPredicate).getFunctionHandle())) {
            CallExpression callExpression = (CallExpression)filterPredicate;
            this.addCandidateExpression((RowExpression)callExpression.getArguments().get(0), left, right, leftRowExpression, rightRowExpression);
            this.addCandidateExpression((RowExpression)callExpression.getArguments().get(1), left, right, leftRowExpression, rightRowExpression);
        }
        return ImmutableList.of(leftRowExpression, rightRowExpression);
    }

    private List<Set<RowExpression>> getRowExpressionsFromArrayNotContains(FunctionResolution functionResolution, RowExpression filterPredicate, List<VariableReferenceExpression> left, List<VariableReferenceExpression> right) {
        RowExpression argument;
        HashSet<RowExpression> leftRowExpression = new HashSet<RowExpression>();
        HashSet<RowExpression> rightRowExpression = new HashSet<RowExpression>();
        if (PlannerUtils.isNegationExpression(functionResolution, filterPredicate) && (argument = (RowExpression)filterPredicate.getChildren().get(0)) instanceof CallExpression && functionResolution.isArrayContainsFunction(((CallExpression)argument).getFunctionHandle())) {
            CallExpression callExpression = (CallExpression)argument;
            this.addCandidateExpression((RowExpression)callExpression.getArguments().get(0), left, right, leftRowExpression, rightRowExpression);
            this.addCandidateExpression((RowExpression)callExpression.getArguments().get(1), left, right, leftRowExpression, rightRowExpression);
        }
        return ImmutableList.of(leftRowExpression, rightRowExpression);
    }

    private void addCandidateExpression(RowExpression candidate, List<VariableReferenceExpression> left, List<VariableReferenceExpression> right, Set<RowExpression> leftRowExpression, Set<RowExpression> rightRowExpression) {
        List<VariableReferenceExpression> variablesInExpression = VariablesExtractor.extractAll(candidate);
        if (!variablesInExpression.isEmpty() && this.determinismEvaluator.isDeterministic(candidate) && !(candidate instanceof VariableReferenceExpression)) {
            if (left.containsAll(variablesInExpression)) {
                leftRowExpression.add(candidate);
            } else if (right.containsAll(variablesInExpression)) {
                rightRowExpression.add(candidate);
            }
        }
    }
}

