/*
 * Decompiled with CFR 0.152.
 */
package org.apache.pinot.core.query.aggregation.groupby;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
import java.util.TreeMap;
import org.apache.commons.collections.comparators.ComparableComparator;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.pinot.common.request.context.ExpressionContext;
import org.apache.pinot.common.response.broker.GroupByResult;
import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
import org.apache.pinot.core.query.aggregation.function.AggregationFunctionUtils;
import org.apache.pinot.core.query.aggregation.function.MinAggregationFunction;
import org.apache.pinot.core.query.request.context.QueryContext;
import org.apache.pinot.core.util.GroupByUtils;

public class AggregationGroupByTrimmingService {
    private final AggregationFunction[] _aggregationFunctions;
    private final int _numGroupByExpressions;
    private final int _limit;
    private final int _trimSize;
    private final int _trimThreshold;

    public AggregationGroupByTrimmingService(QueryContext queryContext) {
        this._aggregationFunctions = queryContext.getAggregationFunctions();
        List<ExpressionContext> groupByExpressions = queryContext.getGroupByExpressions();
        assert (groupByExpressions != null);
        this._numGroupByExpressions = groupByExpressions.size();
        this._limit = queryContext.getLimit();
        assert (this._limit > 0);
        this._trimSize = GroupByUtils.getTableCapacity(this._limit);
        this._trimThreshold = this._trimSize * 4;
    }

    public List<Map<String, Object>> trimIntermediateResultsMap(Map<String, Object[]> intermediateResultsMap) {
        int numAggregationFunctions = this._aggregationFunctions.length;
        Map[] trimmedResultMaps = new Map[numAggregationFunctions];
        int numGroups = intermediateResultsMap.size();
        if (numGroups > this._trimThreshold) {
            Sorter[] sorters = new Sorter[numAggregationFunctions];
            for (int i = 0; i < numAggregationFunctions; ++i) {
                AggregationFunction aggregationFunction = this._aggregationFunctions[i];
                sorters[i] = AggregationGroupByTrimmingService.getSorter(this._trimSize, aggregationFunction, false);
            }
            for (Map.Entry<String, Object[]> entry : intermediateResultsMap.entrySet()) {
                String groupKey = entry.getKey();
                Object[] intermediateResults = entry.getValue();
                for (int i = 0; i < numAggregationFunctions; ++i) {
                    sorters[i].add(groupKey, intermediateResults[i]);
                }
            }
            for (int i = 0; i < numAggregationFunctions; ++i) {
                HashMap<String, Object> trimmedResultMap = new HashMap<String, Object>(this._trimSize);
                sorters[i].dumpToMap(trimmedResultMap);
                trimmedResultMaps[i] = trimmedResultMap;
            }
        } else {
            for (int i = 0; i < numAggregationFunctions; ++i) {
                trimmedResultMaps[i] = new HashMap(numGroups);
            }
            for (Map.Entry<String, Object[]> entry : intermediateResultsMap.entrySet()) {
                String groupKey = entry.getKey();
                Object[] intermediateResults = entry.getValue();
                for (int i = 0; i < numAggregationFunctions; ++i) {
                    trimmedResultMaps[i].put(groupKey, intermediateResults[i]);
                }
            }
        }
        return Arrays.asList(trimmedResultMaps);
    }

    public List<GroupByResult>[] trimFinalResults(Map<String, Comparable>[] finalResultMaps) {
        int numAggregationFunctions = this._aggregationFunctions.length;
        List[] trimmedResults = new List[numAggregationFunctions];
        for (int i = 0; i < numAggregationFunctions; ++i) {
            LinkedList<GroupByResult> groupByResults;
            trimmedResults[i] = groupByResults = new LinkedList<GroupByResult>();
            Map<String, Comparable> finalResultMap = finalResultMaps[i];
            if (finalResultMap.isEmpty()) continue;
            Sorter sorter = AggregationGroupByTrimmingService.getSorter(this._limit, this._aggregationFunctions[i], true);
            for (Map.Entry<String, Comparable> entry : finalResultMap.entrySet()) {
                sorter.add(entry.getKey(), entry.getValue());
            }
            sorter.dumpToGroupByResults(groupByResults, this._numGroupByExpressions);
        }
        return trimmedResults;
    }

    private static Sorter getSorter(int trimSize, AggregationFunction aggregationFunction, boolean isComparable) {
        boolean minOrder = aggregationFunction instanceof MinAggregationFunction;
        if (isComparable) {
            if (minOrder) {
                return new ComparableSorter(trimSize, Collections.reverseOrder());
            }
            return new ComparableSorter(trimSize, (Comparator<? super Comparable>)new ComparableComparator());
        }
        if (minOrder) {
            return new NonComparableSorter(trimSize, (Comparator<? super Comparable>)new ComparableComparator(), aggregationFunction);
        }
        return new NonComparableSorter(trimSize, Collections.reverseOrder(), aggregationFunction);
    }

    private static class NonComparableSorter
    implements Sorter {
        private final int _trimSize;
        private final Comparator<? super Comparable> _comparator;
        private final AggregationFunction _aggregationFunction;
        private final TreeMap<Comparable, List<ImmutablePair<String, Object>>> _treeMap;
        private int _numValuesAdded = 0;

        public NonComparableSorter(int trimSize, Comparator<? super Comparable> comparator, AggregationFunction aggregationFunction) {
            this._trimSize = trimSize;
            this._comparator = comparator;
            this._aggregationFunction = aggregationFunction;
            this._treeMap = new TreeMap(comparator);
        }

        @Override
        public void add(String groupKey, Object result) {
            Object newKey = this._aggregationFunction.extractFinalResult(result);
            ImmutablePair groupKeyResultPair = new ImmutablePair((Object)groupKey, result);
            List<ImmutablePair<String, Object>> groupKeyResultPairs = this._treeMap.get(newKey);
            if (this._numValuesAdded >= this._trimSize) {
                Map.Entry<Comparable, List<ImmutablePair<String, Object>>> maxEntry = this._treeMap.lastEntry();
                Comparable maxKey = maxEntry.getKey();
                if (this._comparator.compare((Comparable)newKey, maxKey) < 0) {
                    if (groupKeyResultPairs == null) {
                        groupKeyResultPairs = new ArrayList<ImmutablePair<String, Object>>();
                        this._treeMap.put((Comparable)newKey, groupKeyResultPairs);
                    }
                    groupKeyResultPairs.add((ImmutablePair<String, Object>)groupKeyResultPair);
                    ++this._numValuesAdded;
                    if (maxEntry.getValue().size() + this._trimSize == this._numValuesAdded) {
                        this._treeMap.remove(maxKey);
                    }
                }
            } else {
                if (groupKeyResultPairs == null) {
                    groupKeyResultPairs = new ArrayList<ImmutablePair<String, Object>>();
                    this._treeMap.put((Comparable)newKey, groupKeyResultPairs);
                }
                groupKeyResultPairs.add((ImmutablePair<String, Object>)groupKeyResultPair);
                ++this._numValuesAdded;
            }
        }

        @Override
        public void dumpToMap(Map<String, Object> dest) {
            int numResultsAdded = 0;
            for (List<ImmutablePair<String, Object>> groupKeyResultPairs : this._treeMap.values()) {
                for (ImmutablePair<String, Object> groupResultPair : groupKeyResultPairs) {
                    if (numResultsAdded != this._trimSize) {
                        dest.put((String)groupResultPair.left, groupResultPair.right);
                        ++numResultsAdded;
                        continue;
                    }
                    return;
                }
            }
        }

        @Override
        public void dumpToGroupByResults(LinkedList<GroupByResult> dest, int numGroupByExpressions) {
            throw new UnsupportedOperationException();
        }
    }

    private static class ComparableSorter
    implements Sorter {
        private final int _trimSize;
        private final Comparator<? super Comparable> _comparator;
        private final PriorityQueue<GroupKeyResultPair> _heap;

        public ComparableSorter(int trimSize, Comparator<? super Comparable> comparator) {
            this._trimSize = trimSize;
            this._comparator = comparator;
            this._heap = new PriorityQueue<Comparable>(this._trimSize, comparator);
        }

        @Override
        public void add(String groupKey, Object result) {
            GroupKeyResultPair newGroupKeyResultPair = new GroupKeyResultPair(groupKey, (Comparable)result);
            if (this._heap.size() == this._trimSize) {
                GroupKeyResultPair minGroupKeyResultPair = this._heap.peek();
                if (this._comparator.compare(newGroupKeyResultPair, minGroupKeyResultPair) > 0) {
                    this._heap.poll();
                    this._heap.add(newGroupKeyResultPair);
                }
            } else {
                this._heap.add(newGroupKeyResultPair);
            }
        }

        @Override
        public void dumpToMap(Map<String, Object> dest) {
            GroupKeyResultPair groupKeyResultPair;
            while ((groupKeyResultPair = this._heap.poll()) != null) {
                dest.put(groupKeyResultPair._groupKey, groupKeyResultPair._result);
            }
        }

        @Override
        public void dumpToGroupByResults(LinkedList<GroupByResult> dest, int numGroupByExpressions) {
            if (numGroupByExpressions == 1) {
                GroupKeyResultPair groupKeyResultPair;
                while ((groupKeyResultPair = this._heap.poll()) != null) {
                    GroupByResult groupByResult = new GroupByResult();
                    groupByResult.setGroup(Collections.singletonList(groupKeyResultPair._groupKey));
                    groupByResult.setValue(AggregationFunctionUtils.getSerializableValue(groupKeyResultPair._result));
                    dest.addFirst(groupByResult);
                }
            } else {
                GroupKeyResultPair groupKeyResultPair;
                while ((groupKeyResultPair = this._heap.poll()) != null) {
                    String[] groupKeys = StringUtils.splitPreserveAllTokens((String)groupKeyResultPair._groupKey, (char)'\u0000');
                    GroupByResult groupByResult = new GroupByResult();
                    groupByResult.setGroup(Arrays.asList(groupKeys));
                    groupByResult.setValue(AggregationFunctionUtils.getSerializableValue(groupKeyResultPair._result));
                    dest.addFirst(groupByResult);
                }
            }
        }

        private static class GroupKeyResultPair
        implements Comparable<GroupKeyResultPair> {
            private final String _groupKey;
            private final Comparable<? super Comparable> _result;

            public GroupKeyResultPair(String groupKey, Comparable<? super Comparable> result) {
                this._groupKey = groupKey;
                this._result = result;
            }

            @Override
            public int compareTo(GroupKeyResultPair o) {
                return this._result.compareTo(o._result);
            }
        }
    }

    private static interface Sorter {
        public void add(String var1, Object var2);

        public void dumpToMap(Map<String, Object> var1);

        public void dumpToGroupByResults(LinkedList<GroupByResult> var1, int var2);
    }
}

