/*
 * Decompiled with CFR 0.152.
 */
package org.apache.pinot.core.query.aggregation.function;

import com.google.common.base.Preconditions;
import java.math.BigDecimal;
import java.math.MathContext;
import java.math.RoundingMode;
import java.util.List;
import java.util.Map;
import org.apache.pinot.common.request.context.ExpressionContext;
import org.apache.pinot.common.utils.DataSchema;
import org.apache.pinot.core.common.BlockValSet;
import org.apache.pinot.core.query.aggregation.AggregationResultHolder;
import org.apache.pinot.core.query.aggregation.ObjectAggregationResultHolder;
import org.apache.pinot.core.query.aggregation.function.NullableSingleInputAggregationFunction;
import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
import org.apache.pinot.core.query.aggregation.groupby.ObjectGroupByResultHolder;
import org.apache.pinot.segment.spi.AggregationFunctionType;
import org.apache.pinot.spi.utils.BigDecimalUtils;

public class SumPrecisionAggregationFunction
extends NullableSingleInputAggregationFunction<BigDecimal, BigDecimal> {
    private final Integer _precision;
    private final Integer _scale;

    public SumPrecisionAggregationFunction(List<ExpressionContext> arguments, boolean nullHandlingEnabled) {
        super(arguments.get(0), nullHandlingEnabled);
        int numArguments = arguments.size();
        Preconditions.checkArgument((numArguments <= 3 ? 1 : 0) != 0, (String)"SumPrecision expects at most 3 arguments, got: %s", (int)numArguments);
        if (numArguments > 1) {
            this._precision = arguments.get(1).getLiteral().getIntValue();
            this._scale = numArguments > 2 ? Integer.valueOf(arguments.get(2).getLiteral().getIntValue()) : null;
        } else {
            this._precision = null;
            this._scale = null;
        }
    }

    @Override
    public AggregationFunctionType getType() {
        return AggregationFunctionType.SUMPRECISION;
    }

    @Override
    public AggregationResultHolder createAggregationResultHolder() {
        return new ObjectAggregationResultHolder();
    }

    @Override
    public GroupByResultHolder createGroupByResultHolder(int initialCapacity, int maxCapacity) {
        return new ObjectGroupByResultHolder(initialCapacity, maxCapacity);
    }

    @Override
    public void aggregate(int length, AggregationResultHolder aggregationResultHolder, Map<ExpressionContext, BlockValSet> blockValSetMap) {
        BigDecimal sum;
        BlockValSet blockValSet = blockValSetMap.get(this._expression);
        switch (blockValSet.getValueType().getStoredType()) {
            case INT: {
                int[] intValues = blockValSet.getIntValuesSV();
                sum = this.foldNotNull(length, blockValSet, null, (A acum, int from, int to) -> {
                    BigDecimal innerSum = BigDecimal.ZERO;
                    for (int i = from; i < to; ++i) {
                        innerSum = innerSum.add(BigDecimal.valueOf(intValues[i]));
                    }
                    return acum == null ? innerSum : acum.add(innerSum);
                });
                break;
            }
            case LONG: {
                long[] longValues = blockValSet.getLongValuesSV();
                sum = this.foldNotNull(length, blockValSet, null, (A acum, int from, int to) -> {
                    BigDecimal innerSum = BigDecimal.ZERO;
                    for (int i = from; i < to; ++i) {
                        innerSum = innerSum.add(BigDecimal.valueOf(longValues[i]));
                    }
                    return acum == null ? innerSum : acum.add(innerSum);
                });
                break;
            }
            case FLOAT: 
            case DOUBLE: 
            case STRING: {
                String[] stringValues = blockValSet.getStringValuesSV();
                sum = this.foldNotNull(length, blockValSet, null, (A acum, int from, int to) -> {
                    BigDecimal innerSum = BigDecimal.ZERO;
                    for (int i = from; i < to; ++i) {
                        innerSum = innerSum.add(new BigDecimal(stringValues[i]));
                    }
                    return acum == null ? innerSum : acum.add(innerSum);
                });
                break;
            }
            case BIG_DECIMAL: {
                BigDecimal[] bigDecimalValues = blockValSet.getBigDecimalValuesSV();
                sum = this.foldNotNull(length, blockValSet, null, (A acum, int from, int to) -> {
                    BigDecimal innerSum = BigDecimal.ZERO;
                    for (int i = from; i < to; ++i) {
                        innerSum = innerSum.add(bigDecimalValues[i]);
                    }
                    return acum == null ? innerSum : acum.add(innerSum);
                });
                break;
            }
            case BYTES: {
                byte[][] bytesValues = blockValSet.getBytesValuesSV();
                sum = this.foldNotNull(length, blockValSet, null, (A acum, int from, int to) -> {
                    BigDecimal innerSum = BigDecimal.ZERO;
                    for (int i = from; i < to; ++i) {
                        innerSum = innerSum.add(BigDecimalUtils.deserialize((byte[])bytesValues[i]));
                    }
                    return acum == null ? innerSum : acum.add(innerSum);
                });
                break;
            }
            default: {
                throw new IllegalStateException();
            }
        }
        this.updateAggregationResult(aggregationResultHolder, sum);
    }

    protected void updateAggregationResult(AggregationResultHolder aggregationResultHolder, BigDecimal sum) {
        if (this._nullHandlingEnabled) {
            if (sum != null) {
                BigDecimal otherSum = (BigDecimal)aggregationResultHolder.getResult();
                aggregationResultHolder.setValue(otherSum == null ? sum : sum.add(otherSum));
            }
        } else {
            BigDecimal otherSum;
            if (sum == null) {
                sum = BigDecimal.ZERO;
            }
            aggregationResultHolder.setValue((otherSum = (BigDecimal)aggregationResultHolder.getResult()) == null ? sum : sum.add(otherSum));
        }
    }

    @Override
    public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder, Map<ExpressionContext, BlockValSet> blockValSetMap) {
        BlockValSet blockValSet = blockValSetMap.get(this._expression);
        switch (blockValSet.getValueType().getStoredType()) {
            case INT: {
                int[] intValues = blockValSet.getIntValuesSV();
                this.forEachNotNull(length, blockValSet, (int from, int to) -> {
                    for (int i = from; i < to; ++i) {
                        this.updateGroupByResult(groupKeyArray[i], groupByResultHolder, BigDecimal.valueOf(intValues[i]));
                    }
                });
                break;
            }
            case LONG: {
                long[] longValues = blockValSet.getLongValuesSV();
                this.forEachNotNull(length, blockValSet, (int from, int to) -> {
                    for (int i = from; i < to; ++i) {
                        this.updateGroupByResult(groupKeyArray[i], groupByResultHolder, BigDecimal.valueOf(longValues[i]));
                    }
                });
                break;
            }
            case FLOAT: 
            case DOUBLE: 
            case STRING: {
                String[] stringValues = blockValSet.getStringValuesSV();
                this.forEachNotNull(length, blockValSet, (int from, int to) -> {
                    for (int i = from; i < to; ++i) {
                        this.updateGroupByResult(groupKeyArray[i], groupByResultHolder, new BigDecimal(stringValues[i]));
                    }
                });
                break;
            }
            case BIG_DECIMAL: {
                BigDecimal[] bigDecimalValues = blockValSet.getBigDecimalValuesSV();
                this.forEachNotNull(length, blockValSet, (int from, int to) -> {
                    for (int i = from; i < to; ++i) {
                        this.updateGroupByResult(groupKeyArray[i], groupByResultHolder, bigDecimalValues[i]);
                    }
                });
                break;
            }
            case BYTES: {
                byte[][] bytesValues = blockValSet.getBytesValuesSV();
                this.forEachNotNull(length, blockValSet, (int from, int to) -> {
                    for (int i = from; i < to; ++i) {
                        this.updateGroupByResult(groupKeyArray[i], groupByResultHolder, BigDecimalUtils.deserialize((byte[])bytesValues[i]));
                    }
                });
                break;
            }
            default: {
                throw new IllegalStateException();
            }
        }
    }

    private void updateGroupByResult(int groupKey, GroupByResultHolder groupByResultHolder, BigDecimal value) {
        BigDecimal sum = (BigDecimal)groupByResultHolder.getResult(groupKey);
        sum = sum == null ? value : sum.add(value);
        groupByResultHolder.setValueForKey(groupKey, sum);
    }

    @Override
    public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResultHolder groupByResultHolder, Map<ExpressionContext, BlockValSet> blockValSetMap) {
        BlockValSet blockValSet = blockValSetMap.get(this._expression);
        switch (blockValSet.getValueType().getStoredType()) {
            case INT: {
                int[] intValues = blockValSet.getIntValuesSV();
                this.forEachNotNull(length, blockValSet, (int from, int to) -> {
                    for (int i = from; i < to; ++i) {
                        for (int groupKey : groupKeysArray[i]) {
                            this.updateGroupByResult(groupKey, groupByResultHolder, BigDecimal.valueOf(intValues[i]));
                        }
                    }
                });
                break;
            }
            case LONG: {
                long[] longValues = blockValSet.getLongValuesSV();
                this.forEachNotNull(length, blockValSet, (int from, int to) -> {
                    for (int i = from; i < to; ++i) {
                        for (int groupKey : groupKeysArray[i]) {
                            this.updateGroupByResult(groupKey, groupByResultHolder, BigDecimal.valueOf(longValues[i]));
                        }
                    }
                });
                break;
            }
            case FLOAT: 
            case DOUBLE: 
            case STRING: {
                String[] stringValues = blockValSet.getStringValuesSV();
                this.forEachNotNull(length, blockValSet, (int from, int to) -> {
                    for (int i = from; i < to; ++i) {
                        for (int groupKey : groupKeysArray[i]) {
                            this.updateGroupByResult(groupKey, groupByResultHolder, new BigDecimal(stringValues[i]));
                        }
                    }
                });
                break;
            }
            case BIG_DECIMAL: {
                BigDecimal[] bigDecimalValues = blockValSet.getBigDecimalValuesSV();
                this.forEachNotNull(length, blockValSet, (int from, int to) -> {
                    for (int i = from; i < to; ++i) {
                        for (int groupKey : groupKeysArray[i]) {
                            this.updateGroupByResult(groupKey, groupByResultHolder, bigDecimalValues[i]);
                        }
                    }
                });
                break;
            }
            case BYTES: {
                byte[][] bytesValues = blockValSet.getBytesValuesSV();
                this.forEachNotNull(length, blockValSet, (int from, int to) -> {
                    for (int i = from; i < to; ++i) {
                        for (int groupKey : groupKeysArray[i]) {
                            this.updateGroupByResult(groupKey, groupByResultHolder, BigDecimalUtils.deserialize((byte[])bytesValues[i]));
                        }
                    }
                });
                break;
            }
            default: {
                throw new IllegalStateException();
            }
        }
    }

    @Override
    public BigDecimal extractAggregationResult(AggregationResultHolder aggregationResultHolder) {
        BigDecimal result = (BigDecimal)aggregationResultHolder.getResult();
        if (result == null) {
            return this._nullHandlingEnabled ? null : BigDecimal.ZERO;
        }
        return result;
    }

    @Override
    public BigDecimal extractGroupByResult(GroupByResultHolder groupByResultHolder, int groupKey) {
        BigDecimal result = (BigDecimal)groupByResultHolder.getResult(groupKey);
        if (result == null) {
            return this._nullHandlingEnabled ? null : BigDecimal.ZERO;
        }
        return result;
    }

    @Override
    public BigDecimal merge(BigDecimal intermediateResult1, BigDecimal intermediateResult2) {
        if (this._nullHandlingEnabled) {
            if (intermediateResult1 == null) {
                return intermediateResult2;
            }
            if (intermediateResult2 == null) {
                return intermediateResult1;
            }
        }
        return intermediateResult1.add(intermediateResult2);
    }

    @Override
    public DataSchema.ColumnDataType getIntermediateResultColumnType() {
        return DataSchema.ColumnDataType.OBJECT;
    }

    @Override
    public DataSchema.ColumnDataType getFinalResultColumnType() {
        return DataSchema.ColumnDataType.STRING;
    }

    @Override
    public BigDecimal extractFinalResult(BigDecimal intermediateResult) {
        if (intermediateResult == null) {
            return null;
        }
        if (this._precision == null) {
            return intermediateResult;
        }
        BigDecimal result = intermediateResult.round(new MathContext(this._precision, RoundingMode.HALF_EVEN));
        return this._scale == null ? result : result.setScale((int)this._scale, RoundingMode.HALF_EVEN);
    }

    @Override
    public BigDecimal mergeFinalResult(BigDecimal finalResult1, BigDecimal finalResult2) {
        return this.merge(finalResult1, finalResult2);
    }

    public BigDecimal getDefaultResult(GroupByResultHolder groupByResultHolder, int groupKey) {
        BigDecimal result = (BigDecimal)groupByResultHolder.getResult(groupKey);
        return result != null ? result : BigDecimal.ZERO;
    }
}

