/*
 * Decompiled with CFR 0.152.
 */
package io.trino.operator.aggregation;

import com.google.common.annotations.VisibleForTesting;
import io.trino.operator.aggregation.state.LongDecimalWithOverflowAndLongState;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.block.Int128ArrayBlock;
import io.trino.spi.function.AggregationFunction;
import io.trino.spi.function.AggregationState;
import io.trino.spi.function.BlockIndex;
import io.trino.spi.function.BlockPosition;
import io.trino.spi.function.CombineFunction;
import io.trino.spi.function.Description;
import io.trino.spi.function.InputFunction;
import io.trino.spi.function.LiteralParameters;
import io.trino.spi.function.OutputFunction;
import io.trino.spi.function.SqlType;
import io.trino.spi.function.TypeParameter;
import io.trino.spi.type.DecimalType;
import io.trino.spi.type.Decimals;
import io.trino.spi.type.Int128;
import io.trino.spi.type.Int128Math;
import io.trino.spi.type.Type;
import java.math.BigDecimal;
import java.math.BigInteger;
import java.math.RoundingMode;

@AggregationFunction(value="avg")
@Description(value="Calculates the average value")
public final class DecimalAverageAggregation {
    private static final BigInteger TWO = new BigInteger("2");
    private static final BigInteger OVERFLOW_MULTIPLIER = TWO.pow(128);

    private DecimalAverageAggregation() {
    }

    @InputFunction
    @LiteralParameters(value={"p", "s"})
    public static void inputShortDecimal(@AggregationState LongDecimalWithOverflowAndLongState state, @SqlType(value="decimal(p,s)") long rightLow) {
        state.addLong(1L);
        long[] decimal = state.getDecimalArray();
        int offset = state.getDecimalArrayOffset();
        long rightHigh = rightLow >> 63;
        long overflow = Int128Math.addWithOverflow((long)decimal[offset], (long)decimal[offset + 1], (long)rightHigh, (long)rightLow, (long[])decimal, (int)offset);
        state.addOverflow(overflow);
    }

    @InputFunction
    @LiteralParameters(value={"p", "s"})
    public static void inputLongDecimal(@AggregationState LongDecimalWithOverflowAndLongState state, @BlockPosition @SqlType(value="decimal(p, s)", nativeContainerType=Int128.class) Int128ArrayBlock block, @BlockIndex int position) {
        state.addLong(1L);
        long[] decimal = state.getDecimalArray();
        int offset = state.getDecimalArrayOffset();
        long rightHigh = block.getLong(position, 0);
        long rightLow = block.getLong(position, 8);
        long overflow = Int128Math.addWithOverflow((long)decimal[offset], (long)decimal[offset + 1], (long)rightHigh, (long)rightLow, (long[])decimal, (int)offset);
        state.addOverflow(overflow);
    }

    @CombineFunction
    public static void combine(@AggregationState LongDecimalWithOverflowAndLongState state, @AggregationState LongDecimalWithOverflowAndLongState otherState) {
        long[] decimal = state.getDecimalArray();
        int offset = state.getDecimalArrayOffset();
        long[] otherDecimal = otherState.getDecimalArray();
        int otherOffset = otherState.getDecimalArrayOffset();
        if (state.getLong() > 0L) {
            long overflow = Int128Math.addWithOverflow((long)decimal[offset], (long)decimal[offset + 1], (long)otherDecimal[otherOffset], (long)otherDecimal[otherOffset + 1], (long[])decimal, (int)offset);
            state.addOverflow(overflow + otherState.getOverflow());
        } else {
            decimal[offset] = otherDecimal[otherOffset];
            decimal[offset + 1] = otherDecimal[otherOffset + 1];
            state.setOverflow(otherState.getOverflow());
        }
        state.addLong(otherState.getLong());
    }

    @OutputFunction(value="decimal(p,s)")
    public static void outputDecimal(@TypeParameter(value="decimal(p,s)") Type type, @AggregationState LongDecimalWithOverflowAndLongState state, BlockBuilder out) {
        DecimalType decimalType = (DecimalType)type;
        if (state.getLong() == 0L) {
            out.appendNull();
            return;
        }
        Int128 average = DecimalAverageAggregation.average(state, decimalType);
        if (decimalType.isShort()) {
            Decimals.writeShortDecimal((BlockBuilder)out, (long)average.toLongExact());
        } else {
            type.writeObject(out, (Object)average);
        }
    }

    @VisibleForTesting
    public static Int128 average(LongDecimalWithOverflowAndLongState state, DecimalType type) {
        long[] decimal = state.getDecimalArray();
        int offset = state.getDecimalArrayOffset();
        long overflow = state.getOverflow();
        if (overflow != 0L) {
            BigDecimal sum = new BigDecimal(Int128.valueOf((long)decimal[offset], (long)decimal[offset + 1]).toBigInteger(), type.getScale());
            sum = sum.add(new BigDecimal(OVERFLOW_MULTIPLIER.multiply(BigInteger.valueOf(overflow))));
            BigDecimal count = BigDecimal.valueOf(state.getLong());
            return Decimals.encodeScaledValue((BigDecimal)sum.divide(count, type.getScale(), RoundingMode.HALF_UP), (int)type.getScale());
        }
        Int128 result = Int128Math.divideRoundUp((long)decimal[offset], (long)decimal[offset + 1], (int)0, (long)0L, (long)state.getLong(), (int)0);
        if (Decimals.overflows((Int128)result)) {
            throw new ArithmeticException("Decimal overflow");
        }
        return result;
    }
}

