/*
 * Decompiled with CFR 0.152.
 */
package org.apache.pinot.core.query.aggregation.function;

import com.google.common.base.Preconditions;
import java.util.List;
import java.util.Map;
import org.apache.commons.lang3.StringUtils;
import org.apache.datasketches.common.Util;
import org.apache.datasketches.cpc.CpcSketch;
import org.apache.datasketches.memory.Memory;
import org.apache.pinot.common.request.context.ExpressionContext;
import org.apache.pinot.common.utils.DataSchema;
import org.apache.pinot.core.common.BlockValSet;
import org.apache.pinot.core.query.aggregation.AggregationResultHolder;
import org.apache.pinot.core.query.aggregation.ObjectAggregationResultHolder;
import org.apache.pinot.core.query.aggregation.function.BaseSingleInputAggregationFunction;
import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
import org.apache.pinot.core.query.aggregation.groupby.ObjectGroupByResultHolder;
import org.apache.pinot.segment.local.customobject.CpcSketchAccumulator;
import org.apache.pinot.segment.local.customobject.CustomObjectAccumulator;
import org.apache.pinot.segment.spi.AggregationFunctionType;
import org.apache.pinot.segment.spi.index.reader.Dictionary;
import org.apache.pinot.spi.data.FieldSpec;
import org.roaringbitmap.PeekableIntIterator;
import org.roaringbitmap.RoaringBitmap;

public class DistinctCountCPCSketchAggregationFunction
extends BaseSingleInputAggregationFunction<CpcSketchAccumulator, Comparable> {
    private static final int DEFAULT_ACCUMULATOR_THRESHOLD = 2;
    protected int _accumulatorThreshold = 2;
    protected int _lgNominalEntries;

    public DistinctCountCPCSketchAggregationFunction(List<ExpressionContext> arguments) {
        super(arguments.get(0));
        int numExpressions = arguments.size();
        Preconditions.checkArgument((numExpressions <= 2 ? 1 : 0) != 0, (String)"DistinctCountCPC expects 1 or 2 arguments, got: %s", (int)numExpressions);
        if (arguments.size() == 2) {
            ExpressionContext secondArgument = arguments.get(1);
            Preconditions.checkArgument((secondArgument.getType() == ExpressionContext.Type.LITERAL ? 1 : 0) != 0, (String)"CPC Sketch Aggregation Function expects the second argument to be a literal (parameters), but got: ", (Object)secondArgument.getType());
            if (secondArgument.getLiteral().getType() == FieldSpec.DataType.STRING) {
                Parameters parameters = new Parameters(secondArgument.getLiteral().getStringValue());
                this._accumulatorThreshold = parameters.getAccumulatorThreshold();
                this._lgNominalEntries = parameters.getLgNominalEntries();
            } else {
                this._lgNominalEntries = secondArgument.getLiteral().getIntValue();
            }
        } else {
            this._lgNominalEntries = 12;
        }
    }

    @Override
    public AggregationFunctionType getType() {
        return AggregationFunctionType.DISTINCTCOUNTCPCSKETCH;
    }

    @Override
    public AggregationResultHolder createAggregationResultHolder() {
        return new ObjectAggregationResultHolder();
    }

    @Override
    public GroupByResultHolder createGroupByResultHolder(int initialCapacity, int maxCapacity) {
        return new ObjectGroupByResultHolder(initialCapacity, maxCapacity);
    }

    @Override
    public void aggregate(int length, AggregationResultHolder aggregationResultHolder, Map<ExpressionContext, BlockValSet> blockValSetMap) {
        BlockValSet blockValSet = blockValSetMap.get(this._expression);
        FieldSpec.DataType storedType = blockValSet.getValueType().getStoredType();
        if (storedType == FieldSpec.DataType.BYTES) {
            byte[][] bytesValues = blockValSet.getBytesValuesSV();
            try {
                CpcSketch[] sketches;
                CpcSketchAccumulator cpcSketchAccumulator = this.getAccumulator(aggregationResultHolder);
                for (CpcSketch sketch : sketches = this.deserializeSketches(bytesValues, length)) {
                    cpcSketchAccumulator.apply((Object)sketch);
                }
            }
            catch (Exception e) {
                throw new RuntimeException("Caught exception while merging CPC sketches", e);
            }
            return;
        }
        Dictionary dictionary = blockValSet.getDictionary();
        if (dictionary != null) {
            int[] dictIds = blockValSet.getDictionaryIdsSV();
            DistinctCountCPCSketchAggregationFunction.getDictIdBitmap(aggregationResultHolder, dictionary).addN(dictIds, 0, length);
            return;
        }
        CpcSketch cpcSketch = this.getCpcSketch(aggregationResultHolder);
        switch (storedType) {
            case INT: {
                int[] intValues = blockValSet.getIntValuesSV();
                for (int i = 0; i < length; ++i) {
                    cpcSketch.update((long)intValues[i]);
                }
                break;
            }
            case LONG: {
                long[] longValues = blockValSet.getLongValuesSV();
                for (int i = 0; i < length; ++i) {
                    cpcSketch.update(longValues[i]);
                }
                break;
            }
            case FLOAT: {
                float[] floatValues = blockValSet.getFloatValuesSV();
                for (int i = 0; i < length; ++i) {
                    cpcSketch.update((double)floatValues[i]);
                }
                break;
            }
            case DOUBLE: {
                double[] doubleValues = blockValSet.getDoubleValuesSV();
                for (int i = 0; i < length; ++i) {
                    cpcSketch.update(doubleValues[i]);
                }
                break;
            }
            case STRING: {
                String[] stringValues = blockValSet.getStringValuesSV();
                for (int i = 0; i < length; ++i) {
                    cpcSketch.update(stringValues[i]);
                }
                break;
            }
            default: {
                throw new IllegalStateException("Illegal data type for DISTINCT_COUNT_CPC aggregation function: " + storedType);
            }
        }
        CpcSketchAccumulator cpcSketchAccumulator = this.getAccumulator(aggregationResultHolder);
        cpcSketchAccumulator.apply((Object)cpcSketch);
    }

    @Override
    public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder, Map<ExpressionContext, BlockValSet> blockValSetMap) {
        BlockValSet blockValSet = blockValSetMap.get(this._expression);
        FieldSpec.DataType storedType = blockValSet.getValueType().getStoredType();
        if (storedType == FieldSpec.DataType.BYTES) {
            byte[][] bytesValues = blockValSet.getBytesValuesSV();
            try {
                CpcSketch[] sketches = this.deserializeSketches(bytesValues, length);
                for (int i = 0; i < length; ++i) {
                    CpcSketchAccumulator cpcSketchAccumulator = this.getAccumulator(groupByResultHolder, groupKeyArray[i]);
                    CpcSketch sketch = sketches[i];
                    cpcSketchAccumulator.apply((Object)sketch);
                }
            }
            catch (Exception e) {
                throw new RuntimeException("Caught exception while aggregating CPC Sketches", e);
            }
            return;
        }
        Dictionary dictionary = blockValSet.getDictionary();
        if (dictionary != null) {
            int[] dictIds = blockValSet.getDictionaryIdsSV();
            for (int i = 0; i < length; ++i) {
                DistinctCountCPCSketchAggregationFunction.getDictIdBitmap(groupByResultHolder, groupKeyArray[i], dictionary).add(dictIds[i]);
            }
            return;
        }
        switch (storedType) {
            case INT: {
                int[] intValues = blockValSet.getIntValuesSV();
                for (int i = 0; i < length; ++i) {
                    this.getCpcSketch(groupByResultHolder, groupKeyArray[i]).update((long)intValues[i]);
                }
                break;
            }
            case LONG: {
                long[] longValues = blockValSet.getLongValuesSV();
                for (int i = 0; i < length; ++i) {
                    this.getCpcSketch(groupByResultHolder, groupKeyArray[i]).update(longValues[i]);
                }
                break;
            }
            case FLOAT: {
                float[] floatValues = blockValSet.getFloatValuesSV();
                for (int i = 0; i < length; ++i) {
                    this.getCpcSketch(groupByResultHolder, groupKeyArray[i]).update((double)floatValues[i]);
                }
                break;
            }
            case DOUBLE: {
                double[] doubleValues = blockValSet.getDoubleValuesSV();
                for (int i = 0; i < length; ++i) {
                    this.getCpcSketch(groupByResultHolder, groupKeyArray[i]).update(doubleValues[i]);
                }
                break;
            }
            case STRING: {
                String[] stringValues = blockValSet.getStringValuesSV();
                for (int i = 0; i < length; ++i) {
                    this.getCpcSketch(groupByResultHolder, groupKeyArray[i]).update(stringValues[i]);
                }
                break;
            }
            default: {
                throw new IllegalStateException("Illegal data type for DISTINCT_COUNT_CPC aggregation function: " + storedType);
            }
        }
    }

    @Override
    public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResultHolder groupByResultHolder, Map<ExpressionContext, BlockValSet> blockValSetMap) {
        BlockValSet blockValSet = blockValSetMap.get(this._expression);
        FieldSpec.DataType storedType = blockValSet.getValueType().getStoredType();
        boolean singleValue = blockValSet.isSingleValue();
        if (singleValue && storedType == FieldSpec.DataType.BYTES) {
            byte[][] bytesValues = blockValSet.getBytesValuesSV();
            try {
                CpcSketch[] sketches = this.deserializeSketches(bytesValues, length);
                for (int i = 0; i < length; ++i) {
                    for (int groupKey : groupKeysArray[i]) {
                        this.getAccumulator(groupByResultHolder, groupKey).apply((Object)sketches[i]);
                    }
                }
            }
            catch (Exception e) {
                throw new RuntimeException("Caught exception while aggregating CPC sketches", e);
            }
            return;
        }
        Dictionary dictionary = blockValSet.getDictionary();
        if (dictionary != null) {
            int[] dictIds = blockValSet.getDictionaryIdsSV();
            for (int i = 0; i < length; ++i) {
                DistinctCountCPCSketchAggregationFunction.setDictIdForGroupKeys(groupByResultHolder, groupKeysArray[i], dictionary, dictIds[i]);
            }
            return;
        }
        switch (storedType) {
            case INT: {
                int[] intValues = blockValSet.getIntValuesSV();
                for (int i = 0; i < length; ++i) {
                    for (int groupKey : groupKeysArray[i]) {
                        this.getCpcSketch(groupByResultHolder, groupKey).update((long)intValues[i]);
                    }
                }
                break;
            }
            case LONG: {
                long[] longValues = blockValSet.getLongValuesSV();
                for (int i = 0; i < length; ++i) {
                    for (int groupKey : groupKeysArray[i]) {
                        this.getCpcSketch(groupByResultHolder, groupKey).update(longValues[i]);
                    }
                }
                break;
            }
            case FLOAT: {
                float[] floatValues = blockValSet.getFloatValuesSV();
                for (int i = 0; i < length; ++i) {
                    for (int groupKey : groupKeysArray[i]) {
                        this.getCpcSketch(groupByResultHolder, groupKey).update((double)floatValues[i]);
                    }
                }
                break;
            }
            case DOUBLE: {
                double[] doubleValues = blockValSet.getDoubleValuesSV();
                for (int i = 0; i < length; ++i) {
                    for (int groupKey : groupKeysArray[i]) {
                        this.getCpcSketch(groupByResultHolder, groupKey).update(doubleValues[i]);
                    }
                }
                break;
            }
            case STRING: {
                String[] stringValues = blockValSet.getStringValuesSV();
                for (int i = 0; i < length; ++i) {
                    for (int groupKey : groupKeysArray[i]) {
                        this.getCpcSketch(groupByResultHolder, groupKey).update(stringValues[i]);
                    }
                }
                break;
            }
            default: {
                throw new IllegalStateException("Illegal data type for DISTINCT_COUNT_CPC aggregation function: " + storedType);
            }
        }
    }

    @Override
    public CpcSketchAccumulator extractAggregationResult(AggregationResultHolder aggregationResultHolder) {
        Object result = aggregationResultHolder.getResult();
        if (result == null) {
            return new CpcSketchAccumulator(this._lgNominalEntries, this._accumulatorThreshold);
        }
        if (result instanceof CpcSketch) {
            return this.convertSketchAccumulator(result);
        }
        if (result instanceof DictIdsWrapper) {
            return this.convertSketchAccumulator(this.dictionaryToCpcSketch((DictIdsWrapper)result));
        }
        return (CpcSketchAccumulator)result;
    }

    @Override
    public CpcSketchAccumulator extractGroupByResult(GroupByResultHolder groupByResultHolder, int groupKey) {
        Object result = groupByResultHolder.getResult(groupKey);
        if (result == null) {
            return new CpcSketchAccumulator(this._lgNominalEntries, this._accumulatorThreshold);
        }
        if (result instanceof CpcSketch) {
            return this.convertSketchAccumulator(result);
        }
        if (result instanceof DictIdsWrapper) {
            return this.convertSketchAccumulator(this.dictionaryToCpcSketch((DictIdsWrapper)result));
        }
        return (CpcSketchAccumulator)result;
    }

    @Override
    public CpcSketchAccumulator merge(CpcSketchAccumulator intermediateResult1, CpcSketchAccumulator intermediateResult2) {
        if (intermediateResult1 == null || intermediateResult1.isEmpty()) {
            return intermediateResult2;
        }
        if (intermediateResult2 == null || intermediateResult2.isEmpty()) {
            return intermediateResult1;
        }
        intermediateResult1.setLgNominalEntries(this._lgNominalEntries);
        intermediateResult1.setThreshold(this._accumulatorThreshold);
        intermediateResult1.merge((CustomObjectAccumulator)intermediateResult2);
        return intermediateResult1;
    }

    @Override
    public DataSchema.ColumnDataType getIntermediateResultColumnType() {
        return DataSchema.ColumnDataType.OBJECT;
    }

    @Override
    public DataSchema.ColumnDataType getFinalResultColumnType() {
        return DataSchema.ColumnDataType.LONG;
    }

    @Override
    public Comparable extractFinalResult(CpcSketchAccumulator intermediateResult) {
        intermediateResult.setLgNominalEntries(this._lgNominalEntries);
        intermediateResult.setThreshold(this._accumulatorThreshold);
        return Long.valueOf(Math.round(intermediateResult.getResult().getEstimate()));
    }

    @Override
    public Comparable mergeFinalResult(Comparable finalResult1, Comparable finalResult2) {
        return Long.valueOf((Long)finalResult1 + (Long)finalResult2);
    }

    protected CpcSketch getCpcSketch(AggregationResultHolder aggregationResultHolder) {
        CpcSketch cpcSketch = (CpcSketch)aggregationResultHolder.getResult();
        if (cpcSketch == null) {
            cpcSketch = new CpcSketch(this._lgNominalEntries);
            aggregationResultHolder.setValue(cpcSketch);
        }
        return cpcSketch;
    }

    protected static RoaringBitmap getDictIdBitmap(AggregationResultHolder aggregationResultHolder, Dictionary dictionary) {
        DictIdsWrapper dictIdsWrapper = (DictIdsWrapper)aggregationResultHolder.getResult();
        if (dictIdsWrapper == null) {
            dictIdsWrapper = new DictIdsWrapper(dictionary);
            aggregationResultHolder.setValue(dictIdsWrapper);
        }
        return dictIdsWrapper._dictIdBitmap;
    }

    protected static RoaringBitmap getDictIdBitmap(GroupByResultHolder groupByResultHolder, int groupKey, Dictionary dictionary) {
        DictIdsWrapper dictIdsWrapper = (DictIdsWrapper)groupByResultHolder.getResult(groupKey);
        if (dictIdsWrapper == null) {
            dictIdsWrapper = new DictIdsWrapper(dictionary);
            groupByResultHolder.setValueForKey(groupKey, dictIdsWrapper);
        }
        return dictIdsWrapper._dictIdBitmap;
    }

    protected CpcSketch getCpcSketch(GroupByResultHolder groupByResultHolder, int groupKey) {
        CpcSketch cpcSketch = (CpcSketch)groupByResultHolder.getResult(groupKey);
        if (cpcSketch == null) {
            cpcSketch = new CpcSketch(this._lgNominalEntries);
            groupByResultHolder.setValueForKey(groupKey, cpcSketch);
        }
        return cpcSketch;
    }

    private static void setDictIdForGroupKeys(GroupByResultHolder groupByResultHolder, int[] groupKeys, Dictionary dictionary, int dictId) {
        for (int groupKey : groupKeys) {
            DistinctCountCPCSketchAggregationFunction.getDictIdBitmap(groupByResultHolder, groupKey, dictionary).add(dictId);
        }
    }

    private CpcSketch dictionaryToCpcSketch(DictIdsWrapper dictIdsWrapper) {
        CpcSketch cpcSketch = new CpcSketch(this._lgNominalEntries);
        Dictionary dictionary = dictIdsWrapper._dictionary;
        RoaringBitmap dictIdBitmap = dictIdsWrapper._dictIdBitmap;
        PeekableIntIterator iterator = dictIdBitmap.getIntIterator();
        while (iterator.hasNext()) {
            Object value = dictionary.get(iterator.next());
            this.addObjectToSketch(value, cpcSketch);
        }
        return cpcSketch;
    }

    private void addObjectToSketch(Object rawValue, CpcSketch sketch) {
        if (rawValue instanceof String) {
            sketch.update((String)rawValue);
        } else if (rawValue instanceof Integer) {
            sketch.update((long)((Integer)rawValue).intValue());
        } else if (rawValue instanceof Long) {
            sketch.update(((Long)rawValue).longValue());
        } else if (rawValue instanceof Double) {
            sketch.update(((Double)rawValue).doubleValue());
        } else if (rawValue instanceof Float) {
            sketch.update((double)((Float)rawValue).floatValue());
        } else if (rawValue instanceof Object[]) {
            this.addObjectsToSketch((Object[])rawValue, sketch);
        } else {
            throw new IllegalStateException("Unsupported data type for CPC Sketch aggregation: " + rawValue.getClass().getSimpleName());
        }
    }

    private void addObjectsToSketch(Object[] rawValues, CpcSketch sketch) {
        if (rawValues instanceof String[]) {
            for (String s : (String[])rawValues) {
                sketch.update(s);
            }
        } else if (rawValues instanceof Integer[]) {
            for (Integer i : (Integer[])rawValues) {
                sketch.update((long)i.intValue());
            }
        } else if (rawValues instanceof Long[]) {
            for (Long l : (Long[])rawValues) {
                sketch.update(l.longValue());
            }
        } else if (rawValues instanceof Double[]) {
            for (Double d : (Double[])rawValues) {
                sketch.update(d.doubleValue());
            }
        } else if (rawValues instanceof Float[]) {
            for (Float f : (Float[])rawValues) {
                sketch.update((double)f.floatValue());
            }
        } else {
            throw new IllegalStateException("Unsupported data type for CPC Sketch aggregation: " + rawValues.getClass().getSimpleName());
        }
    }

    private CpcSketchAccumulator getAccumulator(AggregationResultHolder aggregationResultHolder) {
        CpcSketchAccumulator accumulator = (CpcSketchAccumulator)aggregationResultHolder.getResult();
        if (accumulator == null) {
            accumulator = new CpcSketchAccumulator(this._lgNominalEntries, this._accumulatorThreshold);
            aggregationResultHolder.setValue(accumulator);
        }
        return accumulator;
    }

    private CpcSketchAccumulator getAccumulator(GroupByResultHolder groupByResultHolder, int groupKey) {
        CpcSketchAccumulator accumulator = (CpcSketchAccumulator)groupByResultHolder.getResult(groupKey);
        if (accumulator == null) {
            accumulator = new CpcSketchAccumulator(this._lgNominalEntries, this._accumulatorThreshold);
            groupByResultHolder.setValueForKey(groupKey, accumulator);
        }
        return accumulator;
    }

    private CpcSketch[] deserializeSketches(byte[][] serializedSketches, int length) {
        CpcSketch[] sketches = new CpcSketch[length];
        for (int i = 0; i < length; ++i) {
            sketches[i] = CpcSketch.heapify((Memory)Memory.wrap((byte[])serializedSketches[i]));
        }
        return sketches;
    }

    protected CpcSketchAccumulator convertSketchAccumulator(Object result) {
        if (result instanceof CpcSketch) {
            CpcSketch sketch = (CpcSketch)result;
            CpcSketchAccumulator accumulator = new CpcSketchAccumulator(this._lgNominalEntries, this._accumulatorThreshold);
            accumulator.apply((Object)sketch);
            return accumulator;
        }
        return (CpcSketchAccumulator)result;
    }

    private static class Parameters {
        private static final char PARAMETER_DELIMITER = ';';
        private static final char PARAMETER_KEY_VALUE_SEPARATOR = '=';
        private static final String NOMINAL_ENTRIES_KEY = "nominalEntries";
        private static final String ACCUMULATOR_THRESHOLD_KEY = "accumulatorThreshold";
        private int _nominalEntries = (int)Math.pow(2.0, 12.0);
        private int _accumulatorThreshold = 2;

        Parameters(String parametersString) {
            String[] keyValuePairs;
            StringUtils.deleteWhitespace((String)parametersString);
            for (String keyValuePair : keyValuePairs = StringUtils.split((String)parametersString, (char)';')) {
                String[] keyAndValue = StringUtils.split((String)keyValuePair, (char)'=');
                Preconditions.checkArgument((keyAndValue.length == 2 ? 1 : 0) != 0, (String)"Invalid parameter: %s", (Object)keyValuePair);
                String key = keyAndValue[0];
                String value = keyAndValue[1];
                if (key.equalsIgnoreCase(NOMINAL_ENTRIES_KEY)) {
                    this._nominalEntries = Integer.parseInt(value);
                    continue;
                }
                if (key.equalsIgnoreCase(ACCUMULATOR_THRESHOLD_KEY)) {
                    this._accumulatorThreshold = Integer.parseInt(value);
                    continue;
                }
                throw new IllegalArgumentException("Invalid parameter key: " + key);
            }
        }

        int getLgNominalEntries() {
            return Util.exactLog2OfInt((int)this._nominalEntries);
        }

        int getAccumulatorThreshold() {
            return this._accumulatorThreshold;
        }
    }

    private static final class DictIdsWrapper {
        final Dictionary _dictionary;
        final RoaringBitmap _dictIdBitmap;

        private DictIdsWrapper(Dictionary dictionary) {
            this._dictionary = dictionary;
            this._dictIdBitmap = new RoaringBitmap();
        }
    }
}

