/*
 * Decompiled with CFR 0.152.
 */
package org.apache.calcite.rel.rules;

import com.linkedin.coral.calcite.$internal.com.google.common.base.Preconditions;
import com.linkedin.coral.calcite.$internal.com.google.common.collect.ImmutableList;
import com.linkedin.coral.calcite.$internal.com.google.common.collect.Iterables;
import com.linkedin.coral.calcite.$internal.com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeSet;
import org.apache.calcite.linq4j.Ord;
import org.apache.calcite.plan.Contexts;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.rel.RelCollations;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.core.RelFactories;
import org.apache.calcite.rel.logical.LogicalAggregate;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.fun.SqlSumEmptyIsZeroAggFunction;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.tools.RelBuilderFactory;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.ImmutableIntList;
import org.apache.calcite.util.Pair;
import org.apache.calcite.util.Util;

public final class AggregateExpandDistinctAggregatesRule
extends RelOptRule {
    public static final AggregateExpandDistinctAggregatesRule INSTANCE = new AggregateExpandDistinctAggregatesRule(LogicalAggregate.class, true, RelFactories.LOGICAL_BUILDER);
    public static final AggregateExpandDistinctAggregatesRule JOIN = new AggregateExpandDistinctAggregatesRule(LogicalAggregate.class, false, RelFactories.LOGICAL_BUILDER);
    public final boolean useGroupingSets;

    public AggregateExpandDistinctAggregatesRule(Class<? extends Aggregate> clazz, boolean useGroupingSets, RelBuilderFactory relBuilderFactory) {
        super(AggregateExpandDistinctAggregatesRule.operand(clazz, AggregateExpandDistinctAggregatesRule.any()), relBuilderFactory, null);
        this.useGroupingSets = useGroupingSets;
    }

    @Deprecated
    public AggregateExpandDistinctAggregatesRule(Class<? extends LogicalAggregate> clazz, boolean useGroupingSets, RelFactories.JoinFactory joinFactory) {
        this(clazz, useGroupingSets, RelBuilder.proto(Contexts.of((Object)joinFactory)));
    }

    @Deprecated
    public AggregateExpandDistinctAggregatesRule(Class<? extends LogicalAggregate> clazz, RelFactories.JoinFactory joinFactory) {
        this(clazz, false, RelBuilder.proto(Contexts.of((Object)joinFactory)));
    }

    @Override
    public void onMatch(RelOptRuleCall call) {
        Aggregate aggregate = (Aggregate)call.rel(0);
        if (!aggregate.containsDistinctCall()) {
            return;
        }
        int nonDistinctAggCallCount = 0;
        int filterCount = 0;
        int unsupportedNonDistinctAggCallCount = 0;
        LinkedHashSet<Pair<List<Integer>, Integer>> argLists = new LinkedHashSet<Pair<List<Integer>, Integer>>();
        for (AggregateCall aggCall : aggregate.getAggCallList()) {
            if (aggCall.filterArg >= 0) {
                ++filterCount;
            }
            if (!aggCall.isDistinct()) {
                ++nonDistinctAggCallCount;
                SqlKind aggCallKind = aggCall.getAggregation().getKind();
                switch (aggCallKind) {
                    case COUNT: 
                    case SUM: 
                    case SUM0: 
                    case MIN: 
                    case MAX: {
                        break;
                    }
                    default: {
                        ++unsupportedNonDistinctAggCallCount;
                        break;
                    }
                }
                continue;
            }
            argLists.add(Pair.of(aggCall.getArgList(), aggCall.filterArg));
        }
        int distinctAggCallCount = aggregate.getAggCallList().size() - nonDistinctAggCallCount;
        Preconditions.checkState(argLists.size() > 0, "containsDistinctCall lied");
        if (nonDistinctAggCallCount == 0 && argLists.size() == 1 && aggregate.getGroupType() == Aggregate.Group.SIMPLE) {
            Pair pair = (Pair)Iterables.getOnlyElement(argLists);
            RelBuilder relBuilder = call.builder();
            this.convertMonopole(relBuilder, aggregate, (List)pair.left, (Integer)pair.right);
            call.transformTo(relBuilder.build());
            return;
        }
        if (this.useGroupingSets) {
            this.rewriteUsingGroupingSets(call, aggregate);
            return;
        }
        if (distinctAggCallCount == 1 && filterCount == 0 && unsupportedNonDistinctAggCallCount == 0 && nonDistinctAggCallCount > 0) {
            RelBuilder relBuilder = call.builder();
            this.convertSingletonDistinct(relBuilder, aggregate, argLists);
            call.transformTo(relBuilder.build());
            return;
        }
        List<RelDataTypeField> aggFields = aggregate.getRowType().getFieldList();
        ArrayList<RexInputRef> refs = new ArrayList<RexInputRef>();
        List<String> fieldNames = aggregate.getRowType().getFieldNames();
        ImmutableBitSet groupSet = aggregate.getGroupSet();
        int groupCount = aggregate.getGroupCount();
        for (int i2 : Util.range(groupCount)) {
            refs.add(RexInputRef.of(i2, aggFields));
        }
        ArrayList<AggregateCall> newAggCallList = new ArrayList<AggregateCall>();
        int i = -1;
        for (AggregateCall aggCall : aggregate.getAggCallList()) {
            ++i;
            if (aggCall.isDistinct()) {
                refs.add(null);
                continue;
            }
            refs.add(new RexInputRef(groupCount + newAggCallList.size(), aggFields.get(groupCount + i).getType()));
            newAggCallList.add(aggCall);
        }
        RelBuilder relBuilder = call.builder();
        relBuilder.push(aggregate.getInput());
        int n = 0;
        if (!newAggCallList.isEmpty()) {
            RelBuilder.GroupKey groupKey = relBuilder.groupKey(groupSet, aggregate.getGroupSets());
            relBuilder.aggregate(groupKey, (List<AggregateCall>)newAggCallList);
            ++n;
        }
        for (Pair pair : argLists) {
            this.doRewrite(relBuilder, aggregate, n++, (List)pair.left, (Integer)pair.right, refs);
        }
        relBuilder.project(refs, fieldNames);
        call.transformTo(relBuilder.build());
    }

    private RelBuilder convertSingletonDistinct(RelBuilder relBuilder, Aggregate aggregate, Set<Pair<List<Integer>, Integer>> argLists) {
        Preconditions.checkArgument(argLists.size() == 1);
        relBuilder.push(aggregate.getInput());
        List<AggregateCall> originalAggCalls = aggregate.getAggCallList();
        ImmutableBitSet originalGroupSet = aggregate.getGroupSet();
        TreeSet<Integer> bottomGroups = new TreeSet<Integer>();
        bottomGroups.addAll(aggregate.getGroupSet().asList());
        for (AggregateCall aggCall : originalAggCalls) {
            if (!aggCall.isDistinct()) continue;
            bottomGroups.addAll(aggCall.getArgList());
            break;
        }
        ImmutableBitSet bottomGroupSet = ImmutableBitSet.of(bottomGroups);
        ArrayList<AggregateCall> bottomAggregateCalls = new ArrayList<AggregateCall>();
        for (AggregateCall aggCall : originalAggCalls) {
            if (aggCall.isDistinct()) continue;
            AggregateCall newCall = AggregateCall.create(aggCall.getAggregation(), false, aggCall.isApproximate(), aggCall.ignoreNulls(), aggCall.getArgList(), -1, aggCall.collation, bottomGroupSet.cardinality(), relBuilder.peek(), null, aggCall.name);
            bottomAggregateCalls.add(newCall);
        }
        relBuilder.push(aggregate.copy(aggregate.getTraitSet(), relBuilder.build(), bottomGroupSet, null, bottomAggregateCalls));
        ArrayList<AggregateCall> topAggregateCalls = new ArrayList<AggregateCall>();
        int nonDistinctAggCallProcessedSoFar = 0;
        for (AggregateCall aggCall : originalAggCalls) {
            AggregateCall newCall;
            if (aggCall.isDistinct()) {
                ArrayList<Integer> newArgList = new ArrayList<Integer>();
                for (int arg : aggCall.getArgList()) {
                    newArgList.add(bottomGroups.headSet(arg).size());
                }
                newCall = AggregateCall.create(aggCall.getAggregation(), false, aggCall.isApproximate(), aggCall.ignoreNulls(), newArgList, -1, aggCall.collation, originalGroupSet.cardinality(), relBuilder.peek(), aggCall.getType(), aggCall.name);
            } else {
                int arg = bottomGroups.size() + nonDistinctAggCallProcessedSoFar;
                ImmutableList<Integer> newArgs = ImmutableList.of(Integer.valueOf(arg));
                newCall = aggCall.getAggregation().getKind() == SqlKind.COUNT ? AggregateCall.create(new SqlSumEmptyIsZeroAggFunction(), false, aggCall.isApproximate(), aggCall.ignoreNulls(), newArgs, -1, aggCall.collation, originalGroupSet.cardinality(), relBuilder.peek(), aggCall.getType(), aggCall.getName()) : AggregateCall.create(aggCall.getAggregation(), false, aggCall.isApproximate(), aggCall.ignoreNulls(), newArgs, -1, aggCall.collation, originalGroupSet.cardinality(), relBuilder.peek(), aggCall.getType(), aggCall.name);
                ++nonDistinctAggCallProcessedSoFar;
            }
            topAggregateCalls.add(newCall);
        }
        HashSet<Integer> topGroupSet = new HashSet<Integer>();
        int groupSetToAdd = 0;
        Iterator iterator = bottomGroups.iterator();
        while (iterator.hasNext()) {
            int bottomGroup = (Integer)iterator.next();
            if (originalGroupSet.get(bottomGroup)) {
                topGroupSet.add(groupSetToAdd);
            }
            ++groupSetToAdd;
        }
        relBuilder.push(aggregate.copy(aggregate.getTraitSet(), relBuilder.build(), ImmutableBitSet.of(topGroupSet), null, topAggregateCalls));
        return relBuilder;
    }

    private void rewriteUsingGroupingSets(RelOptRuleCall call, Aggregate aggregate) {
        TreeSet<ImmutableBitSet> groupSetTreeSet = new TreeSet<ImmutableBitSet>(ImmutableBitSet.ORDERING);
        for (AggregateCall aggCall : aggregate.getAggCallList()) {
            if (!aggCall.isDistinct()) {
                groupSetTreeSet.add(aggregate.getGroupSet());
                continue;
            }
            groupSetTreeSet.add(ImmutableBitSet.of(aggCall.getArgList()).setIf(aggCall.filterArg, aggCall.filterArg >= 0).union(aggregate.getGroupSet()));
        }
        ImmutableList<ImmutableBitSet> groupSets = ImmutableList.copyOf(groupSetTreeSet);
        ImmutableBitSet fullGroupSet = ImmutableBitSet.union(groupSets);
        ArrayList<AggregateCall> distinctAggCalls = new ArrayList<AggregateCall>();
        for (Pair<AggregateCall, String> aggCall : aggregate.getNamedAggCalls()) {
            if (((AggregateCall)aggCall.left).isDistinct()) continue;
            AggregateCall newAggCall = ((AggregateCall)aggCall.left).adaptTo(aggregate.getInput(), ((AggregateCall)aggCall.left).getArgList(), ((AggregateCall)aggCall.left).filterArg, aggregate.getGroupCount(), fullGroupSet.cardinality());
            distinctAggCalls.add(newAggCall.rename((String)aggCall.right));
        }
        RelBuilder relBuilder = call.builder();
        relBuilder.push(aggregate.getInput());
        int groupCount = fullGroupSet.cardinality();
        LinkedHashMap filters = new LinkedHashMap();
        int z = groupCount + distinctAggCalls.size();
        distinctAggCalls.add(AggregateCall.create(SqlStdOperatorTable.GROUPING, false, false, false, ImmutableIntList.copyOf(fullGroupSet), -1, RelCollations.EMPTY, groupSets.size(), relBuilder.peek(), null, "$g"));
        for (Ord<ImmutableBitSet> groupSet : Ord.zip(groupSets)) {
            filters.put(groupSet.e, z + groupSet.i);
        }
        relBuilder.aggregate(relBuilder.groupKey(fullGroupSet, groupSets), (List<AggregateCall>)distinctAggCalls);
        RelNode distinct = relBuilder.peek();
        if (!filters.isEmpty()) {
            ArrayList<RexNode> nodes = new ArrayList<RexNode>(relBuilder.fields());
            RexNode nodeZ = (RexNode)nodes.remove(nodes.size() - 1);
            for (Map.Entry entry : filters.entrySet()) {
                long v = AggregateExpandDistinctAggregatesRule.groupValue(fullGroupSet, (ImmutableBitSet)entry.getKey());
                nodes.add(relBuilder.alias(relBuilder.equals(nodeZ, relBuilder.literal(v)), "$g_" + v));
            }
            relBuilder.project(nodes);
        }
        int x = groupCount;
        ArrayList<AggregateCall> newCalls = new ArrayList<AggregateCall>();
        for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
            int newFilterArg;
            ImmutableIntList newArgList;
            SqlAggFunction aggregation;
            if (!aggregateCall.isDistinct()) {
                aggregation = SqlStdOperatorTable.MIN;
                newArgList = ImmutableIntList.of(x++);
                newFilterArg = (Integer)filters.get(aggregate.getGroupSet());
            } else {
                aggregation = aggregateCall.getAggregation();
                newArgList = AggregateExpandDistinctAggregatesRule.remap(fullGroupSet, aggregateCall.getArgList());
                newFilterArg = (Integer)filters.get(ImmutableBitSet.of(aggregateCall.getArgList()).setIf(aggregateCall.filterArg, aggregateCall.filterArg >= 0).union(aggregate.getGroupSet()));
            }
            AggregateCall newCall = AggregateCall.create(aggregation, false, aggregateCall.isApproximate(), aggregateCall.ignoreNulls(), newArgList, newFilterArg, aggregateCall.collation, aggregate.getGroupCount(), distinct, null, aggregateCall.name);
            newCalls.add(newCall);
        }
        relBuilder.aggregate(relBuilder.groupKey(AggregateExpandDistinctAggregatesRule.remap(fullGroupSet, aggregate.getGroupSet()), AggregateExpandDistinctAggregatesRule.remap(fullGroupSet, aggregate.getGroupSets())), (List<AggregateCall>)newCalls);
        relBuilder.convert(aggregate.getRowType(), true);
        call.transformTo(relBuilder.build());
    }

    private static long groupValue(ImmutableBitSet fullGroupSet, ImmutableBitSet groupSet) {
        long v = 0L;
        long x = 1L << fullGroupSet.cardinality() - 1;
        assert (fullGroupSet.contains(groupSet));
        for (int i : fullGroupSet) {
            if (!groupSet.get(i)) {
                v |= x;
            }
            x >>= 1;
        }
        return v;
    }

    private static ImmutableBitSet remap(ImmutableBitSet groupSet, ImmutableBitSet bitSet) {
        ImmutableBitSet.Builder builder = ImmutableBitSet.builder();
        for (Integer bit : bitSet) {
            builder.set(AggregateExpandDistinctAggregatesRule.remap(groupSet, bit));
        }
        return builder.build();
    }

    private static ImmutableList<ImmutableBitSet> remap(ImmutableBitSet groupSet, Iterable<ImmutableBitSet> bitSets) {
        ImmutableList.Builder builder = ImmutableList.builder();
        for (ImmutableBitSet bitSet : bitSets) {
            builder.add(AggregateExpandDistinctAggregatesRule.remap(groupSet, bitSet));
        }
        return builder.build();
    }

    private static List<Integer> remap(ImmutableBitSet groupSet, List<Integer> argList) {
        ImmutableIntList list = ImmutableIntList.of();
        for (int arg : argList) {
            list = list.append(AggregateExpandDistinctAggregatesRule.remap(groupSet, arg));
        }
        return list;
    }

    private static int remap(ImmutableBitSet groupSet, int arg) {
        return arg < 0 ? -1 : groupSet.indexOf(arg);
    }

    private RelBuilder convertMonopole(RelBuilder relBuilder, Aggregate aggregate, List<Integer> argList, int filterArg) {
        HashMap<Integer, Integer> sourceOf = new HashMap<Integer, Integer>();
        this.createSelectDistinct(relBuilder, aggregate, argList, filterArg, sourceOf);
        ArrayList<AggregateCall> newAggCalls = Lists.newArrayList(aggregate.getAggCallList());
        AggregateExpandDistinctAggregatesRule.rewriteAggCalls(newAggCalls, argList, sourceOf);
        int cardinality = aggregate.getGroupSet().cardinality();
        relBuilder.push(aggregate.copy(aggregate.getTraitSet(), relBuilder.build(), ImmutableBitSet.range(cardinality), null, newAggCalls));
        return relBuilder;
    }

    private void doRewrite(RelBuilder relBuilder, Aggregate aggregate, int n, List<Integer> argList, int filterArg, List<RexInputRef> refs) {
        RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder();
        List<RelDataTypeField> leftFields = n == 0 ? null : relBuilder.peek().getRowType().getFieldList();
        HashMap<Integer, Integer> sourceOf = new HashMap<Integer, Integer>();
        this.createSelectDistinct(relBuilder, aggregate, argList, filterArg, sourceOf);
        ArrayList<AggregateCall> aggCallList = new ArrayList<AggregateCall>();
        List<AggregateCall> aggCalls = aggregate.getAggCallList();
        int groupCount = aggregate.getGroupCount();
        int i = groupCount - 1;
        for (AggregateCall aggregateCall : aggCalls) {
            ++i;
            if (!aggregateCall.isDistinct() || !aggregateCall.getArgList().equals(argList)) continue;
            int argCount = aggregateCall.getArgList().size();
            ArrayList<Integer> newArgs = new ArrayList<Integer>(argCount);
            for (int j = 0; j < argCount; ++j) {
                Integer arg = aggregateCall.getArgList().get(j);
                newArgs.add((Integer)sourceOf.get(arg));
            }
            int newFilterArg = aggregateCall.filterArg >= 0 ? (Integer)sourceOf.get(aggregateCall.filterArg) : -1;
            AggregateCall newAggCall = AggregateCall.create(aggregateCall.getAggregation(), false, aggregateCall.isApproximate(), aggregateCall.ignoreNulls(), newArgs, newFilterArg, aggregateCall.collation, aggregateCall.getType(), aggregateCall.getName());
            assert (refs.get(i) == null);
            if (n == 0) {
                refs.set(i, new RexInputRef(groupCount + aggCallList.size(), newAggCall.getType()));
            } else {
                refs.set(i, new RexInputRef(leftFields.size() + groupCount + aggCallList.size(), newAggCall.getType()));
            }
            aggCallList.add(newAggCall);
        }
        HashMap<Integer, Integer> map = new HashMap<Integer, Integer>();
        for (Integer key : aggregate.getGroupSet()) {
            map.put(key, map.size());
        }
        ImmutableBitSet immutableBitSet = aggregate.getGroupSet().permute(map);
        assert (immutableBitSet.equals(ImmutableBitSet.range(aggregate.getGroupSet().cardinality())));
        List<ImmutableBitSet> newGroupingSets = null;
        relBuilder.push(aggregate.copy(aggregate.getTraitSet(), relBuilder.build(), immutableBitSet, newGroupingSets, aggCallList));
        if (n == 0) {
            return;
        }
        List<RelDataTypeField> distinctFields = relBuilder.peek().getRowType().getFieldList();
        ArrayList<RexNode> conditions = new ArrayList<RexNode>();
        for (i = 0; i < groupCount; ++i) {
            conditions.add(rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.IS_NOT_DISTINCT_FROM, RexInputRef.of(i, leftFields), new RexInputRef(leftFields.size() + i, distinctFields.get(i).getType())));
        }
        relBuilder.join(JoinRelType.INNER, conditions);
    }

    private static void rewriteAggCalls(List<AggregateCall> newAggCalls, List<Integer> argList, Map<Integer, Integer> sourceOf) {
        for (int i = 0; i < newAggCalls.size(); ++i) {
            AggregateCall aggCall = newAggCalls.get(i);
            if (!aggCall.isDistinct() || !aggCall.getArgList().equals(argList)) continue;
            int argCount = aggCall.getArgList().size();
            ArrayList<Integer> newArgs = new ArrayList<Integer>(argCount);
            for (int j = 0; j < argCount; ++j) {
                Integer arg = aggCall.getArgList().get(j);
                newArgs.add(sourceOf.get(arg));
            }
            AggregateCall newAggCall = AggregateCall.create(aggCall.getAggregation(), false, aggCall.isApproximate(), aggCall.ignoreNulls(), newArgs, -1, aggCall.collation, aggCall.getType(), aggCall.getName());
            newAggCalls.set(i, newAggCall);
        }
    }

    private RelBuilder createSelectDistinct(RelBuilder relBuilder, Aggregate aggregate, List<Integer> argList, int filterArg, Map<Integer, Integer> sourceOf) {
        relBuilder.push(aggregate.getInput());
        ArrayList<Pair<RexNode, String>> projects = new ArrayList<Pair<RexNode, String>>();
        List<RelDataTypeField> childFields = relBuilder.peek().getRowType().getFieldList();
        for (int i : aggregate.getGroupSet()) {
            sourceOf.put(i, projects.size());
            projects.add(RexInputRef.of2(i, childFields));
        }
        for (Integer arg : argList) {
            if (filterArg >= 0) {
                RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder();
                RexInputRef filterRef = RexInputRef.of(filterArg, childFields);
                Pair<RexNode, String> argRef = RexInputRef.of2(arg, childFields);
                RexNode condition = rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.CASE, filterRef, (RexNode)argRef.left, rexBuilder.ensureType(((RexNode)argRef.left).getType(), rexBuilder.makeNullLiteral(((RexNode)argRef.left).getType()), true));
                sourceOf.put(arg, projects.size());
                projects.add(Pair.of(condition, "i$" + (String)argRef.right));
                continue;
            }
            if (sourceOf.get(arg) != null) continue;
            sourceOf.put(arg, projects.size());
            projects.add(RexInputRef.of2(arg, childFields));
        }
        relBuilder.project(Pair.left(projects), Pair.right(projects));
        relBuilder.push(aggregate.copy(aggregate.getTraitSet(), relBuilder.build(), ImmutableBitSet.range(projects.size()), null, ImmutableList.of()));
        return relBuilder;
    }
}

