/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.Session;
import com.facebook.presto.SystemSessionProperties;
import com.facebook.presto.common.type.ArrayType;
import com.facebook.presto.common.type.BooleanType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.expressions.LogicalRowExpressions;
import com.facebook.presto.matching.Capture;
import com.facebook.presto.matching.Captures;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.spi.plan.EquiJoinClause;
import com.facebook.presto.spi.plan.FilterNode;
import com.facebook.presto.spi.plan.JoinNode;
import com.facebook.presto.spi.plan.JoinType;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.ProjectNode;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.SpecialFormExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.PlannerUtils;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.planner.plan.EnforceSingleRowNode;
import com.facebook.presto.sql.planner.plan.Patterns;
import com.facebook.presto.sql.planner.plan.UnnestNode;
import com.facebook.presto.sql.relational.Expressions;
import com.facebook.presto.sql.relational.FunctionResolution;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;
import java.util.stream.Stream;

public class CrossJoinWithArrayNotContainsToAntiJoin
implements Rule<FilterNode> {
    private static final Capture<JoinNode> JOIN = Capture.newCapture();
    private static final Capture<List<PlanNode>> JOIN_CHILDREN = Capture.newCapture();
    private static final Pattern<FilterNode> PATTERN = Patterns.filter().with(Patterns.source().matching(Patterns.join().matching(x -> x.isCrossJoin()).capturedAs(JOIN).with(Patterns.sources().capturedAs(JOIN_CHILDREN))));
    Metadata metadata;
    private final FunctionAndTypeManager functionAndTypeManager;

    public CrossJoinWithArrayNotContainsToAntiJoin(Metadata metadata, FunctionAndTypeManager functionAndTypeManager) {
        this.metadata = Objects.requireNonNull(metadata, "metadata is null");
        this.functionAndTypeManager = Objects.requireNonNull(functionAndTypeManager, "functionAndTypeManager is null");
    }

    public static RowExpression getCandidateArrayNotContainsExpression(FunctionResolution functionResolution, RowExpression filterPredicate, List<VariableReferenceExpression> leftInput, List<VariableReferenceExpression> rightInput) {
        List conjuncts = LogicalRowExpressions.extractConjuncts((RowExpression)filterPredicate);
        for (RowExpression conjunct : conjuncts) {
            if (!PlannerUtils.isNegationExpression(functionResolution, conjunct) || !PlannerUtils.isSupportedArrayContainsFilter(functionResolution, (RowExpression)conjunct.getChildren().get(0), leftInput, rightInput)) continue;
            return conjunct;
        }
        return null;
    }

    @Override
    public Pattern<FilterNode> getPattern() {
        return PATTERN;
    }

    @Override
    public boolean isEnabled(Session session) {
        return SystemSessionProperties.isRewriteCrossJoinArrayNotContainsToAntiJoinEnabled(session);
    }

    @Override
    public Rule.Result apply(FilterNode node, Captures captures, Rule.Context context) {
        PlanNode inputWithArray;
        JoinNode joinNode = (JoinNode)captures.get(JOIN);
        if (!joinNode.getType().equals((Object)JoinType.INNER) || !joinNode.getCriteria().isEmpty()) {
            return Rule.Result.empty();
        }
        List leftColumns = joinNode.getLeft().getOutputVariables();
        List rightColumns = joinNode.getRight().getOutputVariables();
        RowExpression filterExpression = node.getPredicate();
        FunctionResolution functionResolution = new FunctionResolution(this.functionAndTypeManager.getFunctionAndTypeResolver());
        RowExpression arrayNotContainsExpression = CrossJoinWithArrayNotContainsToAntiJoin.getCandidateArrayNotContainsExpression(functionResolution, filterExpression, leftColumns, rightColumns);
        if (arrayNotContainsExpression == null) {
            return Rule.Result.empty();
        }
        List allConjuncts = LogicalRowExpressions.extractConjuncts((RowExpression)filterExpression);
        List remainingConjuncts = allConjuncts.stream().filter(x -> !x.equals((Object)arrayNotContainsExpression)).collect(Collectors.toList());
        RowExpression arrayContainsExpression = (RowExpression)arrayNotContainsExpression.getChildren().get(0);
        RowExpression array = (RowExpression)((CallExpression)arrayContainsExpression).getArguments().get(0);
        RowExpression element = (RowExpression)((CallExpression)arrayContainsExpression).getArguments().get(1);
        Preconditions.checkState((boolean)(element instanceof VariableReferenceExpression), (Object)"Argument to CONTAINS is not a column");
        Preconditions.checkState((boolean)(array instanceof VariableReferenceExpression), (Object)"Argument to CONTAINS is not a column");
        boolean arrayAtLeftInput = leftColumns.contains(array);
        PlanNode planNode = inputWithArray = arrayAtLeftInput ? joinNode.getLeft() : joinNode.getRight();
        if (!this.isFromScalarSubquery(context, inputWithArray)) {
            return Rule.Result.empty();
        }
        Type type = element.getType();
        CallExpression arrayDistinct = Expressions.call(this.functionAndTypeManager, "array_distinct", (Type)new ArrayType(type), new RowExpression[]{Expressions.call(this.functionAndTypeManager, "remove_nulls", (Type)new ArrayType(type), array)});
        VariableReferenceExpression arrayDistinctVariable = context.getVariableAllocator().newVariable((RowExpression)arrayDistinct);
        PlanNode project = PlannerUtils.addProjections(inputWithArray, context.getIdAllocator(), (Map<VariableReferenceExpression, RowExpression>)ImmutableMap.of((Object)arrayDistinctVariable, (Object)arrayDistinct));
        VariableReferenceExpression unnestVariable = context.getVariableAllocator().newVariable("field", type);
        UnnestNode unnest = new UnnestNode(inputWithArray.getSourceLocation(), context.getIdAllocator().getNextId(), project, project.getOutputVariables(), (Map<VariableReferenceExpression, List<VariableReferenceExpression>>)ImmutableMap.of((Object)arrayDistinctVariable, (Object)ImmutableList.of((Object)unnestVariable)), Optional.empty());
        PlanNode newLeftNode = arrayAtLeftInput ? joinNode.getRight() : joinNode.getLeft();
        Preconditions.checkState((boolean)(element instanceof VariableReferenceExpression), (Object)"Argument to CONTAINS is not a column");
        EquiJoinClause equiJoinClause = new EquiJoinClause((VariableReferenceExpression)element, unnestVariable);
        List newOutputColumns = (List)Stream.concat(newLeftNode.getOutputVariables().stream(), unnest.getOutputVariables().stream()).collect(ImmutableList.toImmutableList());
        JoinNode newJoinNode = new JoinNode(joinNode.getSourceLocation(), context.getIdAllocator().getNextId(), JoinType.LEFT, newLeftNode, (PlanNode)unnest, (List)ImmutableList.of((Object)equiJoinClause), newOutputColumns, joinNode.getFilter(), Optional.empty(), Optional.empty(), joinNode.getDistributionType(), joinNode.getDynamicFilters());
        SpecialFormExpression isNull = Expressions.specialForm(SpecialFormExpression.Form.IS_NULL, (Type)BooleanType.BOOLEAN, (List<RowExpression>)ImmutableList.of((Object)unnestVariable));
        remainingConjuncts.add(isNull);
        FilterNode filterNode = new FilterNode(node.getSourceLocation(), context.getIdAllocator().getNextId(), (PlanNode)newJoinNode, LogicalRowExpressions.and(remainingConjuncts));
        PlanNode result = PlannerUtils.restrictOutput((PlanNode)filterNode, context.getIdAllocator(), joinNode.getOutputVariables());
        return Rule.Result.ofPlanNode(result);
    }

    private boolean isFromScalarSubquery(Rule.Context context, PlanNode node) {
        PlanNode extractedNode = context.getLookup().resolve(node);
        return extractedNode instanceof EnforceSingleRowNode || extractedNode instanceof ProjectNode && this.isFromScalarSubquery(context, (PlanNode)extractedNode.getSources().get(0));
    }
}

