/*
 * Decompiled with CFR 0.152.
 */
package org.apache.pinot.core.query.aggregation.function;

import java.util.List;
import java.util.Map;
import org.apache.pinot.common.request.context.ExpressionContext;
import org.apache.pinot.common.utils.DataSchema;
import org.apache.pinot.core.common.BlockValSet;
import org.apache.pinot.core.common.ObjectSerDeUtils;
import org.apache.pinot.core.query.aggregation.AggregationResultHolder;
import org.apache.pinot.core.query.aggregation.ObjectAggregationResultHolder;
import org.apache.pinot.core.query.aggregation.function.NullableSingleInputAggregationFunction;
import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
import org.apache.pinot.core.query.aggregation.groupby.ObjectGroupByResultHolder;
import org.apache.pinot.segment.local.customobject.AvgPair;
import org.apache.pinot.segment.spi.AggregationFunctionType;
import org.apache.pinot.spi.data.FieldSpec;

public class AvgAggregationFunction
extends NullableSingleInputAggregationFunction<AvgPair, Double> {
    private static final double DEFAULT_FINAL_RESULT = Double.NEGATIVE_INFINITY;

    public AvgAggregationFunction(List<ExpressionContext> arguments, boolean nullHandlingEnabled) {
        this(AvgAggregationFunction.verifySingleArgument(arguments, "AVG"), nullHandlingEnabled);
    }

    protected AvgAggregationFunction(ExpressionContext expression, boolean nullHandlingEnabled) {
        super(expression, nullHandlingEnabled);
    }

    @Override
    public AggregationFunctionType getType() {
        return AggregationFunctionType.AVG;
    }

    @Override
    public AggregationResultHolder createAggregationResultHolder() {
        return new ObjectAggregationResultHolder();
    }

    @Override
    public GroupByResultHolder createGroupByResultHolder(int initialCapacity, int maxCapacity) {
        return new ObjectGroupByResultHolder(initialCapacity, maxCapacity);
    }

    @Override
    public void aggregate(int length, AggregationResultHolder aggregationResultHolder, Map<ExpressionContext, BlockValSet> blockValSetMap) {
        BlockValSet blockValSet = blockValSetMap.get(this._expression);
        if (blockValSet.getValueType() != FieldSpec.DataType.BYTES) {
            double[] doubleValues = blockValSet.getDoubleValuesSV();
            AvgPair avgPair = new AvgPair();
            this.forEachNotNull(length, blockValSet, (int from, int to) -> {
                for (int i = from; i < to; ++i) {
                    avgPair.apply(doubleValues[i], 1L);
                }
            });
            if (avgPair.getCount() != 0L) {
                this.updateAggregationResult(aggregationResultHolder, avgPair.getSum(), avgPair.getCount());
            }
        } else {
            byte[][] bytesValues = blockValSet.getBytesValuesSV();
            AvgPair avgPair = new AvgPair();
            this.forEachNotNull(length, blockValSet, (int from, int to) -> {
                for (int i = from; i < to; ++i) {
                    AvgPair value = ObjectSerDeUtils.AVG_PAIR_SER_DE.deserialize(bytesValues[i]);
                    avgPair.apply(value);
                }
            });
            if (avgPair.getCount() != 0L) {
                this.updateAggregationResult(aggregationResultHolder, avgPair.getSum(), avgPair.getCount());
            }
        }
    }

    protected void updateAggregationResult(AggregationResultHolder aggregationResultHolder, double sum, long count) {
        AvgPair avgPair = (AvgPair)aggregationResultHolder.getResult();
        if (avgPair == null) {
            aggregationResultHolder.setValue(new AvgPair(sum, count));
        } else {
            avgPair.apply(sum, count);
        }
    }

    @Override
    public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder, Map<ExpressionContext, BlockValSet> blockValSetMap) {
        BlockValSet blockValSet = blockValSetMap.get(this._expression);
        if (blockValSet.getValueType() != FieldSpec.DataType.BYTES) {
            double[] doubleValues = blockValSet.getDoubleValuesSV();
            this.forEachNotNull(length, blockValSet, (int from, int to) -> {
                for (int i = from; i < to; ++i) {
                    this.updateGroupByResult(groupKeyArray[i], groupByResultHolder, doubleValues[i], 1L);
                }
            });
        } else {
            byte[][] bytesValues = blockValSet.getBytesValuesSV();
            this.forEachNotNull(length, blockValSet, (int from, int to) -> {
                for (int i = from; i < to; ++i) {
                    AvgPair avgPair = ObjectSerDeUtils.AVG_PAIR_SER_DE.deserialize(bytesValues[i]);
                    this.updateGroupByResult(groupKeyArray[i], groupByResultHolder, avgPair.getSum(), avgPair.getCount());
                }
            });
        }
    }

    @Override
    public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResultHolder groupByResultHolder, Map<ExpressionContext, BlockValSet> blockValSetMap) {
        BlockValSet blockValSet = blockValSetMap.get(this._expression);
        if (blockValSet.getValueType() != FieldSpec.DataType.BYTES) {
            double[] doubleValues = blockValSet.getDoubleValuesSV();
            this.forEachNotNull(length, blockValSet, (int from, int to) -> {
                for (int i = from; i < to; ++i) {
                    for (int groupKey : groupKeysArray[i]) {
                        this.updateGroupByResult(groupKey, groupByResultHolder, doubleValues[i], 1L);
                    }
                }
            });
        } else {
            byte[][] bytesValues = blockValSet.getBytesValuesSV();
            this.forEachNotNull(length, blockValSet, (int from, int to) -> {
                for (int i = from; i < to; ++i) {
                    AvgPair avgPair = ObjectSerDeUtils.AVG_PAIR_SER_DE.deserialize(bytesValues[i]);
                    for (int groupKey : groupKeysArray[i]) {
                        this.updateGroupByResult(groupKey, groupByResultHolder, avgPair.getSum(), avgPair.getCount());
                    }
                }
            });
        }
    }

    protected void updateGroupByResult(int groupKey, GroupByResultHolder groupByResultHolder, double sum, long count) {
        AvgPair avgPair = (AvgPair)groupByResultHolder.getResult(groupKey);
        if (avgPair == null) {
            groupByResultHolder.setValueForKey(groupKey, new AvgPair(sum, count));
        } else {
            avgPair.apply(sum, count);
        }
    }

    @Override
    public AvgPair extractAggregationResult(AggregationResultHolder aggregationResultHolder) {
        AvgPair avgPair = (AvgPair)aggregationResultHolder.getResult();
        if (avgPair == null) {
            return this._nullHandlingEnabled ? null : new AvgPair();
        }
        return avgPair;
    }

    @Override
    public AvgPair extractGroupByResult(GroupByResultHolder groupByResultHolder, int groupKey) {
        AvgPair avgPair = (AvgPair)groupByResultHolder.getResult(groupKey);
        if (avgPair == null) {
            return this._nullHandlingEnabled ? null : new AvgPair();
        }
        return avgPair;
    }

    @Override
    public AvgPair merge(AvgPair intermediateResult1, AvgPair intermediateResult2) {
        if (this._nullHandlingEnabled) {
            if (intermediateResult1 == null) {
                return intermediateResult2;
            }
            if (intermediateResult2 == null) {
                return intermediateResult1;
            }
        }
        intermediateResult1.apply(intermediateResult2);
        return intermediateResult1;
    }

    @Override
    public DataSchema.ColumnDataType getIntermediateResultColumnType() {
        return DataSchema.ColumnDataType.OBJECT;
    }

    @Override
    public DataSchema.ColumnDataType getFinalResultColumnType() {
        return DataSchema.ColumnDataType.DOUBLE;
    }

    @Override
    public Double extractFinalResult(AvgPair intermediateResult) {
        if (intermediateResult == null) {
            return null;
        }
        long count = intermediateResult.getCount();
        if (count == 0L) {
            return Double.NEGATIVE_INFINITY;
        }
        return intermediateResult.getSum() / (double)count;
    }
}

