/*
 * Decompiled with CFR 0.152.
 */
package io.trino.operator.aggregation;

import io.trino.operator.aggregation.state.LongDecimalWithOverflowState;
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.function.AggregationFunction;
import io.trino.spi.function.AggregationState;
import io.trino.spi.function.BlockIndex;
import io.trino.spi.function.BlockPosition;
import io.trino.spi.function.CombineFunction;
import io.trino.spi.function.Description;
import io.trino.spi.function.InputFunction;
import io.trino.spi.function.LiteralParameters;
import io.trino.spi.function.OutputFunction;
import io.trino.spi.function.SqlType;
import io.trino.spi.type.Decimals;
import io.trino.spi.type.Int128;
import io.trino.spi.type.Int128Math;

@AggregationFunction(value="sum")
@Description(value="Calculates the sum over the input values")
public final class DecimalSumAggregation {
    private DecimalSumAggregation() {
    }

    @InputFunction
    @LiteralParameters(value={"p", "s"})
    public static void inputShortDecimal(@AggregationState LongDecimalWithOverflowState state, @SqlType(value="decimal(p,s)") long rightLow) {
        state.setNotNull();
        long[] decimal = state.getDecimalArray();
        int offset = state.getDecimalArrayOffset();
        long rightHigh = rightLow >> 63;
        long overflow = Int128Math.addWithOverflow((long)decimal[offset], (long)decimal[offset + 1], (long)rightHigh, (long)rightLow, (long[])decimal, (int)offset);
        state.setOverflow(Math.addExact(overflow, state.getOverflow()));
    }

    @InputFunction
    @LiteralParameters(value={"p", "s"})
    public static void inputLongDecimal(@AggregationState LongDecimalWithOverflowState state, @BlockPosition @SqlType(value="decimal(p,s)", nativeContainerType=Int128.class) Block block, @BlockIndex int position) {
        state.setNotNull();
        long[] decimal = state.getDecimalArray();
        int offset = state.getDecimalArrayOffset();
        long rightHigh = block.getLong(position, 0);
        long rightLow = block.getLong(position, 8);
        long overflow = Int128Math.addWithOverflow((long)decimal[offset], (long)decimal[offset + 1], (long)rightHigh, (long)rightLow, (long[])decimal, (int)offset);
        state.addOverflow(overflow);
    }

    @CombineFunction
    public static void combine(@AggregationState LongDecimalWithOverflowState state, @AggregationState LongDecimalWithOverflowState otherState) {
        long[] decimal = state.getDecimalArray();
        int offset = state.getDecimalArrayOffset();
        long[] otherDecimal = otherState.getDecimalArray();
        int otherOffset = otherState.getDecimalArrayOffset();
        if (state.isNotNull()) {
            long overflow = Int128Math.addWithOverflow((long)decimal[offset], (long)decimal[offset + 1], (long)otherDecimal[otherOffset], (long)otherDecimal[otherOffset + 1], (long[])decimal, (int)offset);
            state.addOverflow(Math.addExact(overflow, otherState.getOverflow()));
        } else {
            state.setNotNull();
            decimal[offset] = otherDecimal[otherOffset];
            decimal[offset + 1] = otherDecimal[otherOffset + 1];
            state.setOverflow(otherState.getOverflow());
        }
    }

    @OutputFunction(value="decimal(38,s)")
    public static void outputDecimal(@AggregationState LongDecimalWithOverflowState state, BlockBuilder out) {
        if (state.isNotNull()) {
            if (state.getOverflow() != 0L) {
                throw new ArithmeticException("Decimal overflow");
            }
            long[] decimal = state.getDecimalArray();
            int offset = state.getDecimalArrayOffset();
            long rawHigh = decimal[offset];
            long rawLow = decimal[offset + 1];
            Decimals.throwIfOverflows((long)rawHigh, (long)rawLow);
            out.writeLong(rawHigh);
            out.writeLong(rawLow);
            out.closeEntry();
        } else {
            out.appendNull();
        }
    }
}

