/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.sql.planner.optimizations;

import com.facebook.presto.Session;
import com.facebook.presto.SystemSessionProperties;
import com.facebook.presto.common.function.OperatorType;
import com.facebook.presto.common.type.BigintType;
import com.facebook.presto.common.type.BooleanType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.cost.StatsCalculator;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.spi.VariableAllocator;
import com.facebook.presto.spi.WarningCollector;
import com.facebook.presto.spi.function.FunctionHandle;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.Assignments;
import com.facebook.presto.spi.plan.DistinctLimitNode;
import com.facebook.presto.spi.plan.FilterNode;
import com.facebook.presto.spi.plan.LimitNode;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.PlanNodeIdAllocator;
import com.facebook.presto.spi.plan.ProjectNode;
import com.facebook.presto.spi.plan.TableScanNode;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.ConstantExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.SpecialFormExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.analyzer.TypeSignatureProvider;
import com.facebook.presto.sql.planner.PlannerUtils;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.planner.optimizations.AggregationNodeUtils;
import com.facebook.presto.sql.planner.optimizations.JoinNodeUtils;
import com.facebook.presto.sql.planner.optimizations.PlanOptimizer;
import com.facebook.presto.sql.planner.plan.ChildReplacer;
import com.facebook.presto.sql.planner.plan.JoinNode;
import com.facebook.presto.sql.planner.plan.SimplePlanRewriter;
import com.facebook.presto.sql.planner.plan.SortNode;
import com.facebook.presto.sql.relational.Expressions;
import com.facebook.presto.sql.tree.Join;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;

public class PrefilterForLimitingAggregation
implements PlanOptimizer {
    private final Metadata metadata;
    private final StatsCalculator statsCalculator;

    public PrefilterForLimitingAggregation(Metadata metadata, StatsCalculator statsCalculator) {
        this.metadata = metadata;
        this.statsCalculator = statsCalculator;
    }

    @Override
    public PlanNode optimize(PlanNode plan, Session session, TypeProvider types, VariableAllocator variableAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector) {
        if (SystemSessionProperties.isPrefilterForGroupbyLimit(session)) {
            return SimplePlanRewriter.rewriteWith(new Rewriter(session, this.metadata, types, this.statsCalculator, idAllocator, variableAllocator), plan);
        }
        return plan;
    }

    private static class Rewriter
    extends SimplePlanRewriter<Void> {
        private final Session session;
        private final Metadata metadata;
        private final TypeProvider types;
        private final StatsCalculator statsCalculator;
        private final PlanNodeIdAllocator idAllocator;
        private final VariableAllocator variableAllocator;

        private Rewriter(Session session, Metadata metadata, TypeProvider types, StatsCalculator statsCalculator, PlanNodeIdAllocator idAllocator, VariableAllocator variableAllocator) {
            this.session = session;
            this.metadata = metadata;
            this.types = types;
            this.statsCalculator = statsCalculator;
            this.idAllocator = idAllocator;
            this.variableAllocator = variableAllocator;
        }

        @Override
        public PlanNode visitSort(SortNode sortNode, SimplePlanRewriter.RewriteContext<Void> context) {
            return sortNode;
        }

        public PlanNode visitLimit(LimitNode limitNode, SimplePlanRewriter.RewriteContext<Void> context) {
            PlanNode rewrittenAggregation;
            Optional<TableScanNode> scanNode;
            PlanNode source = Rewriter.rewriteWith(this, limitNode.getSource());
            AggregationNode aggregationNode = null;
            if (source instanceof ProjectNode && ((ProjectNode)source).getSource() instanceof AggregationNode) {
                aggregationNode = (AggregationNode)((ProjectNode)source).getSource();
            } else if (source instanceof AggregationNode) {
                aggregationNode = (AggregationNode)source;
            }
            if (aggregationNode != null && !aggregationNode.getGroupingKeys().isEmpty() && (scanNode = PlannerUtils.getTableScanNodeWithOnlyFilterAndProject(aggregationNode.getSource())).isPresent() && !AggregationNodeUtils.isAllLowCardinalityGroupByKeys(aggregationNode, scanNode.get(), this.session, this.statsCalculator, this.types, limitNode.getCount()) && (rewrittenAggregation = this.addPrefilter(aggregationNode, limitNode.getCount())) != aggregationNode) {
                if (source == aggregationNode) {
                    return ChildReplacer.replaceChildren((PlanNode)limitNode, (List<PlanNode>)ImmutableList.of((Object)rewrittenAggregation));
                }
                return ChildReplacer.replaceChildren((PlanNode)limitNode, (List<PlanNode>)ImmutableList.of((Object)ChildReplacer.replaceChildren(source, (List<PlanNode>)ImmutableList.of((Object)rewrittenAggregation))));
            }
            if (source == limitNode.getSource()) {
                return ChildReplacer.replaceChildren((PlanNode)limitNode, (List<PlanNode>)ImmutableList.of((Object)source));
            }
            return limitNode;
        }

        private PlanNode addPrefilter(AggregationNode aggregationNode, long count) {
            List<VariableReferenceExpression> keys = aggregationNode.getGroupingKeys().stream().collect(Collectors.toList());
            if (keys.isEmpty()) {
                return aggregationNode;
            }
            PlanNode originalSource = aggregationNode.getSource();
            PlanNode keySource = PlannerUtils.clonePlanNode(originalSource, this.session, this.metadata, this.idAllocator, keys, new HashMap<VariableReferenceExpression, VariableReferenceExpression>());
            DistinctLimitNode timedDistinctLimitNode = new DistinctLimitNode(Optional.empty(), this.idAllocator.getNextId(), keySource, count, false, keys, Optional.empty(), SystemSessionProperties.getPrefilterForGroupbyLimitTimeoutMS(this.session));
            FunctionAndTypeManager functionAndTypeManager = this.metadata.getFunctionAndTypeManager();
            RowExpression leftHashExpression = PlannerUtils.getHashExpression(functionAndTypeManager, keys).get();
            RowExpression rightHashExpression = PlannerUtils.getHashExpression(functionAndTypeManager, timedDistinctLimitNode.getOutputVariables()).get();
            Type mapType = PlannerUtils.createMapType(functionAndTypeManager, (Type)BigintType.BIGINT, (Type)BooleanType.BOOLEAN);
            PlanNode rightProjectNode = PlannerUtils.projectExpressions((PlanNode)timedDistinctLimitNode, this.idAllocator, this.variableAllocator, (List<? extends RowExpression>)ImmutableList.of((Object)rightHashExpression, (Object)Expressions.constant(Boolean.TRUE, (Type)BooleanType.BOOLEAN)), (List<VariableReferenceExpression>)ImmutableList.of());
            VariableReferenceExpression mapAggVariable = this.variableAllocator.newVariable("expr", mapType);
            PlanNode crossJoinRhs = PlannerUtils.addAggregation(rightProjectNode, functionAndTypeManager, this.idAllocator, this.variableAllocator, "MAP_AGG", mapType, (List<VariableReferenceExpression>)ImmutableList.of(), mapAggVariable, (RowExpression)rightProjectNode.getOutputVariables().get(0), (RowExpression)rightProjectNode.getOutputVariables().get(1));
            PlanNode crossJoinLhs = PlannerUtils.addProjections(originalSource, this.idAllocator, this.variableAllocator, (List<RowExpression>)ImmutableList.of((Object)leftHashExpression), (List<VariableReferenceExpression>)ImmutableList.of());
            ImmutableList.Builder crossJoinOutput = ImmutableList.builder();
            crossJoinOutput.addAll((Iterable)crossJoinLhs.getOutputVariables());
            crossJoinOutput.addAll((Iterable)crossJoinRhs.getOutputVariables());
            JoinNode crossJoin = new JoinNode(Optional.empty(), this.idAllocator.getNextId(), JoinNodeUtils.typeConvert(Join.Type.CROSS), crossJoinLhs, crossJoinRhs, (List<JoinNode.EquiJoinClause>)ImmutableList.of(), (List<VariableReferenceExpression>)crossJoinOutput.build(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.of(JoinNode.DistributionType.REPLICATED), (Map<String, VariableReferenceExpression>)ImmutableMap.of());
            VariableReferenceExpression mapVariable = (VariableReferenceExpression)crossJoinRhs.getOutputVariables().get(0);
            VariableReferenceExpression lookupVariable = (VariableReferenceExpression)crossJoinLhs.getOutputVariables().get(crossJoinLhs.getOutputVariables().size() - 1);
            CallExpression cardinality = Expressions.call(functionAndTypeManager, "CARDINALITY", (Type)BigintType.BIGINT, new RowExpression[]{mapVariable});
            ConstantExpression countExpr = Expressions.constant(count, (Type)BigintType.BIGINT);
            FunctionHandle equalsFunctionHandle = this.metadata.getFunctionAndTypeManager().resolveOperator(OperatorType.EQUAL, TypeSignatureProvider.fromTypes((Type[])new Type[]{BigintType.BIGINT, BigintType.BIGINT}));
            CallExpression foundAllEntires = Expressions.call(OperatorType.EQUAL.name(), equalsFunctionHandle, (Type)BooleanType.BOOLEAN, new RowExpression[]{cardinality, countExpr});
            CallExpression mapElementAt = Expressions.call(functionAndTypeManager, "element_at", (Type)BooleanType.BOOLEAN, new RowExpression[]{mapVariable, lookupVariable});
            SpecialFormExpression check = Expressions.specialForm(SpecialFormExpression.Form.IF, (Type)BooleanType.BOOLEAN, new RowExpression[]{foundAllEntires, mapElementAt, Expressions.constant(Boolean.TRUE, (Type)BooleanType.BOOLEAN)});
            FilterNode filterNode = new FilterNode(Optional.empty(), this.idAllocator.getNextId(), (PlanNode)crossJoin, (RowExpression)check);
            Assignments.Builder originalOutputs = Assignments.builder();
            for (VariableReferenceExpression variableReferenceExpression : originalSource.getOutputVariables()) {
                originalOutputs.put(variableReferenceExpression, (RowExpression)variableReferenceExpression);
            }
            ProjectNode filteredSource = new ProjectNode(Optional.empty(), this.idAllocator.getNextId(), (PlanNode)filterNode, originalOutputs.build(), ProjectNode.Locality.LOCAL);
            return ChildReplacer.replaceChildren((PlanNode)aggregationNode, (List<PlanNode>)ImmutableList.of((Object)filteredSource));
        }
    }
}

