/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.common.type.BigintType;
import com.facebook.presto.common.type.BooleanType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.matching.Captures;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.spi.function.StandardFunctionResolution;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.Assignments;
import com.facebook.presto.spi.plan.FilterNode;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.PlanNodeIdAllocator;
import com.facebook.presto.spi.plan.PlanVisitor;
import com.facebook.presto.spi.plan.ProjectNode;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.ExpressionUtils;
import com.facebook.presto.sql.planner.PlanVariableAllocator;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.planner.VariablesExtractor;
import com.facebook.presto.sql.planner.iterative.Lookup;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.planner.plan.ApplyNode;
import com.facebook.presto.sql.planner.plan.AssignUniqueId;
import com.facebook.presto.sql.planner.plan.AssignmentUtils;
import com.facebook.presto.sql.planner.plan.InternalPlanVisitor;
import com.facebook.presto.sql.planner.plan.JoinNode;
import com.facebook.presto.sql.planner.plan.Patterns;
import com.facebook.presto.sql.relational.FunctionResolution;
import com.facebook.presto.sql.relational.OriginalExpressionUtils;
import com.facebook.presto.sql.tree.BooleanLiteral;
import com.facebook.presto.sql.tree.Cast;
import com.facebook.presto.sql.tree.ComparisonExpression;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.InPredicate;
import com.facebook.presto.sql.tree.IsNotNullPredicate;
import com.facebook.presto.sql.tree.IsNullPredicate;
import com.facebook.presto.sql.tree.LongLiteral;
import com.facebook.presto.sql.tree.NotExpression;
import com.facebook.presto.sql.tree.NullLiteral;
import com.facebook.presto.sql.tree.SearchedCaseExpression;
import com.facebook.presto.sql.tree.SymbolReference;
import com.facebook.presto.sql.tree.WhenClause;
import com.facebook.presto.sql.util.AstUtils;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import javax.annotation.Nullable;

public class TransformCorrelatedInPredicateToJoin
implements Rule<ApplyNode> {
    private static final Pattern<ApplyNode> PATTERN = Patterns.applyNode().with(Pattern.nonEmpty(Patterns.Apply.correlation()));
    private final StandardFunctionResolution functionResolution;

    public TransformCorrelatedInPredicateToJoin(FunctionAndTypeManager functionAndTypeManager) {
        Objects.requireNonNull(functionAndTypeManager, "functionManager is null");
        this.functionResolution = new FunctionResolution(functionAndTypeManager);
    }

    @Override
    public Pattern<ApplyNode> getPattern() {
        return PATTERN;
    }

    @Override
    public Rule.Result apply(ApplyNode apply, Captures captures, Rule.Context context) {
        Assignments subqueryAssignments = apply.getSubqueryAssignments();
        if (subqueryAssignments.size() != 1) {
            return Rule.Result.empty();
        }
        Expression assignmentExpression = OriginalExpressionUtils.castToExpression((RowExpression)Iterables.getOnlyElement((Iterable)subqueryAssignments.getExpressions()));
        if (!(assignmentExpression instanceof InPredicate)) {
            return Rule.Result.empty();
        }
        InPredicate inPredicate = (InPredicate)assignmentExpression;
        VariableReferenceExpression inPredicateOutputVariable = (VariableReferenceExpression)Iterables.getOnlyElement((Iterable)subqueryAssignments.getVariables());
        return this.apply(apply, inPredicate, inPredicateOutputVariable, context.getLookup(), context.getIdAllocator(), context.getVariableAllocator());
    }

    private Rule.Result apply(ApplyNode apply, InPredicate inPredicate, VariableReferenceExpression inPredicateOutputVariable, Lookup lookup, PlanNodeIdAllocator idAllocator, PlanVariableAllocator variableAllocator) {
        Optional<Decorrelated> decorrelated = new DecorrelatingVisitor(lookup, apply.getCorrelation(), variableAllocator.getTypes()).decorrelate(apply.getSubquery());
        if (!decorrelated.isPresent()) {
            return Rule.Result.empty();
        }
        PlanNode projection = this.buildInPredicateEquivalent(apply, inPredicate, inPredicateOutputVariable, decorrelated.get(), idAllocator, variableAllocator);
        return Rule.Result.ofPlanNode(projection);
    }

    private PlanNode buildInPredicateEquivalent(ApplyNode apply, InPredicate inPredicate, VariableReferenceExpression inPredicateOutputVariable, Decorrelated decorrelated, PlanNodeIdAllocator idAllocator, PlanVariableAllocator variableAllocator) {
        Expression correlationCondition = ExpressionUtils.and(decorrelated.getCorrelatedPredicates());
        PlanNode decorrelatedBuildSource = decorrelated.getDecorrelatedNode();
        AssignUniqueId probeSide = new AssignUniqueId(idAllocator.getNextId(), apply.getInput(), variableAllocator.newVariable("unique", (Type)BigintType.BIGINT));
        VariableReferenceExpression buildSideKnownNonNull = variableAllocator.newVariable("buildSideKnownNonNull", (Type)BigintType.BIGINT);
        ProjectNode buildSide = new ProjectNode(idAllocator.getNextId(), decorrelatedBuildSource, Assignments.builder().putAll(AssignmentUtils.identitiesAsSymbolReferences(decorrelatedBuildSource.getOutputVariables())).put(buildSideKnownNonNull, OriginalExpressionUtils.castToRowExpression(TransformCorrelatedInPredicateToJoin.bigint(0L))).build());
        Preconditions.checkArgument((boolean)(inPredicate.getValue() instanceof SymbolReference), (String)"Unexpected expression: %s", (Object)inPredicate.getValue());
        SymbolReference probeSideSymbolReference = (SymbolReference)inPredicate.getValue();
        Preconditions.checkArgument((boolean)(inPredicate.getValueList() instanceof SymbolReference), (String)"Unexpected expression: %s", (Object)inPredicate.getValueList());
        SymbolReference buildSideSymbolReference = (SymbolReference)inPredicate.getValueList();
        Expression joinExpression = ExpressionUtils.and(ExpressionUtils.or(new Expression[]{new IsNullPredicate((Expression)probeSideSymbolReference), new ComparisonExpression(ComparisonExpression.Operator.EQUAL, (Expression)probeSideSymbolReference, (Expression)buildSideSymbolReference), new IsNullPredicate((Expression)buildSideSymbolReference)}), correlationCondition);
        JoinNode leftOuterJoin = TransformCorrelatedInPredicateToJoin.leftOuterJoin(idAllocator, probeSide, buildSide, joinExpression);
        VariableReferenceExpression countMatchesVariable = variableAllocator.newVariable("countMatches", (Type)BigintType.BIGINT);
        VariableReferenceExpression countNullMatchesVariable = variableAllocator.newVariable("countNullMatches", (Type)BigintType.BIGINT);
        Expression matchCondition = ExpressionUtils.and(new Expression[]{new IsNotNullPredicate((Expression)probeSideSymbolReference), new IsNotNullPredicate((Expression)buildSideSymbolReference)});
        Expression nullMatchCondition = ExpressionUtils.and(new Expression[]{new IsNotNullPredicate((Expression)new SymbolReference(buildSideKnownNonNull.getName())), new NotExpression(matchCondition)});
        AggregationNode aggregation = new AggregationNode(idAllocator.getNextId(), (PlanNode)leftOuterJoin, (Map)ImmutableMap.builder().put((Object)countMatchesVariable, (Object)this.countWithFilter(matchCondition)).put((Object)countNullMatchesVariable, (Object)this.countWithFilter(nullMatchCondition)).build(), AggregationNode.singleGroupingSet(probeSide.getOutputVariables()), (List)ImmutableList.of(), AggregationNode.Step.SINGLE, Optional.empty(), Optional.empty());
        SearchedCaseExpression inPredicateEquivalent = new SearchedCaseExpression((List)ImmutableList.of((Object)new WhenClause(TransformCorrelatedInPredicateToJoin.isGreaterThan(countMatchesVariable, 0L), TransformCorrelatedInPredicateToJoin.booleanConstant(true)), (Object)new WhenClause(TransformCorrelatedInPredicateToJoin.isGreaterThan(countNullMatchesVariable, 0L), TransformCorrelatedInPredicateToJoin.booleanConstant(null))), Optional.of(TransformCorrelatedInPredicateToJoin.booleanConstant(false)));
        return new ProjectNode(idAllocator.getNextId(), (PlanNode)aggregation, Assignments.builder().putAll(AssignmentUtils.identitiesAsSymbolReferences(apply.getInput().getOutputVariables())).put(inPredicateOutputVariable, OriginalExpressionUtils.castToRowExpression((Expression)inPredicateEquivalent)).build());
    }

    private static JoinNode leftOuterJoin(PlanNodeIdAllocator idAllocator, AssignUniqueId probeSide, ProjectNode buildSide, Expression joinExpression) {
        return new JoinNode(idAllocator.getNextId(), JoinNode.Type.LEFT, probeSide, (PlanNode)buildSide, (List<JoinNode.EquiJoinClause>)ImmutableList.of(), (List<VariableReferenceExpression>)ImmutableList.builder().addAll(probeSide.getOutputVariables()).addAll((Iterable)buildSide.getOutputVariables()).build(), Optional.of(OriginalExpressionUtils.castToRowExpression(joinExpression)), Optional.empty(), Optional.empty(), Optional.empty(), (Map<String, VariableReferenceExpression>)ImmutableMap.of());
    }

    private AggregationNode.Aggregation countWithFilter(Expression condition) {
        return new AggregationNode.Aggregation(new CallExpression("count", this.functionResolution.countFunction(), (Type)BigintType.BIGINT, (List)ImmutableList.of()), Optional.of(OriginalExpressionUtils.castToRowExpression(condition)), Optional.empty(), false, Optional.empty());
    }

    private static Expression isGreaterThan(VariableReferenceExpression variable, long value) {
        return new ComparisonExpression(ComparisonExpression.Operator.GREATER_THAN, (Expression)new SymbolReference(variable.getName()), TransformCorrelatedInPredicateToJoin.bigint(value));
    }

    private static Expression bigint(long value) {
        return new Cast((Expression)new LongLiteral(String.valueOf(value)), BigintType.BIGINT.toString());
    }

    private static Expression booleanConstant(@Nullable Boolean value) {
        if (value == null) {
            return new Cast((Expression)new NullLiteral(), BooleanType.BOOLEAN.toString());
        }
        return new BooleanLiteral(value.toString());
    }

    private static class Decorrelated {
        private final List<Expression> correlatedPredicates;
        private final PlanNode decorrelatedNode;

        public Decorrelated(List<Expression> correlatedPredicates, PlanNode decorrelatedNode) {
            this.correlatedPredicates = ImmutableList.copyOf((Collection)Objects.requireNonNull(correlatedPredicates, "correlatedPredicates is null"));
            this.decorrelatedNode = Objects.requireNonNull(decorrelatedNode, "decorrelatedNode is null");
        }

        public List<Expression> getCorrelatedPredicates() {
            return this.correlatedPredicates;
        }

        public PlanNode getDecorrelatedNode() {
            return this.decorrelatedNode;
        }
    }

    private static class DecorrelatingVisitor
    extends InternalPlanVisitor<Optional<Decorrelated>, PlanNode> {
        private final Lookup lookup;
        private final Set<VariableReferenceExpression> correlation;
        private final TypeProvider types;

        public DecorrelatingVisitor(Lookup lookup, Iterable<VariableReferenceExpression> correlation, TypeProvider types) {
            this.lookup = Objects.requireNonNull(lookup, "lookup is null");
            this.correlation = ImmutableSet.copyOf(Objects.requireNonNull(correlation, "correlation is null"));
            this.types = Objects.requireNonNull(types, "types is null");
        }

        public Optional<Decorrelated> decorrelate(PlanNode reference) {
            return (Optional)this.lookup.resolve(reference).accept((PlanVisitor)this, (Object)reference);
        }

        public Optional<Decorrelated> visitProject(ProjectNode node, PlanNode reference) {
            if (this.isCorrelatedShallowly((PlanNode)node)) {
                return Optional.empty();
            }
            Optional<Decorrelated> result = this.decorrelate(node.getSource());
            return result.map(decorrelated -> {
                Assignments.Builder assignments = Assignments.builder().putAll(node.getAssignments());
                decorrelated.getCorrelatedPredicates().stream().flatMap(AstUtils::preOrder).filter(SymbolReference.class::isInstance).map(SymbolReference.class::cast).map(symbolReference -> new VariableReferenceExpression(symbolReference.getName(), this.types.get((Expression)symbolReference))).filter(variable -> !this.correlation.contains(variable)).map(AssignmentUtils::identityAsSymbolReference).forEach(arg_0 -> ((Assignments.Builder)assignments).put(arg_0));
                return new Decorrelated(decorrelated.getCorrelatedPredicates(), (PlanNode)new ProjectNode(node.getId(), decorrelated.getDecorrelatedNode(), assignments.build()));
            });
        }

        public Optional<Decorrelated> visitFilter(FilterNode node, PlanNode reference) {
            Optional<Decorrelated> result = this.decorrelate(node.getSource());
            return result.map(decorrelated -> new Decorrelated((List<Expression>)ImmutableList.builder().addAll(decorrelated.getCorrelatedPredicates()).add((Object)OriginalExpressionUtils.castToExpression(node.getPredicate())).build(), decorrelated.getDecorrelatedNode()));
        }

        public Optional<Decorrelated> visitPlan(PlanNode node, PlanNode reference) {
            if (this.isCorrelatedRecursively(node)) {
                return Optional.empty();
            }
            return Optional.of(new Decorrelated((List<Expression>)ImmutableList.of(), reference));
        }

        private boolean isCorrelatedRecursively(PlanNode node) {
            if (this.isCorrelatedShallowly(node)) {
                return true;
            }
            return node.getSources().stream().map(this.lookup::resolve).anyMatch(this::isCorrelatedRecursively);
        }

        private boolean isCorrelatedShallowly(PlanNode node) {
            return VariablesExtractor.extractUniqueNonRecursive(node, this.types).stream().anyMatch(this.correlation::contains);
        }
    }
}

