/*
 * Decompiled with CFR 0.152.
 */
package org.apache.pinot.core.query.aggregation.function;

import com.tdunning.math.stats.TDigest;
import java.util.Map;
import org.apache.pinot.common.request.context.ExpressionContext;
import org.apache.pinot.common.utils.DataSchema;
import org.apache.pinot.core.common.BlockValSet;
import org.apache.pinot.core.common.ObjectSerDeUtils;
import org.apache.pinot.core.query.aggregation.AggregationResultHolder;
import org.apache.pinot.core.query.aggregation.ObjectAggregationResultHolder;
import org.apache.pinot.core.query.aggregation.function.BaseSingleInputAggregationFunction;
import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
import org.apache.pinot.core.query.aggregation.groupby.ObjectGroupByResultHolder;
import org.apache.pinot.segment.spi.AggregationFunctionType;
import org.apache.pinot.spi.data.FieldSpec;

public class PercentileTDigestAggregationFunction
extends BaseSingleInputAggregationFunction<TDigest, Double> {
    public static final int DEFAULT_TDIGEST_COMPRESSION = 100;
    protected final int _version;
    protected final double _percentile;

    public PercentileTDigestAggregationFunction(ExpressionContext expression, int percentile) {
        super(expression);
        this._version = 0;
        this._percentile = percentile;
    }

    public PercentileTDigestAggregationFunction(ExpressionContext expression, double percentile) {
        super(expression);
        this._version = 1;
        this._percentile = percentile;
    }

    @Override
    public AggregationFunctionType getType() {
        return AggregationFunctionType.PERCENTILETDIGEST;
    }

    @Override
    public String getColumnName() {
        return this._version == 0 ? AggregationFunctionType.PERCENTILETDIGEST.getName() + (int)this._percentile + "_" + this._expression : AggregationFunctionType.PERCENTILETDIGEST.getName() + this._percentile + "_" + this._expression;
    }

    @Override
    public String getResultColumnName() {
        return this._version == 0 ? AggregationFunctionType.PERCENTILETDIGEST.getName().toLowerCase() + (int)this._percentile + "(" + this._expression + ")" : AggregationFunctionType.PERCENTILETDIGEST.getName().toLowerCase() + "(" + this._expression + ", " + this._percentile + ")";
    }

    @Override
    public AggregationResultHolder createAggregationResultHolder() {
        return new ObjectAggregationResultHolder();
    }

    @Override
    public GroupByResultHolder createGroupByResultHolder(int initialCapacity, int maxCapacity) {
        return new ObjectGroupByResultHolder(initialCapacity, maxCapacity);
    }

    @Override
    public void aggregate(int length, AggregationResultHolder aggregationResultHolder, Map<ExpressionContext, BlockValSet> blockValSetMap) {
        BlockValSet blockValSet = blockValSetMap.get(this._expression);
        if (blockValSet.getValueType() != FieldSpec.DataType.BYTES) {
            double[] doubleValues = blockValSet.getDoubleValuesSV();
            TDigest tDigest = PercentileTDigestAggregationFunction.getDefaultTDigest(aggregationResultHolder);
            for (int i = 0; i < length; ++i) {
                tDigest.add(doubleValues[i]);
            }
        } else {
            byte[][] bytesValues = blockValSet.getBytesValuesSV();
            TDigest tDigest = (TDigest)aggregationResultHolder.getResult();
            if (tDigest != null) {
                for (int i = 0; i < length; ++i) {
                    tDigest.add(ObjectSerDeUtils.TDIGEST_SER_DE.deserialize(bytesValues[i]));
                }
            } else {
                tDigest = ObjectSerDeUtils.TDIGEST_SER_DE.deserialize(bytesValues[0]);
                aggregationResultHolder.setValue(tDigest);
                for (int i = 1; i < length; ++i) {
                    tDigest.add(ObjectSerDeUtils.TDIGEST_SER_DE.deserialize(bytesValues[i]));
                }
            }
        }
    }

    @Override
    public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder, Map<ExpressionContext, BlockValSet> blockValSetMap) {
        BlockValSet blockValSet = blockValSetMap.get(this._expression);
        if (blockValSet.getValueType() != FieldSpec.DataType.BYTES) {
            double[] doubleValues = blockValSet.getDoubleValuesSV();
            for (int i = 0; i < length; ++i) {
                PercentileTDigestAggregationFunction.getDefaultTDigest(groupByResultHolder, groupKeyArray[i]).add(doubleValues[i]);
            }
        } else {
            byte[][] bytesValues = blockValSet.getBytesValuesSV();
            for (int i = 0; i < length; ++i) {
                TDigest value = ObjectSerDeUtils.TDIGEST_SER_DE.deserialize(bytesValues[i]);
                int groupKey = groupKeyArray[i];
                TDigest tDigest = (TDigest)groupByResultHolder.getResult(groupKey);
                if (tDigest != null) {
                    tDigest.add(value);
                    continue;
                }
                groupByResultHolder.setValueForKey(groupKey, value);
            }
        }
    }

    @Override
    public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResultHolder groupByResultHolder, Map<ExpressionContext, BlockValSet> blockValSetMap) {
        BlockValSet blockValSet = blockValSetMap.get(this._expression);
        if (blockValSet.getValueType() != FieldSpec.DataType.BYTES) {
            double[] doubleValues = blockValSet.getDoubleValuesSV();
            for (int i = 0; i < length; ++i) {
                double value = doubleValues[i];
                for (int groupKey : groupKeysArray[i]) {
                    PercentileTDigestAggregationFunction.getDefaultTDigest(groupByResultHolder, groupKey).add(value);
                }
            }
        } else {
            byte[][] bytesValues = blockValSet.getBytesValuesSV();
            for (int i = 0; i < length; ++i) {
                TDigest value = ObjectSerDeUtils.TDIGEST_SER_DE.deserialize(bytesValues[i]);
                for (int groupKey : groupKeysArray[i]) {
                    TDigest tDigest = (TDigest)groupByResultHolder.getResult(groupKey);
                    if (tDigest != null) {
                        tDigest.add(value);
                        continue;
                    }
                    groupByResultHolder.setValueForKey(groupKey, ObjectSerDeUtils.TDIGEST_SER_DE.deserialize(bytesValues[i]));
                }
            }
        }
    }

    @Override
    public TDigest extractAggregationResult(AggregationResultHolder aggregationResultHolder) {
        TDigest tDigest = (TDigest)aggregationResultHolder.getResult();
        if (tDigest == null) {
            return TDigest.createMergingDigest((double)100.0);
        }
        return tDigest;
    }

    @Override
    public TDigest extractGroupByResult(GroupByResultHolder groupByResultHolder, int groupKey) {
        TDigest tDigest = (TDigest)groupByResultHolder.getResult(groupKey);
        if (tDigest == null) {
            return TDigest.createMergingDigest((double)100.0);
        }
        return tDigest;
    }

    @Override
    public TDigest merge(TDigest intermediateResult1, TDigest intermediateResult2) {
        if (intermediateResult1.size() == 0L) {
            return intermediateResult2;
        }
        if (intermediateResult2.size() == 0L) {
            return intermediateResult1;
        }
        intermediateResult1.add(intermediateResult2);
        return intermediateResult1;
    }

    @Override
    public boolean isIntermediateResultComparable() {
        return false;
    }

    @Override
    public DataSchema.ColumnDataType getIntermediateResultColumnType() {
        return DataSchema.ColumnDataType.OBJECT;
    }

    @Override
    public DataSchema.ColumnDataType getFinalResultColumnType() {
        return DataSchema.ColumnDataType.DOUBLE;
    }

    @Override
    public Double extractFinalResult(TDigest intermediateResult) {
        return intermediateResult.quantile(this._percentile / 100.0);
    }

    protected static TDigest getDefaultTDigest(AggregationResultHolder aggregationResultHolder) {
        TDigest tDigest = (TDigest)aggregationResultHolder.getResult();
        if (tDigest == null) {
            tDigest = TDigest.createMergingDigest((double)100.0);
            aggregationResultHolder.setValue(tDigest);
        }
        return tDigest;
    }

    protected static TDigest getDefaultTDigest(GroupByResultHolder groupByResultHolder, int groupKey) {
        TDigest tDigest = (TDigest)groupByResultHolder.getResult(groupKey);
        if (tDigest == null) {
            tDigest = TDigest.createMergingDigest((double)100.0);
            groupByResultHolder.setValueForKey(groupKey, tDigest);
        }
        return tDigest;
    }
}

