/*
 * Decompiled with CFR 0.152.
 */
package org.apache.pinot.core.query.aggregation.function;

import com.google.common.base.Preconditions;
import java.math.BigDecimal;
import java.math.MathContext;
import java.math.RoundingMode;
import java.util.List;
import java.util.Map;
import org.apache.pinot.common.request.context.ExpressionContext;
import org.apache.pinot.common.utils.DataSchema;
import org.apache.pinot.core.common.BlockValSet;
import org.apache.pinot.core.query.aggregation.AggregationResultHolder;
import org.apache.pinot.core.query.aggregation.ObjectAggregationResultHolder;
import org.apache.pinot.core.query.aggregation.function.BaseSingleInputAggregationFunction;
import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
import org.apache.pinot.core.query.aggregation.groupby.ObjectGroupByResultHolder;
import org.apache.pinot.segment.spi.AggregationFunctionType;
import org.apache.pinot.spi.utils.BigDecimalUtils;
import org.roaringbitmap.RoaringBitmap;

public class SumPrecisionAggregationFunction
extends BaseSingleInputAggregationFunction<BigDecimal, BigDecimal> {
    private final Integer _precision;
    private final Integer _scale;
    private final boolean _nullHandlingEnabled;

    public SumPrecisionAggregationFunction(List<ExpressionContext> arguments, boolean nullHandlingEnabled) {
        super(arguments.get(0));
        int numArguments = arguments.size();
        Preconditions.checkArgument((numArguments <= 3 ? 1 : 0) != 0, (String)"SumPrecision expects at most 3 arguments, got: %s", (int)numArguments);
        if (numArguments > 1) {
            this._precision = Integer.valueOf(arguments.get(1).getLiteral());
            this._scale = numArguments > 2 ? Integer.valueOf(arguments.get(2).getLiteral()) : null;
        } else {
            this._precision = null;
            this._scale = null;
        }
        this._nullHandlingEnabled = nullHandlingEnabled;
    }

    @Override
    public AggregationFunctionType getType() {
        return AggregationFunctionType.SUMPRECISION;
    }

    @Override
    public AggregationResultHolder createAggregationResultHolder() {
        return new ObjectAggregationResultHolder();
    }

    @Override
    public GroupByResultHolder createGroupByResultHolder(int initialCapacity, int maxCapacity) {
        return new ObjectGroupByResultHolder(initialCapacity, maxCapacity);
    }

    @Override
    public void aggregate(int length, AggregationResultHolder aggregationResultHolder, Map<ExpressionContext, BlockValSet> blockValSetMap) {
        RoaringBitmap nullBitmap;
        BlockValSet blockValSet = blockValSetMap.get(this._expression);
        if (this._nullHandlingEnabled && (nullBitmap = blockValSet.getNullBitmap()) != null && !nullBitmap.isEmpty()) {
            this.aggregateNullHandlingEnabled(length, aggregationResultHolder, blockValSet, nullBitmap);
            return;
        }
        BigDecimal sum = this.getDefaultResult(aggregationResultHolder);
        switch (blockValSet.getValueType().getStoredType()) {
            case INT: {
                int[] intValues = blockValSet.getIntValuesSV();
                for (int i = 0; i < length; ++i) {
                    sum = sum.add(BigDecimal.valueOf(intValues[i]));
                }
                break;
            }
            case LONG: {
                long[] longValues = blockValSet.getLongValuesSV();
                for (int i = 0; i < length; ++i) {
                    sum = sum.add(BigDecimal.valueOf(longValues[i]));
                }
                break;
            }
            case FLOAT: 
            case DOUBLE: 
            case STRING: {
                String[] stringValues = blockValSet.getStringValuesSV();
                for (int i = 0; i < length; ++i) {
                    sum = sum.add(new BigDecimal(stringValues[i]));
                }
                break;
            }
            case BIG_DECIMAL: {
                BigDecimal[] bigDecimalValues = blockValSet.getBigDecimalValuesSV();
                for (int i = 0; i < length; ++i) {
                    sum = sum.add(bigDecimalValues[i]);
                }
                break;
            }
            case BYTES: {
                byte[][] bytesValues = blockValSet.getBytesValuesSV();
                for (int i = 0; i < length; ++i) {
                    sum = sum.add(BigDecimalUtils.deserialize((byte[])bytesValues[i]));
                }
                break;
            }
            default: {
                throw new IllegalStateException();
            }
        }
        aggregationResultHolder.setValue(sum);
    }

    private void aggregateNullHandlingEnabled(int length, AggregationResultHolder aggregationResultHolder, BlockValSet blockValSet, RoaringBitmap nullBitmap) {
        BigDecimal sum = BigDecimal.ZERO;
        switch (blockValSet.getValueType().getStoredType()) {
            case INT: {
                if (nullBitmap.getCardinality() >= length) break;
                int[] intValues = blockValSet.getIntValuesSV();
                for (int i = 0; i < length; ++i) {
                    if (nullBitmap.contains(i)) continue;
                    sum = sum.add(BigDecimal.valueOf(intValues[i]));
                }
                this.setAggregationResult(aggregationResultHolder, sum);
                break;
            }
            case LONG: {
                if (nullBitmap.getCardinality() >= length) break;
                long[] longValues = blockValSet.getLongValuesSV();
                for (int i = 0; i < length; ++i) {
                    if (nullBitmap.contains(i)) continue;
                    sum = sum.add(BigDecimal.valueOf(longValues[i]));
                }
                this.setAggregationResult(aggregationResultHolder, sum);
                break;
            }
            case FLOAT: {
                if (nullBitmap.getCardinality() >= length) break;
                float[] floatValues = blockValSet.getFloatValuesSV();
                for (int i = 0; i < length; ++i) {
                    if (nullBitmap.contains(i) || !Float.isFinite(floatValues[i])) continue;
                    sum = sum.add(BigDecimal.valueOf(floatValues[i]));
                }
                this.setAggregationResult(aggregationResultHolder, sum);
                break;
            }
            case DOUBLE: {
                if (nullBitmap.getCardinality() >= length) break;
                double[] doubleValues = blockValSet.getDoubleValuesSV();
                for (int i = 0; i < length; ++i) {
                    if (nullBitmap.contains(i) || !Double.isFinite(doubleValues[i])) continue;
                    sum = sum.add(BigDecimal.valueOf(doubleValues[i]));
                }
                this.setAggregationResult(aggregationResultHolder, sum);
                break;
            }
            case STRING: {
                if (nullBitmap.getCardinality() >= length) break;
                String[] stringValues = blockValSet.getStringValuesSV();
                for (int i = 0; i < length; ++i) {
                    if (nullBitmap.contains(i)) continue;
                    sum = sum.add(new BigDecimal(stringValues[i]));
                }
                this.setAggregationResult(aggregationResultHolder, sum);
                break;
            }
            case BIG_DECIMAL: {
                if (nullBitmap.getCardinality() >= length) break;
                BigDecimal[] bigDecimalValues = blockValSet.getBigDecimalValuesSV();
                for (int i = 0; i < length; ++i) {
                    if (nullBitmap.contains(i)) continue;
                    sum = sum.add(bigDecimalValues[i]);
                }
                this.setAggregationResult(aggregationResultHolder, sum);
                break;
            }
            case BYTES: {
                if (nullBitmap.getCardinality() >= length) break;
                byte[][] bytesValues = blockValSet.getBytesValuesSV();
                for (int i = 0; i < length; ++i) {
                    if (nullBitmap.contains(i)) continue;
                    sum = sum.add(BigDecimalUtils.deserialize((byte[])bytesValues[i]));
                }
                this.setAggregationResult(aggregationResultHolder, sum);
                break;
            }
            default: {
                throw new IllegalStateException();
            }
        }
    }

    protected void setAggregationResult(AggregationResultHolder aggregationResultHolder, BigDecimal sum) {
        BigDecimal otherSum = (BigDecimal)aggregationResultHolder.getResult();
        aggregationResultHolder.setValue(otherSum == null ? sum : sum.add(otherSum));
    }

    @Override
    public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder, Map<ExpressionContext, BlockValSet> blockValSetMap) {
        RoaringBitmap nullBitmap;
        BlockValSet blockValSet = blockValSetMap.get(this._expression);
        if (this._nullHandlingEnabled && (nullBitmap = blockValSet.getNullBitmap()) != null && !nullBitmap.isEmpty()) {
            this.aggregateGroupBySVNullHandlingEnabled(length, groupKeyArray, groupByResultHolder, blockValSet, nullBitmap);
            return;
        }
        switch (blockValSet.getValueType().getStoredType()) {
            case INT: {
                int[] intValues = blockValSet.getIntValuesSV();
                for (int i = 0; i < length; ++i) {
                    int groupKey = groupKeyArray[i];
                    BigDecimal sum = this.getDefaultResult(groupByResultHolder, groupKey);
                    sum = sum.add(BigDecimal.valueOf(intValues[i]));
                    groupByResultHolder.setValueForKey(groupKey, sum);
                }
                break;
            }
            case LONG: {
                long[] longValues = blockValSet.getLongValuesSV();
                for (int i = 0; i < length; ++i) {
                    int groupKey = groupKeyArray[i];
                    BigDecimal sum = this.getDefaultResult(groupByResultHolder, groupKey);
                    sum = sum.add(BigDecimal.valueOf(longValues[i]));
                    groupByResultHolder.setValueForKey(groupKey, sum);
                }
                break;
            }
            case FLOAT: 
            case DOUBLE: 
            case STRING: {
                String[] stringValues = blockValSet.getStringValuesSV();
                for (int i = 0; i < length; ++i) {
                    int groupKey = groupKeyArray[i];
                    BigDecimal sum = this.getDefaultResult(groupByResultHolder, groupKey);
                    sum = sum.add(new BigDecimal(stringValues[i]));
                    groupByResultHolder.setValueForKey(groupKey, sum);
                }
                break;
            }
            case BIG_DECIMAL: {
                BigDecimal[] bigDecimalValues = blockValSet.getBigDecimalValuesSV();
                for (int i = 0; i < length; ++i) {
                    int groupKey = groupKeyArray[i];
                    BigDecimal sum = this.getDefaultResult(groupByResultHolder, groupKey);
                    sum = sum.add(bigDecimalValues[i]);
                    groupByResultHolder.setValueForKey(groupKey, sum);
                }
                break;
            }
            case BYTES: {
                byte[][] bytesValues = blockValSet.getBytesValuesSV();
                for (int i = 0; i < length; ++i) {
                    int groupKey = groupKeyArray[i];
                    BigDecimal sum = this.getDefaultResult(groupByResultHolder, groupKey);
                    sum = sum.add(BigDecimalUtils.deserialize((byte[])bytesValues[i]));
                    groupByResultHolder.setValueForKey(groupKey, sum);
                }
                break;
            }
            default: {
                throw new IllegalStateException();
            }
        }
    }

    private void aggregateGroupBySVNullHandlingEnabled(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder, BlockValSet blockValSet, RoaringBitmap nullBitmap) {
        switch (blockValSet.getValueType().getStoredType()) {
            case INT: {
                if (nullBitmap.getCardinality() >= length) break;
                int[] intValues = blockValSet.getIntValuesSV();
                for (int i = 0; i < length; ++i) {
                    if (nullBitmap.contains(i)) continue;
                    this.setGroupByResult(groupKeyArray[i], groupByResultHolder, BigDecimal.valueOf(intValues[i]));
                }
                break;
            }
            case LONG: {
                if (nullBitmap.getCardinality() >= length) break;
                long[] longValues = blockValSet.getLongValuesSV();
                for (int i = 0; i < length; ++i) {
                    if (nullBitmap.contains(i)) continue;
                    this.setGroupByResult(groupKeyArray[i], groupByResultHolder, BigDecimal.valueOf(longValues[i]));
                }
                break;
            }
            case FLOAT: 
            case DOUBLE: 
            case STRING: {
                if (nullBitmap.getCardinality() >= length) break;
                String[] stringValues = blockValSet.getStringValuesSV();
                for (int i = 0; i < length; ++i) {
                    if (nullBitmap.contains(i)) continue;
                    this.setGroupByResult(groupKeyArray[i], groupByResultHolder, new BigDecimal(stringValues[i]));
                }
                break;
            }
            case BIG_DECIMAL: {
                if (nullBitmap.getCardinality() >= length) break;
                BigDecimal[] bigDecimalValues = blockValSet.getBigDecimalValuesSV();
                for (int i = 0; i < length; ++i) {
                    if (nullBitmap.contains(i)) continue;
                    this.setGroupByResult(groupKeyArray[i], groupByResultHolder, bigDecimalValues[i]);
                }
                break;
            }
            case BYTES: {
                if (nullBitmap.getCardinality() >= length) break;
                byte[][] bytesValues = blockValSet.getBytesValuesSV();
                for (int i = 0; i < length; ++i) {
                    if (nullBitmap.contains(i)) continue;
                    this.setGroupByResult(groupKeyArray[i], groupByResultHolder, BigDecimalUtils.deserialize((byte[])bytesValues[i]));
                }
                break;
            }
            default: {
                throw new IllegalStateException();
            }
        }
    }

    private void setGroupByResult(int groupKey, GroupByResultHolder groupByResultHolder, BigDecimal value) {
        BigDecimal sum = (BigDecimal)groupByResultHolder.getResult(groupKey);
        sum = sum == null ? value : sum.add(value);
        groupByResultHolder.setValueForKey(groupKey, sum);
    }

    @Override
    public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResultHolder groupByResultHolder, Map<ExpressionContext, BlockValSet> blockValSetMap) {
        BlockValSet blockValSet = blockValSetMap.get(this._expression);
        switch (blockValSet.getValueType().getStoredType()) {
            case INT: {
                int[] intValues = blockValSet.getIntValuesSV();
                for (int i = 0; i < length; ++i) {
                    int value = intValues[i];
                    for (int groupKey : groupKeysArray[i]) {
                        BigDecimal sum = this.getDefaultResult(groupByResultHolder, groupKey);
                        sum = sum.add(BigDecimal.valueOf(value));
                        groupByResultHolder.setValueForKey(groupKey, sum);
                    }
                }
                break;
            }
            case LONG: {
                long[] longValues = blockValSet.getLongValuesSV();
                for (int i = 0; i < length; ++i) {
                    long value = longValues[i];
                    for (int groupKey : groupKeysArray[i]) {
                        BigDecimal sum = this.getDefaultResult(groupByResultHolder, groupKey);
                        sum = sum.add(BigDecimal.valueOf(value));
                        groupByResultHolder.setValueForKey(groupKey, sum);
                    }
                }
                break;
            }
            case FLOAT: 
            case DOUBLE: 
            case STRING: {
                String[] stringValues = blockValSet.getStringValuesSV();
                for (int i = 0; i < length; ++i) {
                    String value = stringValues[i];
                    for (int groupKey : groupKeysArray[i]) {
                        BigDecimal sum = this.getDefaultResult(groupByResultHolder, groupKey);
                        sum = sum.add(new BigDecimal(value));
                        groupByResultHolder.setValueForKey(groupKey, sum);
                    }
                }
                break;
            }
            case BIG_DECIMAL: {
                BigDecimal[] bigDecimalValues = blockValSet.getBigDecimalValuesSV();
                for (int i = 0; i < length; ++i) {
                    BigDecimal value = bigDecimalValues[i];
                    for (int groupKey : groupKeysArray[i]) {
                        BigDecimal sum = this.getDefaultResult(groupByResultHolder, groupKey);
                        sum = sum.add(value);
                        groupByResultHolder.setValueForKey(groupKey, sum);
                    }
                }
                break;
            }
            case BYTES: {
                byte[][] bytesValues = blockValSet.getBytesValuesSV();
                for (int i = 0; i < length; ++i) {
                    byte[] value = bytesValues[i];
                    for (int groupKey : groupKeysArray[i]) {
                        BigDecimal sum = this.getDefaultResult(groupByResultHolder, groupKey);
                        sum = sum.add(BigDecimalUtils.deserialize((byte[])value));
                        groupByResultHolder.setValueForKey(groupKey, sum);
                    }
                }
                break;
            }
            default: {
                throw new IllegalStateException();
            }
        }
    }

    @Override
    public BigDecimal extractAggregationResult(AggregationResultHolder aggregationResultHolder) {
        BigDecimal result = (BigDecimal)aggregationResultHolder.getResult();
        if (result == null) {
            return this._nullHandlingEnabled ? null : BigDecimal.ZERO;
        }
        return result;
    }

    @Override
    public BigDecimal extractGroupByResult(GroupByResultHolder groupByResultHolder, int groupKey) {
        BigDecimal result = (BigDecimal)groupByResultHolder.getResult(groupKey);
        if (result == null) {
            return this._nullHandlingEnabled ? null : BigDecimal.ZERO;
        }
        return result;
    }

    @Override
    public BigDecimal merge(BigDecimal intermediateResult1, BigDecimal intermediateResult2) {
        if (this._nullHandlingEnabled) {
            if (intermediateResult1 == null) {
                return intermediateResult2;
            }
            if (intermediateResult2 == null) {
                return intermediateResult1;
            }
        }
        return intermediateResult1.add(intermediateResult2);
    }

    @Override
    public DataSchema.ColumnDataType getIntermediateResultColumnType() {
        return DataSchema.ColumnDataType.OBJECT;
    }

    @Override
    public DataSchema.ColumnDataType getFinalResultColumnType() {
        return DataSchema.ColumnDataType.STRING;
    }

    @Override
    public BigDecimal extractFinalResult(BigDecimal intermediateResult) {
        if (intermediateResult == null) {
            return null;
        }
        if (this._precision == null) {
            return intermediateResult;
        }
        BigDecimal result = intermediateResult.round(new MathContext(this._precision, RoundingMode.HALF_EVEN));
        return this._scale == null ? result : result.setScale((int)this._scale, RoundingMode.HALF_EVEN);
    }

    public BigDecimal getDefaultResult(AggregationResultHolder aggregationResultHolder) {
        BigDecimal result = (BigDecimal)aggregationResultHolder.getResult();
        return result != null ? result : BigDecimal.ZERO;
    }

    public BigDecimal getDefaultResult(GroupByResultHolder groupByResultHolder, int groupKey) {
        BigDecimal result = (BigDecimal)groupByResultHolder.getResult(groupKey);
        return result != null ? result : BigDecimal.ZERO;
    }
}

