/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.sql.planner.optimizations;

import com.facebook.presto.Session;
import com.facebook.presto.common.function.OperatorType;
import com.facebook.presto.common.type.BigintType;
import com.facebook.presto.common.type.BooleanType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.expressions.LogicalRowExpressions;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.spi.VariableAllocator;
import com.facebook.presto.spi.WarningCollector;
import com.facebook.presto.spi.function.FunctionMetadataManager;
import com.facebook.presto.spi.function.StandardFunctionResolution;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.Assignments;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.PlanNodeIdAllocator;
import com.facebook.presto.spi.plan.ProjectNode;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.ConstantExpression;
import com.facebook.presto.spi.relation.DeterminismEvaluator;
import com.facebook.presto.spi.relation.QuantifiedComparisonExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.SpecialFormExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.planner.optimizations.PlanOptimizer;
import com.facebook.presto.sql.planner.plan.ApplyNode;
import com.facebook.presto.sql.planner.plan.AssignmentUtils;
import com.facebook.presto.sql.planner.plan.LateralJoinNode;
import com.facebook.presto.sql.planner.plan.SimplePlanRewriter;
import com.facebook.presto.sql.relational.Expressions;
import com.facebook.presto.sql.relational.FunctionResolution;
import com.facebook.presto.sql.relational.RowExpressionDeterminismEvaluator;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterables;
import java.util.Collections;
import java.util.EnumSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Function;

public class TransformQuantifiedComparisonApplyToLateralJoin
implements PlanOptimizer {
    private final StandardFunctionResolution functionResolution;
    private final LogicalRowExpressions logicalRowExpressions;

    public TransformQuantifiedComparisonApplyToLateralJoin(FunctionAndTypeManager functionAndTypeManager) {
        Objects.requireNonNull(functionAndTypeManager, "functionManager is null");
        this.functionResolution = new FunctionResolution(functionAndTypeManager.getFunctionAndTypeResolver());
        this.logicalRowExpressions = new LogicalRowExpressions((DeterminismEvaluator)new RowExpressionDeterminismEvaluator(functionAndTypeManager), (StandardFunctionResolution)new FunctionResolution(functionAndTypeManager.getFunctionAndTypeResolver()), (FunctionMetadataManager)functionAndTypeManager);
    }

    @Override
    public PlanNode optimize(PlanNode plan, Session session, TypeProvider types, VariableAllocator variableAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector) {
        return SimplePlanRewriter.rewriteWith(new Rewriter(this.functionResolution, idAllocator, variableAllocator, this.logicalRowExpressions), plan, null);
    }

    private static class Rewriter
    extends SimplePlanRewriter<PlanNode> {
        private final StandardFunctionResolution functionResolution;
        private final PlanNodeIdAllocator idAllocator;
        private final VariableAllocator variableAllocator;
        private final LogicalRowExpressions logicalRowExpressions;

        public Rewriter(StandardFunctionResolution functionResolution, PlanNodeIdAllocator idAllocator, VariableAllocator variableAllocator, LogicalRowExpressions logicalRowExpressions) {
            this.functionResolution = Objects.requireNonNull(functionResolution, "functionResolution is null");
            this.idAllocator = Objects.requireNonNull(idAllocator, "idAllocator is null");
            this.variableAllocator = Objects.requireNonNull(variableAllocator, "variableAllocator is null");
            this.logicalRowExpressions = Objects.requireNonNull(logicalRowExpressions, "logicalRowExpressions is null");
        }

        @Override
        public PlanNode visitApply(ApplyNode node, SimplePlanRewriter.RewriteContext<PlanNode> context) {
            if (node.getSubqueryAssignments().size() != 1) {
                return context.defaultRewrite(node);
            }
            RowExpression expression = (RowExpression)Iterables.getOnlyElement((Iterable)node.getSubqueryAssignments().getExpressions());
            if (!(expression instanceof QuantifiedComparisonExpression)) {
                return context.defaultRewrite(node);
            }
            QuantifiedComparisonExpression quantifiedComparison = (QuantifiedComparisonExpression)expression;
            return this.rewriteQuantifiedApplyNode(node, quantifiedComparison, context);
        }

        private PlanNode rewriteQuantifiedApplyNode(ApplyNode node, QuantifiedComparisonExpression quantifiedComparison, SimplePlanRewriter.RewriteContext<PlanNode> context) {
            PlanNode subqueryPlan = context.rewrite(node.getSubquery());
            VariableReferenceExpression outputColumn = (VariableReferenceExpression)Iterables.getOnlyElement((Iterable)subqueryPlan.getOutputVariables());
            Type outputColumnType = outputColumn.getType();
            Preconditions.checkState((boolean)outputColumnType.isOrderable(), (Object)"Subquery result type must be orderable");
            VariableReferenceExpression minValue = this.variableAllocator.newVariable(outputColumn.getSourceLocation(), "min", outputColumnType);
            VariableReferenceExpression maxValue = this.variableAllocator.newVariable(outputColumn.getSourceLocation(), "max", outputColumnType);
            VariableReferenceExpression countAllValue = this.variableAllocator.newVariable(outputColumn.getSourceLocation(), "count_all", (Type)BigintType.BIGINT);
            VariableReferenceExpression countNonNullValue = this.variableAllocator.newVariable(outputColumn.getSourceLocation(), "count_non_null", (Type)BigintType.BIGINT);
            ImmutableList outputColumnReferences = ImmutableList.of((Object)outputColumn);
            subqueryPlan = new AggregationNode(node.getSourceLocation(), this.idAllocator.getNextId(), subqueryPlan, (Map)ImmutableMap.of((Object)minValue, (Object)new AggregationNode.Aggregation(new CallExpression(quantifiedComparison.getSourceLocation(), "min", this.functionResolution.minFunction(outputColumnType), outputColumnType, (List)outputColumnReferences), Optional.empty(), Optional.empty(), false, Optional.empty()), (Object)maxValue, (Object)new AggregationNode.Aggregation(new CallExpression(quantifiedComparison.getSourceLocation(), "max", this.functionResolution.maxFunction(outputColumnType), outputColumnType, (List)outputColumnReferences), Optional.empty(), Optional.empty(), false, Optional.empty()), (Object)countAllValue, (Object)new AggregationNode.Aggregation(new CallExpression(quantifiedComparison.getSourceLocation(), "count", this.functionResolution.countFunction(), (Type)BigintType.BIGINT, Collections.emptyList()), Optional.empty(), Optional.empty(), false, Optional.empty()), (Object)countNonNullValue, (Object)new AggregationNode.Aggregation(new CallExpression(quantifiedComparison.getSourceLocation(), "count", this.functionResolution.countFunction(outputColumnType), (Type)BigintType.BIGINT, (List)outputColumnReferences), Optional.empty(), Optional.empty(), false, Optional.empty())), AggregationNode.globalAggregation(), (List)ImmutableList.of(), AggregationNode.Step.SINGLE, Optional.empty(), Optional.empty());
            LateralJoinNode lateralJoinNode = new LateralJoinNode(node.getSourceLocation(), node.getId(), context.rewrite(node.getInput()), subqueryPlan, node.getCorrelation(), LateralJoinNode.Type.INNER, node.getOriginSubqueryError());
            RowExpression valueComparedToSubquery = this.rewriteUsingBounds(quantifiedComparison, minValue, maxValue, countAllValue, countNonNullValue);
            VariableReferenceExpression quantifiedComparisonVariable = (VariableReferenceExpression)Iterables.getOnlyElement((Iterable)node.getSubqueryAssignments().getVariables());
            return this.projectExpressions(lateralJoinNode, Assignments.of((VariableReferenceExpression)quantifiedComparisonVariable, (RowExpression)valueComparedToSubquery));
        }

        public RowExpression rewriteUsingBounds(QuantifiedComparisonExpression quantifiedComparison, VariableReferenceExpression minValue, VariableReferenceExpression maxValue, VariableReferenceExpression countAllValue, VariableReferenceExpression countNonNullValue) {
            ConstantExpression emptySetResult;
            ConstantExpression constantExpression = emptySetResult = quantifiedComparison.getQuantifier().equals((Object)QuantifiedComparisonExpression.Quantifier.ALL) ? LogicalRowExpressions.TRUE_CONSTANT : LogicalRowExpressions.FALSE_CONSTANT;
            Function<List, RowExpression> quantifier = quantifiedComparison.getQuantifier().equals((Object)QuantifiedComparisonExpression.Quantifier.ALL) ? arg_0 -> ((LogicalRowExpressions)this.logicalRowExpressions).combineConjuncts(arg_0) : arg_0 -> ((LogicalRowExpressions)this.logicalRowExpressions).combineDisjuncts(arg_0);
            RowExpression comparisonWithExtremeValue = this.getBoundComparisons(quantifiedComparison, minValue, maxValue);
            SpecialFormExpression whenClause = Expressions.specialForm(SpecialFormExpression.Form.WHEN, (Type)BooleanType.BOOLEAN, new RowExpression[]{Expressions.comparisonExpression(this.functionResolution, OperatorType.NOT_EQUAL, (RowExpression)countAllValue, (RowExpression)countNonNullValue), new ConstantExpression(null, (Type)BooleanType.BOOLEAN)});
            return Expressions.buildSwitch((RowExpression)countAllValue, (List<RowExpression>)ImmutableList.of((Object)Expressions.specialForm(SpecialFormExpression.Form.WHEN, (Type)BooleanType.BOOLEAN, new RowExpression[]{new ConstantExpression((Object)0L, (Type)BigintType.BIGINT), emptySetResult})), Optional.of(quantifier.apply((List)ImmutableList.of((Object)comparisonWithExtremeValue, (Object)Expressions.searchedCaseExpression((List<RowExpression>)ImmutableList.of((Object)whenClause), Optional.of(emptySetResult))))), (Type)BooleanType.BOOLEAN);
        }

        private RowExpression getBoundComparisons(QuantifiedComparisonExpression quantifiedComparison, VariableReferenceExpression minValue, VariableReferenceExpression maxValue) {
            if (quantifiedComparison.getOperator() == OperatorType.EQUAL && quantifiedComparison.getQuantifier() == QuantifiedComparisonExpression.Quantifier.ALL) {
                return this.logicalRowExpressions.combineConjuncts(new RowExpression[]{Expressions.comparisonExpression(this.functionResolution, OperatorType.EQUAL, (RowExpression)minValue, (RowExpression)maxValue), Expressions.comparisonExpression(this.functionResolution, OperatorType.EQUAL, quantifiedComparison.getValue(), (RowExpression)maxValue)});
            }
            if (EnumSet.of(OperatorType.LESS_THAN, OperatorType.LESS_THAN_OR_EQUAL, OperatorType.GREATER_THAN, OperatorType.GREATER_THAN_OR_EQUAL).contains(quantifiedComparison.getOperator())) {
                VariableReferenceExpression boundValue = Rewriter.shouldCompareValueWithLowerBound(quantifiedComparison) ? minValue : maxValue;
                return Expressions.comparisonExpression(this.functionResolution, quantifiedComparison.getOperator(), quantifiedComparison.getValue(), (RowExpression)boundValue);
            }
            throw new IllegalArgumentException("Unsupported quantified comparison: " + quantifiedComparison);
        }

        private static boolean shouldCompareValueWithLowerBound(QuantifiedComparisonExpression quantifiedComparison) {
            switch (quantifiedComparison.getQuantifier()) {
                case ALL: {
                    switch (quantifiedComparison.getOperator()) {
                        case LESS_THAN: 
                        case LESS_THAN_OR_EQUAL: {
                            return true;
                        }
                        case GREATER_THAN: 
                        case GREATER_THAN_OR_EQUAL: {
                            return false;
                        }
                    }
                    break;
                }
                case ANY: 
                case SOME: {
                    switch (quantifiedComparison.getOperator()) {
                        case LESS_THAN: 
                        case LESS_THAN_OR_EQUAL: {
                            return false;
                        }
                        case GREATER_THAN: 
                        case GREATER_THAN_OR_EQUAL: {
                            return true;
                        }
                    }
                }
            }
            throw new IllegalArgumentException("Unexpected quantifier: " + quantifiedComparison.getQuantifier());
        }

        private ProjectNode projectExpressions(PlanNode input, Assignments subqueryAssignments) {
            Assignments assignments = Assignments.builder().putAll(AssignmentUtils.identityAssignments(input.getOutputVariables())).putAll(subqueryAssignments).build();
            return new ProjectNode(this.idAllocator.getNextId(), input, assignments);
        }
    }
}

