/*
 * Decompiled with CFR 0.152.
 */
package io.prestosql.sql.planner.optimizations;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterables;
import io.prestosql.Session;
import io.prestosql.execution.warnings.WarningCollector;
import io.prestosql.metadata.FunctionRegistry;
import io.prestosql.metadata.Metadata;
import io.prestosql.spi.type.BigintType;
import io.prestosql.spi.type.BooleanType;
import io.prestosql.spi.type.Type;
import io.prestosql.spi.type.TypeSignature;
import io.prestosql.sql.ExpressionUtils;
import io.prestosql.sql.analyzer.TypeSignatureProvider;
import io.prestosql.sql.planner.PlanNodeIdAllocator;
import io.prestosql.sql.planner.Symbol;
import io.prestosql.sql.planner.SymbolAllocator;
import io.prestosql.sql.planner.TypeProvider;
import io.prestosql.sql.planner.optimizations.PlanOptimizer;
import io.prestosql.sql.planner.plan.AggregationNode;
import io.prestosql.sql.planner.plan.ApplyNode;
import io.prestosql.sql.planner.plan.Assignments;
import io.prestosql.sql.planner.plan.LateralJoinNode;
import io.prestosql.sql.planner.plan.PlanNode;
import io.prestosql.sql.planner.plan.ProjectNode;
import io.prestosql.sql.planner.plan.SimplePlanRewriter;
import io.prestosql.sql.tree.BooleanLiteral;
import io.prestosql.sql.tree.Cast;
import io.prestosql.sql.tree.ComparisonExpression;
import io.prestosql.sql.tree.Expression;
import io.prestosql.sql.tree.FunctionCall;
import io.prestosql.sql.tree.GenericLiteral;
import io.prestosql.sql.tree.NullLiteral;
import io.prestosql.sql.tree.QualifiedName;
import io.prestosql.sql.tree.QuantifiedComparisonExpression;
import io.prestosql.sql.tree.SearchedCaseExpression;
import io.prestosql.sql.tree.SimpleCaseExpression;
import io.prestosql.sql.tree.WhenClause;
import java.util.Collections;
import java.util.EnumSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Function;

public class TransformQuantifiedComparisonApplyToLateralJoin
implements PlanOptimizer {
    private final Metadata metadata;

    public TransformQuantifiedComparisonApplyToLateralJoin(Metadata metadata) {
        this.metadata = Objects.requireNonNull(metadata, "metadata is null");
    }

    @Override
    public PlanNode optimize(PlanNode plan, Session session, TypeProvider types, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector) {
        return SimplePlanRewriter.rewriteWith(new Rewriter(idAllocator, types, symbolAllocator, this.metadata), plan, null);
    }

    private static class Rewriter
    extends SimplePlanRewriter<PlanNode> {
        private static final QualifiedName MIN = QualifiedName.of((String)"min");
        private static final QualifiedName MAX = QualifiedName.of((String)"max");
        private static final QualifiedName COUNT = QualifiedName.of((String)"count");
        private final PlanNodeIdAllocator idAllocator;
        private final TypeProvider types;
        private final SymbolAllocator symbolAllocator;
        private final Metadata metadata;

        public Rewriter(PlanNodeIdAllocator idAllocator, TypeProvider types, SymbolAllocator symbolAllocator, Metadata metadata) {
            this.idAllocator = Objects.requireNonNull(idAllocator, "idAllocator is null");
            this.types = Objects.requireNonNull(types, "types is null");
            this.symbolAllocator = Objects.requireNonNull(symbolAllocator, "symbolAllocator is null");
            this.metadata = Objects.requireNonNull(metadata, "metadata is null");
        }

        @Override
        public PlanNode visitApply(ApplyNode node, SimplePlanRewriter.RewriteContext<PlanNode> context) {
            if (node.getSubqueryAssignments().size() != 1) {
                return context.defaultRewrite(node);
            }
            Expression expression = (Expression)Iterables.getOnlyElement(node.getSubqueryAssignments().getExpressions());
            if (!(expression instanceof QuantifiedComparisonExpression)) {
                return context.defaultRewrite(node);
            }
            QuantifiedComparisonExpression quantifiedComparison = (QuantifiedComparisonExpression)expression;
            return this.rewriteQuantifiedApplyNode(node, quantifiedComparison, context);
        }

        private PlanNode rewriteQuantifiedApplyNode(ApplyNode node, QuantifiedComparisonExpression quantifiedComparison, SimplePlanRewriter.RewriteContext<PlanNode> context) {
            PlanNode subqueryPlan = context.rewrite(node.getSubquery());
            Symbol outputColumn = (Symbol)Iterables.getOnlyElement(subqueryPlan.getOutputSymbols());
            Type outputColumnType = this.types.get(outputColumn);
            Preconditions.checkState((boolean)outputColumnType.isOrderable(), (Object)"Subquery result type must be orderable");
            Symbol minValue = this.symbolAllocator.newSymbol(MIN.toString(), outputColumnType);
            Symbol maxValue = this.symbolAllocator.newSymbol(MAX.toString(), outputColumnType);
            Symbol countAllValue = this.symbolAllocator.newSymbol("count_all", (Type)BigintType.BIGINT);
            Symbol countNonNullValue = this.symbolAllocator.newSymbol("count_non_null", (Type)BigintType.BIGINT);
            FunctionRegistry functionRegistry = this.metadata.getFunctionRegistry();
            ImmutableList outputColumnReferences = ImmutableList.of((Object)outputColumn.toSymbolReference());
            ImmutableList outputColumnTypeSignature = ImmutableList.of((Object)outputColumnType.getTypeSignature());
            subqueryPlan = new AggregationNode(this.idAllocator.getNextId(), subqueryPlan, (Map<Symbol, AggregationNode.Aggregation>)ImmutableMap.of((Object)minValue, (Object)new AggregationNode.Aggregation(new FunctionCall(MIN, (List)outputColumnReferences), functionRegistry.resolveFunction(MIN, TypeSignatureProvider.fromTypeSignatures((List<? extends TypeSignature>)outputColumnTypeSignature)), Optional.empty()), (Object)maxValue, (Object)new AggregationNode.Aggregation(new FunctionCall(MAX, (List)outputColumnReferences), functionRegistry.resolveFunction(MAX, TypeSignatureProvider.fromTypeSignatures((List<? extends TypeSignature>)outputColumnTypeSignature)), Optional.empty()), (Object)countAllValue, (Object)new AggregationNode.Aggregation(new FunctionCall(COUNT, Collections.emptyList()), functionRegistry.resolveFunction(COUNT, Collections.emptyList()), Optional.empty()), (Object)countNonNullValue, (Object)new AggregationNode.Aggregation(new FunctionCall(COUNT, (List)outputColumnReferences), functionRegistry.resolveFunction(COUNT, TypeSignatureProvider.fromTypeSignatures((List<? extends TypeSignature>)outputColumnTypeSignature)), Optional.empty())), AggregationNode.globalAggregation(), (List<Symbol>)ImmutableList.of(), AggregationNode.Step.SINGLE, Optional.empty(), Optional.empty());
            LateralJoinNode lateralJoinNode = new LateralJoinNode(node.getId(), context.rewrite(node.getInput()), subqueryPlan, node.getCorrelation(), LateralJoinNode.Type.INNER, (Expression)BooleanLiteral.TRUE_LITERAL, node.getOriginSubquery());
            Expression valueComparedToSubquery = this.rewriteUsingBounds(quantifiedComparison, minValue, maxValue, countAllValue, countNonNullValue);
            Symbol quantifiedComparisonSymbol = (Symbol)Iterables.getOnlyElement(node.getSubqueryAssignments().getSymbols());
            return this.projectExpressions(lateralJoinNode, Assignments.of(quantifiedComparisonSymbol, valueComparedToSubquery));
        }

        public Expression rewriteUsingBounds(QuantifiedComparisonExpression quantifiedComparison, Symbol minValue, Symbol maxValue, Symbol countAllValue, Symbol countNonNullValue) {
            BooleanLiteral emptySetResult = quantifiedComparison.getQuantifier().equals((Object)QuantifiedComparisonExpression.Quantifier.ALL) ? BooleanLiteral.TRUE_LITERAL : BooleanLiteral.FALSE_LITERAL;
            Function<List, Expression> quantifier = quantifiedComparison.getQuantifier().equals((Object)QuantifiedComparisonExpression.Quantifier.ALL) ? ExpressionUtils::combineConjuncts : ExpressionUtils::combineDisjuncts;
            Expression comparisonWithExtremeValue = this.getBoundComparisons(quantifiedComparison, minValue, maxValue);
            return new SimpleCaseExpression((Expression)countAllValue.toSymbolReference(), (List)ImmutableList.of((Object)new WhenClause((Expression)new GenericLiteral("bigint", "0"), (Expression)emptySetResult)), Optional.of(quantifier.apply((List)ImmutableList.of((Object)comparisonWithExtremeValue, (Object)new SearchedCaseExpression((List)ImmutableList.of((Object)new WhenClause((Expression)new ComparisonExpression(ComparisonExpression.Operator.NOT_EQUAL, (Expression)countAllValue.toSymbolReference(), (Expression)countNonNullValue.toSymbolReference()), (Expression)new Cast((Expression)new NullLiteral(), BooleanType.BOOLEAN.toString()))), Optional.of(emptySetResult))))));
        }

        private Expression getBoundComparisons(QuantifiedComparisonExpression quantifiedComparison, Symbol minValue, Symbol maxValue) {
            if (quantifiedComparison.getOperator() == ComparisonExpression.Operator.EQUAL && quantifiedComparison.getQuantifier() == QuantifiedComparisonExpression.Quantifier.ALL) {
                return ExpressionUtils.combineConjuncts(new Expression[]{new ComparisonExpression(ComparisonExpression.Operator.EQUAL, (Expression)minValue.toSymbolReference(), (Expression)maxValue.toSymbolReference()), new ComparisonExpression(ComparisonExpression.Operator.EQUAL, quantifiedComparison.getValue(), (Expression)maxValue.toSymbolReference())});
            }
            if (EnumSet.of(ComparisonExpression.Operator.LESS_THAN, ComparisonExpression.Operator.LESS_THAN_OR_EQUAL, ComparisonExpression.Operator.GREATER_THAN, ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL).contains(quantifiedComparison.getOperator())) {
                Symbol boundValue = Rewriter.shouldCompareValueWithLowerBound(quantifiedComparison) ? minValue : maxValue;
                return new ComparisonExpression(quantifiedComparison.getOperator(), quantifiedComparison.getValue(), (Expression)boundValue.toSymbolReference());
            }
            throw new IllegalArgumentException("Unsupported quantified comparison: " + quantifiedComparison);
        }

        private static boolean shouldCompareValueWithLowerBound(QuantifiedComparisonExpression quantifiedComparison) {
            switch (quantifiedComparison.getQuantifier()) {
                case ALL: {
                    switch (quantifiedComparison.getOperator()) {
                        case LESS_THAN: 
                        case LESS_THAN_OR_EQUAL: {
                            return true;
                        }
                        case GREATER_THAN: 
                        case GREATER_THAN_OR_EQUAL: {
                            return false;
                        }
                    }
                    break;
                }
                case ANY: 
                case SOME: {
                    switch (quantifiedComparison.getOperator()) {
                        case LESS_THAN: 
                        case LESS_THAN_OR_EQUAL: {
                            return false;
                        }
                        case GREATER_THAN: 
                        case GREATER_THAN_OR_EQUAL: {
                            return true;
                        }
                    }
                }
            }
            throw new IllegalArgumentException("Unexpected quantifier: " + quantifiedComparison.getQuantifier());
        }

        private ProjectNode projectExpressions(PlanNode input, Assignments subqueryAssignments) {
            Assignments assignments = Assignments.builder().putIdentities(input.getOutputSymbols()).putAll(subqueryAssignments).build();
            return new ProjectNode(this.idAllocator.getNextId(), input, assignments);
        }
    }
}

