/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.Session;
import com.facebook.presto.SystemSessionProperties;
import com.facebook.presto.common.type.ArrayType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.expressions.LogicalRowExpressions;
import com.facebook.presto.matching.Capture;
import com.facebook.presto.matching.Captures;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.spi.plan.EquiJoinClause;
import com.facebook.presto.spi.plan.FilterNode;
import com.facebook.presto.spi.plan.JoinNode;
import com.facebook.presto.spi.plan.JoinType;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.PlannerUtils;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.planner.plan.Patterns;
import com.facebook.presto.sql.planner.plan.UnnestNode;
import com.facebook.presto.sql.relational.Expressions;
import com.facebook.presto.sql.relational.FunctionResolution;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;

public class CrossJoinWithArrayContainsToInnerJoin
implements Rule<FilterNode> {
    private static final Capture<JoinNode> CHILD = Capture.newCapture();
    private static final Pattern<FilterNode> PATTERN = Patterns.filter().with(Patterns.source().matching(Patterns.join().matching(x -> x.getType().equals((Object)JoinType.INNER) && x.getCriteria().isEmpty()).capturedAs(CHILD)));
    private final FunctionAndTypeManager functionAndTypeManager;

    public CrossJoinWithArrayContainsToInnerJoin(FunctionAndTypeManager functionAndTypeManager) {
        this.functionAndTypeManager = Objects.requireNonNull(functionAndTypeManager, "functionAndTypeManager is null");
    }

    public static RowExpression getCandidateArrayContainsExpression(FunctionResolution functionResolution, RowExpression filterPredicate, List<VariableReferenceExpression> leftInput, List<VariableReferenceExpression> rightInput) {
        List andConjuncts = LogicalRowExpressions.extractConjuncts((RowExpression)filterPredicate);
        for (RowExpression conjunct : andConjuncts) {
            if (!PlannerUtils.isSupportedArrayContainsFilter(functionResolution, conjunct, leftInput, rightInput)) continue;
            return conjunct;
        }
        return null;
    }

    @Override
    public Pattern<FilterNode> getPattern() {
        return PATTERN;
    }

    @Override
    public boolean isEnabled(Session session) {
        return SystemSessionProperties.isRewriteCrossJoinArrayContainsToInnerJoinEnabled(session);
    }

    @Override
    public Rule.Result apply(FilterNode node, Captures captures, Rule.Context context) {
        JoinNode joinNode = (JoinNode)captures.get(CHILD);
        if (!joinNode.getType().equals((Object)JoinType.INNER) || !joinNode.getCriteria().isEmpty()) {
            return Rule.Result.empty();
        }
        List leftInput = joinNode.getLeft().getOutputVariables();
        List rightInput = joinNode.getRight().getOutputVariables();
        RowExpression filterExpression = node.getPredicate();
        FunctionResolution functionResolution = new FunctionResolution(this.functionAndTypeManager.getFunctionAndTypeResolver());
        RowExpression arrayContainsExpression = CrossJoinWithArrayContainsToInnerJoin.getCandidateArrayContainsExpression(functionResolution, filterExpression, leftInput, rightInput);
        if (arrayContainsExpression == null) {
            return Rule.Result.empty();
        }
        List andConjuncts = LogicalRowExpressions.extractConjuncts((RowExpression)filterExpression);
        List remainingConjuncts = (List)andConjuncts.stream().filter(x -> !x.equals((Object)arrayContainsExpression)).collect(ImmutableList.toImmutableList());
        RowExpression array = (RowExpression)((CallExpression)arrayContainsExpression).getArguments().get(0);
        RowExpression element = (RowExpression)((CallExpression)arrayContainsExpression).getArguments().get(1);
        Preconditions.checkState((boolean)(element instanceof VariableReferenceExpression), (Object)"Argument to CONTAINS is not a column");
        Preconditions.checkState((boolean)(array instanceof VariableReferenceExpression), (Object)"Argument to CONTAINS is not a column");
        VariableReferenceExpression elementVar = (VariableReferenceExpression)element;
        boolean arrayAtLeftInput = leftInput.contains(array);
        PlanNode inputWithArray = arrayAtLeftInput ? joinNode.getLeft() : joinNode.getRight();
        CallExpression arrayDistinct = Expressions.call(this.functionAndTypeManager, "array_distinct", (Type)new ArrayType(element.getType()), array);
        VariableReferenceExpression arrayDistinctVariable = context.getVariableAllocator().newVariable((RowExpression)arrayDistinct);
        PlanNode project = PlannerUtils.addProjections(inputWithArray, context.getIdAllocator(), (Map<VariableReferenceExpression, RowExpression>)ImmutableMap.of((Object)arrayDistinctVariable, (Object)arrayDistinct));
        VariableReferenceExpression unnestVariable = context.getVariableAllocator().newVariable("field", element.getType());
        UnnestNode unnest = new UnnestNode(inputWithArray.getSourceLocation(), context.getIdAllocator().getNextId(), project, project.getOutputVariables(), (Map<VariableReferenceExpression, List<VariableReferenceExpression>>)ImmutableMap.of((Object)arrayDistinctVariable, (Object)ImmutableList.of((Object)unnestVariable)), Optional.empty());
        EquiJoinClause equiJoinClause = arrayAtLeftInput ? new EquiJoinClause(unnestVariable, elementVar) : new EquiJoinClause(elementVar, unnestVariable);
        JoinNode newJoinNode = new JoinNode(joinNode.getSourceLocation(), context.getIdAllocator().getNextId(), joinNode.getType(), (PlanNode)(arrayAtLeftInput ? unnest : joinNode.getLeft()), (PlanNode)(arrayAtLeftInput ? joinNode.getRight() : unnest), (List)ImmutableList.of((Object)equiJoinClause), joinNode.getOutputVariables(), joinNode.getFilter(), Optional.empty(), Optional.empty(), joinNode.getDistributionType(), joinNode.getDynamicFilters());
        if (!remainingConjuncts.isEmpty()) {
            return Rule.Result.ofPlanNode((PlanNode)new FilterNode(node.getSourceLocation(), context.getIdAllocator().getNextId(), (PlanNode)newJoinNode, LogicalRowExpressions.and((Collection)remainingConjuncts)));
        }
        return Rule.Result.ofPlanNode((PlanNode)newJoinNode);
    }
}

