/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.Session;
import com.facebook.presto.SystemSessionProperties;
import com.facebook.presto.common.type.ArrayType;
import com.facebook.presto.common.type.BigintType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.common.type.TypeSignature;
import com.facebook.presto.matching.Captures;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.metadata.BuiltInFunctionHandle;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.spi.function.FunctionHandle;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.Assignments;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.ProjectNode;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.planner.plan.Patterns;
import com.facebook.presto.sql.relational.Expressions;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

public class CombineApproxPercentileFunctions
implements Rule<AggregationNode> {
    private static final String APPROX_PERCENTILE = "approx_percentile";
    private static final String ARRAY_CONSTRUCTOR = "array_constructor";
    private static final String ELEMENT_AT = "element_at";
    private static final int ARRAY_SIZE_LIMIT = 254;
    private final FunctionAndTypeManager functionAndTypeManager;
    private static final Pattern<AggregationNode> PATTERN = Patterns.aggregation().matching(CombineApproxPercentileFunctions::hasMultipleApproxPercentile);

    public CombineApproxPercentileFunctions(FunctionAndTypeManager functionAndTypeManager) {
        this.functionAndTypeManager = Objects.requireNonNull(functionAndTypeManager, "functionAndTypeManager is null");
    }

    private static boolean hasMultipleApproxPercentile(AggregationNode aggregation) {
        return aggregation.getAggregations().values().stream().filter(agg -> agg.getCall().getDisplayName().equals(APPROX_PERCENTILE)).count() > 1L;
    }

    @Override
    public Pattern<AggregationNode> getPattern() {
        return PATTERN;
    }

    @Override
    public boolean isEnabled(Session session) {
        return SystemSessionProperties.isCombineApproxPercentileEnabled(session);
    }

    private static int getPercentilePosition(FunctionHandle functionHandle) {
        Preconditions.checkState((boolean)(functionHandle instanceof BuiltInFunctionHandle));
        List argumentTypes = ((BuiltInFunctionHandle)functionHandle).getSignature().getArgumentTypes();
        if (argumentTypes.size() == 2 || argumentTypes.size() == 3 && ((TypeSignature)argumentTypes.get(1)).getBase().equals("double")) {
            return 1;
        }
        Preconditions.checkState((argumentTypes.size() == 4 || argumentTypes.size() == 3 && ((TypeSignature)argumentTypes.get(1)).getBase().equals("bigint") ? 1 : 0) != 0);
        return 2;
    }

    private static boolean aggregationCanMerge(AggregationNode.Aggregation aggregation1, AggregationNode.Aggregation aggregation2) {
        if (!(aggregation1.getMask().equals(aggregation2.getMask()) && aggregation1.getOrderBy().equals(aggregation2.getOrderBy()) && aggregation1.getFilter().equals(aggregation2.getFilter()) && aggregation1.isDistinct() == aggregation2.isDistinct())) {
            return false;
        }
        CallExpression expression1 = aggregation1.getCall();
        CallExpression expression2 = aggregation2.getCall();
        int percentilePosition = CombineApproxPercentileFunctions.getPercentilePosition(expression1.getFunctionHandle());
        if (!expression1.getFunctionHandle().equals(expression2.getFunctionHandle()) || expression1.getArguments().size() != expression2.getArguments().size()) {
            return false;
        }
        List arguments1 = expression1.getArguments();
        List arguments2 = expression2.getArguments();
        for (int i = 0; i < arguments1.size(); ++i) {
            if (i == percentilePosition || ((RowExpression)arguments1.get(i)).equals(arguments2.get(i))) continue;
            return false;
        }
        return true;
    }

    private static List<RowExpression> changePercentileArgument(List<RowExpression> arguments, RowExpression percentileArgument, int percentilePosition) {
        ImmutableList.Builder newAggCallArguments = new ImmutableList.Builder();
        for (int i = 0; i < arguments.size(); ++i) {
            if (i == percentilePosition) {
                newAggCallArguments.add((Object)percentileArgument);
                continue;
            }
            newAggCallArguments.add((Object)arguments.get(i));
        }
        return newAggCallArguments.build();
    }

    private static List<List<AggregationNode.Aggregation>> createMergeableAggregations(List<AggregationNode.Aggregation> candidateAggregations) {
        ImmutableList.Builder result = ImmutableList.builder();
        HashSet<AggregationNode.Aggregation> mergedAggregation = new HashSet<AggregationNode.Aggregation>();
        for (int i = 0; i < candidateAggregations.size(); ++i) {
            if (mergedAggregation.contains(candidateAggregations.get(i))) continue;
            ImmutableList.Builder aggregationCanBeMerged = ImmutableList.builder();
            mergedAggregation.add(candidateAggregations.get(i));
            aggregationCanBeMerged.add((Object)candidateAggregations.get(i));
            for (int j = i + 1; j < candidateAggregations.size(); ++j) {
                if (mergedAggregation.contains(candidateAggregations.get(j)) || !CombineApproxPercentileFunctions.aggregationCanMerge(candidateAggregations.get(i), candidateAggregations.get(j))) continue;
                mergedAggregation.add(candidateAggregations.get(j));
                aggregationCanBeMerged.add((Object)candidateAggregations.get(j));
            }
            result.add((Object)aggregationCanBeMerged.build());
        }
        return result.build();
    }

    private CallExpression createArrayPercentile(List<AggregationNode.Aggregation> aggregations) {
        List<RowExpression> percentileArray = aggregations.stream().map(x -> (RowExpression)x.getArguments().get(CombineApproxPercentileFunctions.getPercentilePosition(x.getFunctionHandle()))).collect(Collectors.toList());
        return Expressions.call(this.functionAndTypeManager, ARRAY_CONSTRUCTOR, (Type)new ArrayType(((RowExpression)percentileArray.get(0)).getType()), percentileArray);
    }

    private AggregationNode.Aggregation createArrayAggregation(List<AggregationNode.Aggregation> candidateList, VariableReferenceExpression arrayVariableReference) {
        AggregationNode.Aggregation aggregationBeforeMerge = candidateList.get(0);
        int percentilePosition = CombineApproxPercentileFunctions.getPercentilePosition(aggregationBeforeMerge.getFunctionHandle());
        List<RowExpression> newAggCallArguments = CombineApproxPercentileFunctions.changePercentileArgument(aggregationBeforeMerge.getCall().getArguments(), (RowExpression)arrayVariableReference, percentilePosition);
        Type colType = ((RowExpression)aggregationBeforeMerge.getCall().getArguments().get(0)).getType();
        CallExpression approxPercentileCall = Expressions.call(this.functionAndTypeManager, APPROX_PERCENTILE, (Type)new ArrayType(colType), newAggCallArguments);
        return new AggregationNode.Aggregation(approxPercentileCall, aggregationBeforeMerge.getFilter(), aggregationBeforeMerge.getOrderBy(), aggregationBeforeMerge.isDistinct(), aggregationBeforeMerge.getMask());
    }

    @Override
    public Rule.Result apply(AggregationNode aggregationNode, Captures captures, Rule.Context context) {
        ImmutableMap.Builder aggregations = ImmutableMap.builder();
        List approxPercentile = aggregationNode.getAggregations().values().stream().filter(x -> x.getCall().getDisplayName().equals(APPROX_PERCENTILE) && !(x.getCall().getType() instanceof ArrayType)).collect(Collectors.toList());
        Map aggregationOccurrences = approxPercentile.stream().collect(Collectors.groupingBy(Function.identity(), Collectors.counting()));
        ImmutableList candidateApproxPercentile = (ImmutableList)approxPercentile.stream().filter(x -> (Long)aggregationOccurrences.get(x) == 1L).collect(ImmutableList.toImmutableList());
        Map sameColumnHandle = candidateApproxPercentile.stream().collect(Collectors.groupingBy(x -> (RowExpression)x.getCall().getArguments().get(0), LinkedHashMap::new, Collectors.groupingBy(x -> x.getFunctionHandle(), LinkedHashMap::new, Collectors.toList())));
        ImmutableList.Builder candidateLists = ImmutableList.builder();
        sameColumnHandle.values().forEach(sameHandle -> sameHandle.values().forEach(aggregationList -> candidateLists.addAll(CombineApproxPercentileFunctions.createMergeableAggregations(aggregationList))));
        List candidateAggregationLists = candidateLists.build().stream().filter(x -> x.size() > 1 && x.size() < 254).collect(Collectors.toList());
        if (candidateAggregationLists.isEmpty()) {
            return Rule.Result.empty();
        }
        Set combinedAggregations = candidateAggregationLists.stream().flatMap(Collection::stream).collect(Collectors.toSet());
        HashMap aggregationVariableMap = new HashMap();
        HashSet combinedVariableReference = new HashSet();
        aggregationNode.getAggregations().forEach((variable, aggregation) -> {
            if (combinedAggregations.contains(aggregation)) {
                aggregationVariableMap.put(aggregation, variable);
                combinedVariableReference.add(variable);
            }
        });
        Assignments.Builder sourceProjectAssignments = Assignments.builder();
        Assignments.Builder outputProjectAssignments = Assignments.builder();
        for (List candidateList : candidateAggregationLists) {
            CallExpression arrayExpression = this.createArrayPercentile(candidateList);
            VariableReferenceExpression arrayVariableReference = context.getVariableAllocator().newVariable((RowExpression)arrayExpression);
            sourceProjectAssignments.put(arrayVariableReference, (RowExpression)arrayExpression);
            AggregationNode.Aggregation newAggregation = this.createArrayAggregation(candidateList, arrayVariableReference);
            VariableReferenceExpression newVariableReference = context.getVariableAllocator().newVariable((RowExpression)newAggregation.getCall());
            aggregations.put((Object)newVariableReference, (Object)newAggregation);
            Map elementAtMap = (Map)IntStream.range(0, candidateList.size()).boxed().collect(ImmutableMap.toImmutableMap(x -> (VariableReferenceExpression)aggregationVariableMap.get(candidateList.get((int)x)), x -> Expressions.call(this.functionAndTypeManager, ELEMENT_AT, ((RowExpression)((AggregationNode.Aggregation)candidateList.get((int)x)).getArguments().get(0)).getType(), (List<RowExpression>)ImmutableList.of((Object)newVariableReference, (Object)Expressions.constant((long)x.intValue() + 1L, (Type)BigintType.BIGINT)))));
            outputProjectAssignments.putAll(elementAtMap);
        }
        aggregationNode.getAggregations().forEach((key, value) -> {
            if (!combinedVariableReference.contains(key)) {
                aggregations.put(key, value);
            }
        });
        aggregationNode.getOutputVariables().forEach(variable -> {
            if (!combinedVariableReference.contains(variable)) {
                outputProjectAssignments.put(variable, (RowExpression)variable);
            }
        });
        aggregationNode.getSource().getOutputVariables().forEach(variable -> sourceProjectAssignments.put(variable, (RowExpression)variable));
        return Rule.Result.ofPlanNode((PlanNode)new ProjectNode(context.getIdAllocator().getNextId(), (PlanNode)new AggregationNode(aggregationNode.getSourceLocation(), context.getIdAllocator().getNextId(), (PlanNode)new ProjectNode(context.getIdAllocator().getNextId(), aggregationNode.getSource(), sourceProjectAssignments.build()), (Map)aggregations.build(), aggregationNode.getGroupingSets(), (List)ImmutableList.of(), aggregationNode.getStep(), aggregationNode.getHashVariable(), aggregationNode.getGroupIdVariable()), outputProjectAssignments.build()));
    }
}

