/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.Session;
import com.facebook.presto.SystemSessionProperties;
import com.facebook.presto.common.type.ArrayType;
import com.facebook.presto.common.type.BigintType;
import com.facebook.presto.common.type.BooleanType;
import com.facebook.presto.common.type.DateType;
import com.facebook.presto.common.type.IntegerType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.common.type.VarcharType;
import com.facebook.presto.expressions.LogicalRowExpressions;
import com.facebook.presto.matching.Capture;
import com.facebook.presto.matching.Captures;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.metadata.CastType;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.spi.VariableAllocator;
import com.facebook.presto.spi.plan.Assignments;
import com.facebook.presto.spi.plan.EquiJoinClause;
import com.facebook.presto.spi.plan.FilterNode;
import com.facebook.presto.spi.plan.JoinNode;
import com.facebook.presto.spi.plan.JoinType;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.PlanNodeIdAllocator;
import com.facebook.presto.spi.plan.ProjectNode;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.SpecialFormExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.PlannerUtils;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.planner.plan.Patterns;
import com.facebook.presto.sql.planner.plan.UnnestNode;
import com.facebook.presto.sql.relational.Expressions;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;

public class CrossJoinWithOrFilterToInnerJoin
implements Rule<FilterNode> {
    private static final List<Type> SUPPORTED_JOIN_KEY_TYPE = ImmutableList.of((Object)BigintType.BIGINT, (Object)IntegerType.INTEGER, (Object)VarcharType.VARCHAR, (Object)DateType.DATE);
    private static final Capture<JoinNode> CHILD = Capture.newCapture();
    private static final Pattern<FilterNode> PATTERN = Patterns.filter().with(Patterns.source().matching(Patterns.join().matching(x -> x.getType().equals((Object)JoinType.INNER) && x.getCriteria().isEmpty()).capturedAs(CHILD)));
    private final FunctionAndTypeManager functionAndTypeManager;

    public CrossJoinWithOrFilterToInnerJoin(FunctionAndTypeManager functionAndTypeManager) {
        this.functionAndTypeManager = Objects.requireNonNull(functionAndTypeManager, "functionAndTypeManager is null");
    }

    private static boolean isValidExpression(RowExpression rowExpression, List<VariableReferenceExpression> leftInput, List<VariableReferenceExpression> rightInput) {
        if (!(rowExpression instanceof CallExpression) || !((CallExpression)rowExpression).getDisplayName().equals("EQUAL")) {
            return false;
        }
        CallExpression callExpression = (CallExpression)rowExpression;
        RowExpression argument0 = (RowExpression)callExpression.getArguments().get(0);
        RowExpression argument1 = (RowExpression)callExpression.getArguments().get(1);
        return SUPPORTED_JOIN_KEY_TYPE.containsAll((Collection<?>)ImmutableList.of((Object)argument0.getType(), (Object)argument1.getType())) && (leftInput.contains(argument0) && rightInput.contains(argument1) || leftInput.contains(argument1) && rightInput.contains(argument0));
    }

    public static RowExpression getCandidateOrExpression(RowExpression filterPredicate, List<VariableReferenceExpression> leftInput, List<VariableReferenceExpression> rightInput) {
        List andConjuncts = LogicalRowExpressions.extractConjuncts((RowExpression)filterPredicate);
        for (RowExpression conjunct : andConjuncts) {
            List equalExpressionList = LogicalRowExpressions.extractDisjuncts((RowExpression)conjunct);
            if (equalExpressionList.isEmpty() || !equalExpressionList.stream().allMatch(x -> CrossJoinWithOrFilterToInnerJoin.isValidExpression(x, leftInput, rightInput))) continue;
            return conjunct;
        }
        return null;
    }

    @Override
    public Pattern<FilterNode> getPattern() {
        return PATTERN;
    }

    @Override
    public boolean isEnabled(Session session) {
        return SystemSessionProperties.isRewriteCrossJoinOrToInnerJoinEnabled(session);
    }

    private RewrittenJoinInput rewriteJoinInput(List<VariableReferenceExpression> variablesInOrCondition, PlanNode joinInput, Type finalJoinKeyType, VariableAllocator variableAllocator, PlanNodeIdAllocator idAllocator) {
        HashMap<VariableReferenceExpression, VariableReferenceExpression> castVariableMap = new HashMap<VariableReferenceExpression, VariableReferenceExpression>();
        HashMap<VariableReferenceExpression, CallExpression> castExpressionMap = new HashMap<VariableReferenceExpression, CallExpression>();
        if (!variablesInOrCondition.stream().allMatch(x -> x.getType().equals(finalJoinKeyType))) {
            for (int i = 0; i < variablesInOrCondition.size(); ++i) {
                CallExpression castExpression = Expressions.call("CAST", this.functionAndTypeManager.lookupCast(CastType.CAST, variablesInOrCondition.get(i).getType(), (Type)VarcharType.VARCHAR), (Type)VarcharType.VARCHAR, (RowExpression)variablesInOrCondition.get(i));
                VariableReferenceExpression castVariable = variableAllocator.newVariable((RowExpression)castExpression);
                castVariableMap.put(variablesInOrCondition.get(i), castVariable);
                castExpressionMap.put(castVariable, castExpression);
            }
        }
        ImmutableList.Builder constantsArgument = ImmutableList.builder();
        for (int i = 0; i < variablesInOrCondition.size(); ++i) {
            constantsArgument.add((Object)Expressions.constant((long)i + 1L, (Type)IntegerType.INTEGER));
        }
        CallExpression arrayConstruct = Expressions.call(this.functionAndTypeManager, "array_constructor", (Type)new ArrayType((Type)IntegerType.INTEGER), (List<RowExpression>)constantsArgument.build());
        VariableReferenceExpression arrayVariable = variableAllocator.newVariable((RowExpression)arrayConstruct);
        ImmutableMap.Builder projectAssignment = ImmutableMap.builder();
        PlanNode project = PlannerUtils.addProjections(joinInput, idAllocator, (Map<VariableReferenceExpression, RowExpression>)projectAssignment.put((Object)arrayVariable, (Object)arrayConstruct).putAll(castExpressionMap).build());
        VariableReferenceExpression unnestVariable = variableAllocator.newVariable("field", (Type)IntegerType.INTEGER);
        UnnestNode unnest = new UnnestNode(joinInput.getSourceLocation(), idAllocator.getNextId(), project, (List)project.getOutputVariables().stream().filter(x -> !x.equals((Object)arrayVariable)).collect(ImmutableList.toImmutableList()), (Map<VariableReferenceExpression, List<VariableReferenceExpression>>)ImmutableMap.of((Object)arrayVariable, (Object)ImmutableList.of((Object)unnestVariable)), Optional.empty());
        ImmutableList.Builder whenExpression = ImmutableList.builder();
        whenExpression.add((Object)unnestVariable);
        for (int i = 0; i < variablesInOrCondition.size(); ++i) {
            whenExpression.add((Object)new SpecialFormExpression(SpecialFormExpression.Form.WHEN, finalJoinKeyType, new RowExpression[]{Expressions.constant((long)i + 1L, (Type)IntegerType.INTEGER), castVariableMap.isEmpty() ? (RowExpression)variablesInOrCondition.get(i) : (RowExpression)castVariableMap.get(variablesInOrCondition.get(i))}));
        }
        whenExpression.add((Object)Expressions.constantNull(finalJoinKeyType));
        SpecialFormExpression joinKeyExpression = new SpecialFormExpression(SpecialFormExpression.Form.SWITCH, finalJoinKeyType, (List)whenExpression.build());
        VariableReferenceExpression newJoinVariable = variableAllocator.newVariable((RowExpression)joinKeyExpression);
        PlanNode rewrittenInput = PlannerUtils.addProjections(unnest, idAllocator, variableAllocator, (List<RowExpression>)ImmutableList.of((Object)joinKeyExpression), (List<VariableReferenceExpression>)ImmutableList.of((Object)newJoinVariable));
        return new RewrittenJoinInput(rewrittenInput, unnestVariable, newJoinVariable);
    }

    private VariableReferenceExpression getVariableInEqualComparison(RowExpression rowExpression, List<VariableReferenceExpression> candidate) {
        Preconditions.checkArgument((rowExpression instanceof CallExpression && ((CallExpression)rowExpression).getDisplayName().equals("EQUAL") ? 1 : 0) != 0);
        CallExpression callExpression = (CallExpression)rowExpression;
        RowExpression argument0 = (RowExpression)callExpression.getArguments().get(0);
        RowExpression argument1 = (RowExpression)callExpression.getArguments().get(1);
        if (candidate.contains(argument0)) {
            return (VariableReferenceExpression)argument0;
        }
        if (candidate.contains(argument1)) {
            return (VariableReferenceExpression)argument1;
        }
        Preconditions.checkState((boolean)false, (Object)"argument does not exist in candidate list");
        return null;
    }

    @Override
    public Rule.Result apply(FilterNode filterNode, Captures captures, Rule.Context context) {
        JoinNode joinNode = (JoinNode)captures.get(CHILD);
        if (!joinNode.getType().equals((Object)JoinType.INNER) || !joinNode.getCriteria().isEmpty()) {
            return Rule.Result.empty();
        }
        RowExpression candidateOrExpressions = CrossJoinWithOrFilterToInnerJoin.getCandidateOrExpression(filterNode.getPredicate(), joinNode.getLeft().getOutputVariables(), joinNode.getRight().getOutputVariables());
        if (candidateOrExpressions == null) {
            return Rule.Result.empty();
        }
        List andConjuncts = LogicalRowExpressions.extractConjuncts((RowExpression)filterNode.getPredicate());
        List leftAndConjuncts = (List)andConjuncts.stream().filter(x -> !x.equals((Object)candidateOrExpressions)).collect(ImmutableList.toImmutableList());
        List equalExpressionList = LogicalRowExpressions.extractDisjuncts((RowExpression)candidateOrExpressions);
        List variablesUsedInOrComparisionFromLeft = (List)equalExpressionList.stream().map(x -> this.getVariableInEqualComparison((RowExpression)x, joinNode.getLeft().getOutputVariables())).collect(ImmutableList.toImmutableList());
        List variablesUsedInOrComparisionFromRight = (List)equalExpressionList.stream().map(x -> this.getVariableInEqualComparison((RowExpression)x, joinNode.getRight().getOutputVariables())).collect(ImmutableList.toImmutableList());
        if (variablesUsedInOrComparisionFromLeft.isEmpty() || variablesUsedInOrComparisionFromRight.isEmpty()) {
            return Rule.Result.empty();
        }
        if (variablesUsedInOrComparisionFromLeft.stream().anyMatch(x -> !SUPPORTED_JOIN_KEY_TYPE.contains(x.getType())) || variablesUsedInOrComparisionFromRight.stream().anyMatch(x -> !SUPPORTED_JOIN_KEY_TYPE.contains(x.getType()))) {
            return Rule.Result.empty();
        }
        VarcharType joinKeyType = VarcharType.VARCHAR;
        List leftOrPredicateTypes = (List)variablesUsedInOrComparisionFromLeft.stream().map(x -> x.getType()).distinct().collect(ImmutableList.toImmutableList());
        List rightOrPredicateTypes = (List)variablesUsedInOrComparisionFromRight.stream().map(x -> x.getType()).distinct().collect(ImmutableList.toImmutableList());
        if (leftOrPredicateTypes.size() == 1 && rightOrPredicateTypes.size() == 1 && ((Type)leftOrPredicateTypes.get(0)).equals(rightOrPredicateTypes.get(0))) {
            joinKeyType = (Type)leftOrPredicateTypes.get(0);
        }
        RewrittenJoinInput leftJoinInput = this.rewriteJoinInput(variablesUsedInOrComparisionFromLeft, joinNode.getLeft(), (Type)joinKeyType, context.getVariableAllocator(), context.getIdAllocator());
        RewrittenJoinInput rightJoinInput = this.rewriteJoinInput(variablesUsedInOrComparisionFromRight, joinNode.getRight(), (Type)joinKeyType, context.getVariableAllocator(), context.getIdAllocator());
        ImmutableList.Builder joinOutput = ImmutableList.builder();
        joinOutput.add((Object)leftJoinInput.getJoinKey()).add((Object)leftJoinInput.getUnnestIndex()).addAll((Iterable)joinNode.getOutputVariables());
        JoinNode newJoinNode = new JoinNode(joinNode.getSourceLocation(), context.getIdAllocator().getNextId(), joinNode.getType(), leftJoinInput.getNode(), rightJoinInput.getNode(), (List)ImmutableList.of((Object)new EquiJoinClause(leftJoinInput.getJoinKey(), rightJoinInput.getJoinKey()), (Object)new EquiJoinClause(leftJoinInput.getUnnestIndex(), rightJoinInput.getUnnestIndex())), (List)joinOutput.build(), joinNode.getFilter(), Optional.empty(), Optional.empty(), joinNode.getDistributionType(), joinNode.getDynamicFilters());
        ImmutableList.Builder whenExpression = ImmutableList.builder();
        whenExpression.add((Object)leftJoinInput.getUnnestIndex());
        for (int i = 0; i < equalExpressionList.size(); ++i) {
            ImmutableList.Builder matchCondition = ImmutableList.builder();
            for (int j = 0; j < i; ++j) {
                matchCondition.add((Object)Expressions.not(this.functionAndTypeManager, (RowExpression)Expressions.coalesceNullToFalse((RowExpression)equalExpressionList.get(j))));
            }
            matchCondition.add((Object)((RowExpression)equalExpressionList.get(i)));
            whenExpression.add((Object)new SpecialFormExpression(SpecialFormExpression.Form.WHEN, (Type)BooleanType.BOOLEAN, new RowExpression[]{Expressions.constant((long)i + 1L, (Type)IntegerType.INTEGER), LogicalRowExpressions.and((Collection)matchCondition.build())}));
        }
        whenExpression.add((Object)Expressions.constantNull((Type)BooleanType.BOOLEAN));
        SpecialFormExpression dedupFilter = new SpecialFormExpression(SpecialFormExpression.Form.SWITCH, (Type)BooleanType.BOOLEAN, (List)whenExpression.build());
        FilterNode newFilterNode = new FilterNode(joinNode.getSourceLocation(), context.getIdAllocator().getNextId(), (PlanNode)newJoinNode, (RowExpression)dedupFilter);
        if (!leftAndConjuncts.isEmpty()) {
            newFilterNode = new FilterNode(filterNode.getSourceLocation(), context.getIdAllocator().getNextId(), (PlanNode)newFilterNode, LogicalRowExpressions.and((Collection)leftAndConjuncts));
        }
        Assignments.Builder identity = Assignments.builder();
        identity.putAll((Map)filterNode.getOutputVariables().stream().collect(ImmutableMap.toImmutableMap(x -> x, x -> x)));
        ProjectNode projectUnusedOutput = new ProjectNode(context.getIdAllocator().getNextId(), (PlanNode)newFilterNode, identity.build());
        return Rule.Result.ofPlanNode((PlanNode)projectUnusedOutput);
    }

    private static class RewrittenJoinInput {
        private final PlanNode node;
        private final VariableReferenceExpression unnestIndex;
        private final VariableReferenceExpression joinKey;

        public RewrittenJoinInput(PlanNode node, VariableReferenceExpression unnestIndex, VariableReferenceExpression joinKey) {
            this.node = node;
            this.unnestIndex = unnestIndex;
            this.joinKey = joinKey;
        }

        public PlanNode getNode() {
            return this.node;
        }

        public VariableReferenceExpression getJoinKey() {
            return this.joinKey;
        }

        public VariableReferenceExpression getUnnestIndex() {
            return this.unnestIndex;
        }
    }
}

