/*
 * Decompiled with CFR 0.152.
 */
package org.apache.pinot.core.query.aggregation.function;

import com.clearspring.analytics.stream.cardinality.HyperLogLog;
import com.google.common.base.Preconditions;
import java.util.List;
import java.util.Map;
import org.apache.pinot.common.request.context.ExpressionContext;
import org.apache.pinot.common.utils.DataSchema;
import org.apache.pinot.core.common.BlockValSet;
import org.apache.pinot.core.common.ObjectSerDeUtils;
import org.apache.pinot.core.query.aggregation.AggregationResultHolder;
import org.apache.pinot.core.query.aggregation.ObjectAggregationResultHolder;
import org.apache.pinot.core.query.aggregation.function.BaseSingleInputAggregationFunction;
import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
import org.apache.pinot.core.query.aggregation.groupby.ObjectGroupByResultHolder;
import org.apache.pinot.segment.spi.AggregationFunctionType;
import org.apache.pinot.segment.spi.index.reader.Dictionary;
import org.apache.pinot.spi.data.FieldSpec;
import org.roaringbitmap.PeekableIntIterator;
import org.roaringbitmap.RoaringBitmap;

public class DistinctCountHLLAggregationFunction
extends BaseSingleInputAggregationFunction<HyperLogLog, Long> {
    protected final int _log2m;

    public DistinctCountHLLAggregationFunction(List<ExpressionContext> arguments) {
        super(arguments.get(0));
        int numExpressions = arguments.size();
        Preconditions.checkArgument((numExpressions <= 2 ? 1 : 0) != 0, (String)"DistinctCountHLL expects 1 or 2 arguments, got: %s", (int)numExpressions);
        this._log2m = arguments.size() == 2 ? arguments.get(1).getLiteral().getIntValue() : 8;
    }

    public int getLog2m() {
        return this._log2m;
    }

    @Override
    public AggregationFunctionType getType() {
        return AggregationFunctionType.DISTINCTCOUNTHLL;
    }

    @Override
    public AggregationResultHolder createAggregationResultHolder() {
        return new ObjectAggregationResultHolder();
    }

    @Override
    public GroupByResultHolder createGroupByResultHolder(int initialCapacity, int maxCapacity) {
        return new ObjectGroupByResultHolder(initialCapacity, maxCapacity);
    }

    @Override
    public void aggregate(int length, AggregationResultHolder aggregationResultHolder, Map<ExpressionContext, BlockValSet> blockValSetMap) {
        BlockValSet blockValSet = blockValSetMap.get(this._expression);
        FieldSpec.DataType storedType = blockValSet.getValueType().getStoredType();
        if (storedType == FieldSpec.DataType.BYTES) {
            byte[][] bytesValues = blockValSet.getBytesValuesSV();
            try {
                HyperLogLog hyperLogLog = (HyperLogLog)aggregationResultHolder.getResult();
                if (hyperLogLog != null) {
                    for (int i = 0; i < length; ++i) {
                        hyperLogLog.addAll(ObjectSerDeUtils.HYPER_LOG_LOG_SER_DE.deserialize(bytesValues[i]));
                    }
                } else {
                    hyperLogLog = ObjectSerDeUtils.HYPER_LOG_LOG_SER_DE.deserialize(bytesValues[0]);
                    aggregationResultHolder.setValue(hyperLogLog);
                    for (int i = 1; i < length; ++i) {
                        hyperLogLog.addAll(ObjectSerDeUtils.HYPER_LOG_LOG_SER_DE.deserialize(bytesValues[i]));
                    }
                }
            }
            catch (Exception e) {
                throw new RuntimeException("Caught exception while merging HyperLogLogs", e);
            }
            return;
        }
        Dictionary dictionary = blockValSet.getDictionary();
        if (dictionary != null) {
            int[] dictIds = blockValSet.getDictionaryIdsSV();
            DistinctCountHLLAggregationFunction.getDictIdBitmap(aggregationResultHolder, dictionary).addN(dictIds, 0, length);
            return;
        }
        HyperLogLog hyperLogLog = this.getHyperLogLog(aggregationResultHolder);
        switch (storedType) {
            case INT: {
                int[] intValues = blockValSet.getIntValuesSV();
                for (int i = 0; i < length; ++i) {
                    hyperLogLog.offer((Object)intValues[i]);
                }
                break;
            }
            case LONG: {
                long[] longValues = blockValSet.getLongValuesSV();
                for (int i = 0; i < length; ++i) {
                    hyperLogLog.offer((Object)longValues[i]);
                }
                break;
            }
            case FLOAT: {
                float[] floatValues = blockValSet.getFloatValuesSV();
                for (int i = 0; i < length; ++i) {
                    hyperLogLog.offer((Object)Float.valueOf(floatValues[i]));
                }
                break;
            }
            case DOUBLE: {
                double[] doubleValues = blockValSet.getDoubleValuesSV();
                for (int i = 0; i < length; ++i) {
                    hyperLogLog.offer((Object)doubleValues[i]);
                }
                break;
            }
            case STRING: {
                String[] stringValues = blockValSet.getStringValuesSV();
                for (int i = 0; i < length; ++i) {
                    hyperLogLog.offer((Object)stringValues[i]);
                }
                break;
            }
            default: {
                throw new IllegalStateException("Illegal data type for DISTINCT_COUNT_HLL aggregation function: " + storedType);
            }
        }
    }

    @Override
    public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder, Map<ExpressionContext, BlockValSet> blockValSetMap) {
        BlockValSet blockValSet = blockValSetMap.get(this._expression);
        FieldSpec.DataType storedType = blockValSet.getValueType().getStoredType();
        if (storedType == FieldSpec.DataType.BYTES) {
            byte[][] bytesValues = blockValSet.getBytesValuesSV();
            try {
                for (int i = 0; i < length; ++i) {
                    HyperLogLog value = ObjectSerDeUtils.HYPER_LOG_LOG_SER_DE.deserialize(bytesValues[i]);
                    int groupKey = groupKeyArray[i];
                    HyperLogLog hyperLogLog = (HyperLogLog)groupByResultHolder.getResult(groupKey);
                    if (hyperLogLog != null) {
                        hyperLogLog.addAll(value);
                        continue;
                    }
                    groupByResultHolder.setValueForKey(groupKey, value);
                }
            }
            catch (Exception e) {
                throw new RuntimeException("Caught exception while merging HyperLogLogs", e);
            }
            return;
        }
        Dictionary dictionary = blockValSet.getDictionary();
        if (dictionary != null) {
            int[] dictIds = blockValSet.getDictionaryIdsSV();
            for (int i = 0; i < length; ++i) {
                DistinctCountHLLAggregationFunction.getDictIdBitmap(groupByResultHolder, groupKeyArray[i], dictionary).add(dictIds[i]);
            }
            return;
        }
        switch (storedType) {
            case INT: {
                int[] intValues = blockValSet.getIntValuesSV();
                for (int i = 0; i < length; ++i) {
                    this.getHyperLogLog(groupByResultHolder, groupKeyArray[i]).offer((Object)intValues[i]);
                }
                break;
            }
            case LONG: {
                long[] longValues = blockValSet.getLongValuesSV();
                for (int i = 0; i < length; ++i) {
                    this.getHyperLogLog(groupByResultHolder, groupKeyArray[i]).offer((Object)longValues[i]);
                }
                break;
            }
            case FLOAT: {
                float[] floatValues = blockValSet.getFloatValuesSV();
                for (int i = 0; i < length; ++i) {
                    this.getHyperLogLog(groupByResultHolder, groupKeyArray[i]).offer((Object)Float.valueOf(floatValues[i]));
                }
                break;
            }
            case DOUBLE: {
                double[] doubleValues = blockValSet.getDoubleValuesSV();
                for (int i = 0; i < length; ++i) {
                    this.getHyperLogLog(groupByResultHolder, groupKeyArray[i]).offer((Object)doubleValues[i]);
                }
                break;
            }
            case STRING: {
                String[] stringValues = blockValSet.getStringValuesSV();
                for (int i = 0; i < length; ++i) {
                    this.getHyperLogLog(groupByResultHolder, groupKeyArray[i]).offer((Object)stringValues[i]);
                }
                break;
            }
            default: {
                throw new IllegalStateException("Illegal data type for DISTINCT_COUNT_HLL aggregation function: " + storedType);
            }
        }
    }

    @Override
    public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResultHolder groupByResultHolder, Map<ExpressionContext, BlockValSet> blockValSetMap) {
        BlockValSet blockValSet = blockValSetMap.get(this._expression);
        FieldSpec.DataType storedType = blockValSet.getValueType().getStoredType();
        if (storedType == FieldSpec.DataType.BYTES) {
            byte[][] bytesValues = blockValSet.getBytesValuesSV();
            try {
                for (int i = 0; i < length; ++i) {
                    HyperLogLog value = ObjectSerDeUtils.HYPER_LOG_LOG_SER_DE.deserialize(bytesValues[i]);
                    for (int groupKey : groupKeysArray[i]) {
                        HyperLogLog hyperLogLog = (HyperLogLog)groupByResultHolder.getResult(groupKey);
                        if (hyperLogLog != null) {
                            hyperLogLog.addAll(value);
                            continue;
                        }
                        groupByResultHolder.setValueForKey(groupKey, ObjectSerDeUtils.HYPER_LOG_LOG_SER_DE.deserialize(bytesValues[i]));
                    }
                }
            }
            catch (Exception e) {
                throw new RuntimeException("Caught exception while merging HyperLogLogs", e);
            }
            return;
        }
        Dictionary dictionary = blockValSet.getDictionary();
        if (dictionary != null) {
            int[] dictIds = blockValSet.getDictionaryIdsSV();
            for (int i = 0; i < length; ++i) {
                DistinctCountHLLAggregationFunction.setDictIdForGroupKeys(groupByResultHolder, groupKeysArray[i], dictionary, dictIds[i]);
            }
            return;
        }
        switch (storedType) {
            case INT: {
                int[] intValues = blockValSet.getIntValuesSV();
                for (int i = 0; i < length; ++i) {
                    this.setValueForGroupKeys(groupByResultHolder, groupKeysArray[i], intValues[i]);
                }
                break;
            }
            case LONG: {
                long[] longValues = blockValSet.getLongValuesSV();
                for (int i = 0; i < length; ++i) {
                    this.setValueForGroupKeys(groupByResultHolder, groupKeysArray[i], longValues[i]);
                }
                break;
            }
            case FLOAT: {
                float[] floatValues = blockValSet.getFloatValuesSV();
                for (int i = 0; i < length; ++i) {
                    this.setValueForGroupKeys(groupByResultHolder, groupKeysArray[i], Float.valueOf(floatValues[i]));
                }
                break;
            }
            case DOUBLE: {
                double[] doubleValues = blockValSet.getDoubleValuesSV();
                for (int i = 0; i < length; ++i) {
                    this.setValueForGroupKeys(groupByResultHolder, groupKeysArray[i], doubleValues[i]);
                }
                break;
            }
            case STRING: {
                String[] stringValues = blockValSet.getStringValuesSV();
                for (int i = 0; i < length; ++i) {
                    this.setValueForGroupKeys(groupByResultHolder, groupKeysArray[i], stringValues[i]);
                }
                break;
            }
            default: {
                throw new IllegalStateException("Illegal data type for DISTINCT_COUNT_HLL aggregation function: " + storedType);
            }
        }
    }

    @Override
    public HyperLogLog extractAggregationResult(AggregationResultHolder aggregationResultHolder) {
        Object result = aggregationResultHolder.getResult();
        if (result == null) {
            return new HyperLogLog(this._log2m);
        }
        if (result instanceof DictIdsWrapper) {
            return this.convertToHyperLogLog((DictIdsWrapper)result);
        }
        return (HyperLogLog)result;
    }

    @Override
    public HyperLogLog extractGroupByResult(GroupByResultHolder groupByResultHolder, int groupKey) {
        Object result = groupByResultHolder.getResult(groupKey);
        if (result == null) {
            return new HyperLogLog(this._log2m);
        }
        if (result instanceof DictIdsWrapper) {
            return this.convertToHyperLogLog((DictIdsWrapper)result);
        }
        return (HyperLogLog)result;
    }

    @Override
    public HyperLogLog merge(HyperLogLog intermediateResult1, HyperLogLog intermediateResult2) {
        if (intermediateResult1.sizeof() != intermediateResult2.sizeof()) {
            if (intermediateResult1.cardinality() == 0L) {
                return intermediateResult2;
            }
            Preconditions.checkState((intermediateResult2.cardinality() == 0L ? 1 : 0) != 0, (Object)"Cannot merge HyperLogLogs of different sizes");
            return intermediateResult1;
        }
        try {
            intermediateResult1.addAll(intermediateResult2);
        }
        catch (Exception e) {
            throw new RuntimeException("Caught exception while merging HyperLogLogs", e);
        }
        return intermediateResult1;
    }

    @Override
    public DataSchema.ColumnDataType getIntermediateResultColumnType() {
        return DataSchema.ColumnDataType.OBJECT;
    }

    @Override
    public DataSchema.ColumnDataType getFinalResultColumnType() {
        return DataSchema.ColumnDataType.LONG;
    }

    @Override
    public Long extractFinalResult(HyperLogLog intermediateResult) {
        return intermediateResult.cardinality();
    }

    @Override
    public Long mergeFinalResult(Long finalResult1, Long finalResult2) {
        return finalResult1 + finalResult2;
    }

    protected static RoaringBitmap getDictIdBitmap(AggregationResultHolder aggregationResultHolder, Dictionary dictionary) {
        DictIdsWrapper dictIdsWrapper = (DictIdsWrapper)aggregationResultHolder.getResult();
        if (dictIdsWrapper == null) {
            dictIdsWrapper = new DictIdsWrapper(dictionary);
            aggregationResultHolder.setValue(dictIdsWrapper);
        }
        return dictIdsWrapper._dictIdBitmap;
    }

    protected HyperLogLog getHyperLogLog(AggregationResultHolder aggregationResultHolder) {
        HyperLogLog hyperLogLog = (HyperLogLog)aggregationResultHolder.getResult();
        if (hyperLogLog == null) {
            hyperLogLog = new HyperLogLog(this._log2m);
            aggregationResultHolder.setValue(hyperLogLog);
        }
        return hyperLogLog;
    }

    protected static RoaringBitmap getDictIdBitmap(GroupByResultHolder groupByResultHolder, int groupKey, Dictionary dictionary) {
        DictIdsWrapper dictIdsWrapper = (DictIdsWrapper)groupByResultHolder.getResult(groupKey);
        if (dictIdsWrapper == null) {
            dictIdsWrapper = new DictIdsWrapper(dictionary);
            groupByResultHolder.setValueForKey(groupKey, dictIdsWrapper);
        }
        return dictIdsWrapper._dictIdBitmap;
    }

    protected HyperLogLog getHyperLogLog(GroupByResultHolder groupByResultHolder, int groupKey) {
        HyperLogLog hyperLogLog = (HyperLogLog)groupByResultHolder.getResult(groupKey);
        if (hyperLogLog == null) {
            hyperLogLog = new HyperLogLog(this._log2m);
            groupByResultHolder.setValueForKey(groupKey, hyperLogLog);
        }
        return hyperLogLog;
    }

    private static void setDictIdForGroupKeys(GroupByResultHolder groupByResultHolder, int[] groupKeys, Dictionary dictionary, int dictId) {
        for (int groupKey : groupKeys) {
            DistinctCountHLLAggregationFunction.getDictIdBitmap(groupByResultHolder, groupKey, dictionary).add(dictId);
        }
    }

    private void setValueForGroupKeys(GroupByResultHolder groupByResultHolder, int[] groupKeys, Object value) {
        for (int groupKey : groupKeys) {
            this.getHyperLogLog(groupByResultHolder, groupKey).offer(value);
        }
    }

    private HyperLogLog convertToHyperLogLog(DictIdsWrapper dictIdsWrapper) {
        HyperLogLog hyperLogLog = new HyperLogLog(this._log2m);
        Dictionary dictionary = dictIdsWrapper._dictionary;
        RoaringBitmap dictIdBitmap = dictIdsWrapper._dictIdBitmap;
        PeekableIntIterator iterator = dictIdBitmap.getIntIterator();
        while (iterator.hasNext()) {
            hyperLogLog.offer(dictionary.get(iterator.next()));
        }
        return hyperLogLog;
    }

    private static final class DictIdsWrapper {
        final Dictionary _dictionary;
        final RoaringBitmap _dictIdBitmap;

        private DictIdsWrapper(Dictionary dictionary) {
            this._dictionary = dictionary;
            this._dictIdBitmap = new RoaringBitmap();
        }
    }
}

