/*
 * Decompiled with CFR 0.152.
 */
package org.apache.pinot.core.query.aggregation.function;

import it.unimi.dsi.fastutil.doubles.DoubleOpenHashSet;
import it.unimi.dsi.fastutil.floats.FloatOpenHashSet;
import it.unimi.dsi.fastutil.longs.LongOpenHashSet;
import it.unimi.dsi.fastutil.objects.ObjectOpenHashSet;
import java.util.Collection;
import java.util.Map;
import javax.annotation.Nullable;
import org.apache.pinot.common.request.context.ExpressionContext;
import org.apache.pinot.common.utils.DataSchema;
import org.apache.pinot.core.common.BlockValSet;
import org.apache.pinot.core.query.aggregation.AggregationResultHolder;
import org.apache.pinot.core.query.aggregation.ObjectAggregationResultHolder;
import org.apache.pinot.core.query.aggregation.function.BaseSingleInputAggregationFunction;
import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
import org.apache.pinot.core.query.aggregation.groupby.ObjectGroupByResultHolder;
import org.apache.pinot.segment.spi.AggregationFunctionType;
import org.apache.pinot.spi.data.FieldSpec;
import org.apache.pinot.spi.utils.ByteArray;
import org.roaringbitmap.RoaringBitmap;

public class SegmentPartitionedDistinctCountAggregationFunction
extends BaseSingleInputAggregationFunction<Long, Long> {
    public SegmentPartitionedDistinctCountAggregationFunction(ExpressionContext expression) {
        super(expression);
    }

    @Override
    public AggregationFunctionType getType() {
        return AggregationFunctionType.SEGMENTPARTITIONEDDISTINCTCOUNT;
    }

    @Override
    public AggregationResultHolder createAggregationResultHolder() {
        return new ObjectAggregationResultHolder();
    }

    @Override
    public GroupByResultHolder createGroupByResultHolder(int initialCapacity, int maxCapacity) {
        return new ObjectGroupByResultHolder(initialCapacity, maxCapacity);
    }

    @Override
    public void aggregate(int length, AggregationResultHolder aggregationResultHolder, Map<ExpressionContext, BlockValSet> blockValSetMap) {
        BlockValSet blockValSet = blockValSetMap.get(this._expression);
        if (blockValSet.getDictionary() != null) {
            int[] dictIds = blockValSet.getDictionaryIdsSV();
            RoaringBitmap bitmap = (RoaringBitmap)aggregationResultHolder.getResult();
            if (bitmap == null) {
                bitmap = new RoaringBitmap();
                aggregationResultHolder.setValue(bitmap);
            }
            bitmap.addN(dictIds, 0, length);
            return;
        }
        FieldSpec.DataType storedType = blockValSet.getValueType().getStoredType();
        switch (storedType) {
            case INT: {
                int[] intValues = blockValSet.getIntValuesSV();
                RoaringBitmap bitmap = (RoaringBitmap)aggregationResultHolder.getResult();
                if (bitmap == null) {
                    bitmap = new RoaringBitmap();
                    aggregationResultHolder.setValue(bitmap);
                }
                bitmap.addN(intValues, 0, length);
                break;
            }
            case LONG: {
                long[] longValues = blockValSet.getLongValuesSV();
                LongOpenHashSet longSet = (LongOpenHashSet)aggregationResultHolder.getResult();
                if (longSet == null) {
                    longSet = new LongOpenHashSet();
                    aggregationResultHolder.setValue(longSet);
                }
                for (int i = 0; i < length; ++i) {
                    longSet.add(longValues[i]);
                }
                break;
            }
            case FLOAT: {
                float[] floatValues = blockValSet.getFloatValuesSV();
                FloatOpenHashSet floatSet = (FloatOpenHashSet)aggregationResultHolder.getResult();
                if (floatSet == null) {
                    floatSet = new FloatOpenHashSet();
                    aggregationResultHolder.setValue(floatSet);
                }
                for (int i = 0; i < length; ++i) {
                    floatSet.add(floatValues[i]);
                }
                break;
            }
            case DOUBLE: {
                double[] doubleValues = blockValSet.getDoubleValuesSV();
                DoubleOpenHashSet doubleSet = (DoubleOpenHashSet)aggregationResultHolder.getResult();
                if (doubleSet == null) {
                    doubleSet = new DoubleOpenHashSet();
                    aggregationResultHolder.setValue(doubleSet);
                }
                for (int i = 0; i < length; ++i) {
                    doubleSet.add(doubleValues[i]);
                }
                break;
            }
            case STRING: {
                String[] stringValues = blockValSet.getStringValuesSV();
                ObjectOpenHashSet stringSet = (ObjectOpenHashSet)aggregationResultHolder.getResult();
                if (stringSet == null) {
                    stringSet = new ObjectOpenHashSet();
                    aggregationResultHolder.setValue(stringSet);
                }
                for (int i = 0; i < length; ++i) {
                    stringSet.add((Object)stringValues[i]);
                }
                break;
            }
            case BYTES: {
                byte[][] bytesValues = blockValSet.getBytesValuesSV();
                ObjectOpenHashSet bytesSet = (ObjectOpenHashSet)aggregationResultHolder.getResult();
                if (bytesSet == null) {
                    bytesSet = new ObjectOpenHashSet();
                    aggregationResultHolder.setValue(bytesSet);
                }
                for (int i = 0; i < length; ++i) {
                    bytesSet.add((Object)new ByteArray(bytesValues[i]));
                }
                break;
            }
            default: {
                throw new IllegalStateException("Illegal data type for PARTITIONED_DISTINCT_COUNT aggregation function: " + storedType);
            }
        }
    }

    @Override
    public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder, Map<ExpressionContext, BlockValSet> blockValSetMap) {
        BlockValSet blockValSet = blockValSetMap.get(this._expression);
        if (blockValSet.getDictionary() != null) {
            int[] dictIds = blockValSet.getDictionaryIdsSV();
            for (int i = 0; i < length; ++i) {
                SegmentPartitionedDistinctCountAggregationFunction.setIntValueForGroup(groupByResultHolder, groupKeyArray[i], dictIds[i]);
            }
            return;
        }
        FieldSpec.DataType storedType = blockValSet.getValueType().getStoredType();
        switch (storedType) {
            case INT: {
                int[] intValues = blockValSet.getIntValuesSV();
                for (int i = 0; i < length; ++i) {
                    SegmentPartitionedDistinctCountAggregationFunction.setIntValueForGroup(groupByResultHolder, groupKeyArray[i], intValues[i]);
                }
                break;
            }
            case LONG: {
                long[] longValues = blockValSet.getLongValuesSV();
                for (int i = 0; i < length; ++i) {
                    SegmentPartitionedDistinctCountAggregationFunction.setLongValueForGroup(groupByResultHolder, groupKeyArray[i], longValues[i]);
                }
                break;
            }
            case FLOAT: {
                float[] floatValues = blockValSet.getFloatValuesSV();
                for (int i = 0; i < length; ++i) {
                    SegmentPartitionedDistinctCountAggregationFunction.setFloatValueForGroup(groupByResultHolder, groupKeyArray[i], floatValues[i]);
                }
                break;
            }
            case DOUBLE: {
                double[] doubleValues = blockValSet.getDoubleValuesSV();
                for (int i = 0; i < length; ++i) {
                    SegmentPartitionedDistinctCountAggregationFunction.setDoubleValueForGroup(groupByResultHolder, groupKeyArray[i], doubleValues[i]);
                }
                break;
            }
            case STRING: {
                String[] stringValues = blockValSet.getStringValuesSV();
                for (int i = 0; i < length; ++i) {
                    SegmentPartitionedDistinctCountAggregationFunction.setStringValueForGroup(groupByResultHolder, groupKeyArray[i], stringValues[i]);
                }
                break;
            }
            case BYTES: {
                byte[][] bytesValues = blockValSet.getBytesValuesSV();
                for (int i = 0; i < length; ++i) {
                    SegmentPartitionedDistinctCountAggregationFunction.setBytesValueForGroup(groupByResultHolder, groupKeyArray[i], new ByteArray(bytesValues[i]));
                }
                break;
            }
            default: {
                throw new IllegalStateException("Illegal data type for PARTITIONED_DISTINCT_COUNT aggregation function: " + storedType);
            }
        }
    }

    @Override
    public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResultHolder groupByResultHolder, Map<ExpressionContext, BlockValSet> blockValSetMap) {
        BlockValSet blockValSet = blockValSetMap.get(this._expression);
        if (blockValSet.getDictionary() != null) {
            int[] dictIds = blockValSet.getDictionaryIdsSV();
            for (int i = 0; i < length; ++i) {
                int dictId = dictIds[i];
                for (int groupKey : groupKeysArray[i]) {
                    SegmentPartitionedDistinctCountAggregationFunction.setIntValueForGroup(groupByResultHolder, groupKey, dictId);
                }
            }
            return;
        }
        FieldSpec.DataType storedType = blockValSet.getValueType().getStoredType();
        switch (storedType) {
            case INT: {
                int[] intValues = blockValSet.getIntValuesSV();
                for (int i = 0; i < length; ++i) {
                    int value = intValues[i];
                    for (int groupKey : groupKeysArray[i]) {
                        SegmentPartitionedDistinctCountAggregationFunction.setIntValueForGroup(groupByResultHolder, groupKey, value);
                    }
                }
                break;
            }
            case LONG: {
                long[] longValues = blockValSet.getLongValuesSV();
                for (int i = 0; i < length; ++i) {
                    long value = longValues[i];
                    for (int groupKey : groupKeysArray[i]) {
                        SegmentPartitionedDistinctCountAggregationFunction.setLongValueForGroup(groupByResultHolder, groupKey, value);
                    }
                }
                break;
            }
            case FLOAT: {
                float[] floatValues = blockValSet.getFloatValuesSV();
                for (int i = 0; i < length; ++i) {
                    float value = floatValues[i];
                    for (int groupKey : groupKeysArray[i]) {
                        SegmentPartitionedDistinctCountAggregationFunction.setFloatValueForGroup(groupByResultHolder, groupKey, value);
                    }
                }
                break;
            }
            case DOUBLE: {
                double[] doubleValues = blockValSet.getDoubleValuesSV();
                for (int i = 0; i < length; ++i) {
                    double value = doubleValues[i];
                    for (int groupKey : groupKeysArray[i]) {
                        SegmentPartitionedDistinctCountAggregationFunction.setDoubleValueForGroup(groupByResultHolder, groupKey, value);
                    }
                }
                break;
            }
            case STRING: {
                String[] stringValues = blockValSet.getStringValuesSV();
                for (int i = 0; i < length; ++i) {
                    String value = stringValues[i];
                    for (int groupKey : groupKeysArray[i]) {
                        SegmentPartitionedDistinctCountAggregationFunction.setStringValueForGroup(groupByResultHolder, groupKey, value);
                    }
                }
                break;
            }
            case BYTES: {
                byte[][] bytesValues = blockValSet.getBytesValuesSV();
                for (int i = 0; i < length; ++i) {
                    ByteArray value = new ByteArray(bytesValues[i]);
                    for (int groupKey : groupKeysArray[i]) {
                        SegmentPartitionedDistinctCountAggregationFunction.setBytesValueForGroup(groupByResultHolder, groupKey, value);
                    }
                }
                break;
            }
            default: {
                throw new IllegalStateException("Illegal data type for PARTITIONED_DISTINCT_COUNT aggregation function: " + storedType);
            }
        }
    }

    @Override
    public Long extractAggregationResult(AggregationResultHolder aggregationResultHolder) {
        return SegmentPartitionedDistinctCountAggregationFunction.extractIntermediateResult(aggregationResultHolder.getResult());
    }

    @Override
    public Long extractGroupByResult(GroupByResultHolder groupByResultHolder, int groupKey) {
        return SegmentPartitionedDistinctCountAggregationFunction.extractIntermediateResult(groupByResultHolder.getResult(groupKey));
    }

    @Override
    public Long merge(Long intermediateResult1, Long intermediateResult2) {
        return intermediateResult1 + intermediateResult2;
    }

    @Override
    public DataSchema.ColumnDataType getIntermediateResultColumnType() {
        return DataSchema.ColumnDataType.LONG;
    }

    @Override
    public DataSchema.ColumnDataType getFinalResultColumnType() {
        return DataSchema.ColumnDataType.LONG;
    }

    @Override
    public Long extractFinalResult(Long intermediateResult) {
        return intermediateResult;
    }

    private static void setIntValueForGroup(GroupByResultHolder groupByResultHolder, int groupKey, int value) {
        RoaringBitmap bitmap = (RoaringBitmap)groupByResultHolder.getResult(groupKey);
        if (bitmap == null) {
            bitmap = new RoaringBitmap();
            groupByResultHolder.setValueForKey(groupKey, bitmap);
        }
        bitmap.add(value);
    }

    private static void setLongValueForGroup(GroupByResultHolder groupByResultHolder, int groupKey, long value) {
        LongOpenHashSet longSet = (LongOpenHashSet)groupByResultHolder.getResult(groupKey);
        if (longSet == null) {
            longSet = new LongOpenHashSet();
            groupByResultHolder.setValueForKey(groupKey, longSet);
        }
        longSet.add(value);
    }

    private static void setFloatValueForGroup(GroupByResultHolder groupByResultHolder, int groupKey, float value) {
        FloatOpenHashSet floatSet = (FloatOpenHashSet)groupByResultHolder.getResult(groupKey);
        if (floatSet == null) {
            floatSet = new FloatOpenHashSet();
            groupByResultHolder.setValueForKey(groupKey, floatSet);
        }
        floatSet.add(value);
    }

    private static void setDoubleValueForGroup(GroupByResultHolder groupByResultHolder, int groupKey, double value) {
        DoubleOpenHashSet doubleSet = (DoubleOpenHashSet)groupByResultHolder.getResult(groupKey);
        if (doubleSet == null) {
            doubleSet = new DoubleOpenHashSet();
            groupByResultHolder.setValueForKey(groupKey, doubleSet);
        }
        doubleSet.add(value);
    }

    private static void setStringValueForGroup(GroupByResultHolder groupByResultHolder, int groupKey, String value) {
        ObjectOpenHashSet stringSet = (ObjectOpenHashSet)groupByResultHolder.getResult(groupKey);
        if (stringSet == null) {
            stringSet = new ObjectOpenHashSet();
            groupByResultHolder.setValueForKey(groupKey, stringSet);
        }
        stringSet.add((Object)value);
    }

    private static void setBytesValueForGroup(GroupByResultHolder groupByResultHolder, int groupKey, ByteArray value) {
        ObjectOpenHashSet bytesSet = (ObjectOpenHashSet)groupByResultHolder.getResult(groupKey);
        if (bytesSet == null) {
            bytesSet = new ObjectOpenHashSet();
            groupByResultHolder.setValueForKey(groupKey, bytesSet);
        }
        bytesSet.add((Object)value);
    }

    private static long extractIntermediateResult(@Nullable Object result) {
        if (result == null) {
            return 0L;
        }
        if (result instanceof RoaringBitmap) {
            return ((RoaringBitmap)result).getLongCardinality();
        }
        assert (result instanceof Collection);
        return ((Collection)result).size();
    }
}

