/*
 * Decompiled with CFR 0.152.
 */
package io.trino.operator.aggregation;

import io.trino.operator.aggregation.DecimalSumAggregation;
import io.trino.operator.aggregation.state.LongDecimalWithOverflowState;
import io.trino.operator.aggregation.state.LongDecimalWithOverflowStateFactory;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.block.Int128ArrayBlock;
import io.trino.spi.block.VariableWidthBlockBuilder;
import io.trino.spi.type.DecimalType;
import io.trino.spi.type.Int128;
import java.math.BigInteger;
import org.assertj.core.api.AbstractThrowableAssert;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;
import org.testng.Assert;

public class TestDecimalSumAggregation {
    private static final BigInteger TWO = new BigInteger("2");
    private static final DecimalType TYPE = DecimalType.createDecimalType((int)38, (int)0);

    @Test
    public void testOverflow() {
        LongDecimalWithOverflowState state = new LongDecimalWithOverflowStateFactory().createSingleState();
        TestDecimalSumAggregation.addToState(state, TWO.pow(126));
        Assert.assertEquals((long)state.getOverflow(), (long)0L);
        Assert.assertEquals((Object)this.getDecimal(state), (Object)Int128.valueOf((BigInteger)TWO.pow(126)));
        TestDecimalSumAggregation.addToState(state, TWO.pow(126));
        Assert.assertEquals((long)state.getOverflow(), (long)1L);
        Assert.assertEquals((Object)this.getDecimal(state), (Object)Int128.valueOf((long)Long.MIN_VALUE, (long)0L));
    }

    @Test
    public void testUnderflow() {
        LongDecimalWithOverflowState state = new LongDecimalWithOverflowStateFactory().createSingleState();
        TestDecimalSumAggregation.addToState(state, TWO.pow(126).negate());
        Assert.assertEquals((long)state.getOverflow(), (long)0L);
        Assert.assertEquals((Object)this.getDecimal(state), (Object)Int128.valueOf((BigInteger)TWO.pow(126).negate()));
        TestDecimalSumAggregation.addToState(state, TWO.pow(126).negate());
        Assert.assertEquals((long)state.getOverflow(), (long)0L);
        Assert.assertEquals((Object)this.getDecimal(state), (Object)Int128.valueOf((long)Long.MIN_VALUE, (long)0L));
    }

    @Test
    public void testUnderflowAfterOverflow() {
        LongDecimalWithOverflowState state = new LongDecimalWithOverflowStateFactory().createSingleState();
        TestDecimalSumAggregation.addToState(state, TWO.pow(126));
        TestDecimalSumAggregation.addToState(state, TWO.pow(126));
        TestDecimalSumAggregation.addToState(state, TWO.pow(125));
        Assert.assertEquals((long)state.getOverflow(), (long)1L);
        Assert.assertEquals((Object)this.getDecimal(state), (Object)Int128.valueOf((long)-6917529027641081856L, (long)0L));
        TestDecimalSumAggregation.addToState(state, TWO.pow(126).negate());
        TestDecimalSumAggregation.addToState(state, TWO.pow(126).negate());
        TestDecimalSumAggregation.addToState(state, TWO.pow(126).negate());
        Assert.assertEquals((long)state.getOverflow(), (long)0L);
        Assert.assertEquals((Object)this.getDecimal(state), (Object)Int128.valueOf((BigInteger)TWO.pow(125).negate()));
    }

    @Test
    public void testCombineOverflow() {
        LongDecimalWithOverflowState state = new LongDecimalWithOverflowStateFactory().createSingleState();
        TestDecimalSumAggregation.addToState(state, TWO.pow(125));
        TestDecimalSumAggregation.addToState(state, TWO.pow(126));
        LongDecimalWithOverflowState otherState = new LongDecimalWithOverflowStateFactory().createSingleState();
        TestDecimalSumAggregation.addToState(otherState, TWO.pow(125));
        TestDecimalSumAggregation.addToState(otherState, TWO.pow(126));
        DecimalSumAggregation.combine((LongDecimalWithOverflowState)state, (LongDecimalWithOverflowState)otherState);
        Assert.assertEquals((long)state.getOverflow(), (long)1L);
        Assert.assertEquals((Object)this.getDecimal(state), (Object)Int128.valueOf((long)-4611686018427387904L, (long)0L));
    }

    @Test
    public void testCombineUnderflow() {
        LongDecimalWithOverflowState state = new LongDecimalWithOverflowStateFactory().createSingleState();
        TestDecimalSumAggregation.addToState(state, TWO.pow(125).negate());
        TestDecimalSumAggregation.addToState(state, TWO.pow(126).negate());
        LongDecimalWithOverflowState otherState = new LongDecimalWithOverflowStateFactory().createSingleState();
        TestDecimalSumAggregation.addToState(otherState, TWO.pow(125).negate());
        TestDecimalSumAggregation.addToState(otherState, TWO.pow(126).negate());
        DecimalSumAggregation.combine((LongDecimalWithOverflowState)state, (LongDecimalWithOverflowState)otherState);
        Assert.assertEquals((long)state.getOverflow(), (long)-1L);
        Assert.assertEquals((Object)this.getDecimal(state), (Object)Int128.valueOf((long)0x4000000000000000L, (long)0L));
    }

    @Test
    public void testOverflowOnOutput() {
        LongDecimalWithOverflowState state = new LongDecimalWithOverflowStateFactory().createSingleState();
        TestDecimalSumAggregation.addToState(state, TWO.pow(126));
        TestDecimalSumAggregation.addToState(state, TWO.pow(126));
        Assert.assertEquals((long)state.getOverflow(), (long)1L);
        ((AbstractThrowableAssert)Assertions.assertThatThrownBy(() -> DecimalSumAggregation.outputDecimal((LongDecimalWithOverflowState)state, (BlockBuilder)new VariableWidthBlockBuilder(null, 10, 100))).isInstanceOf(ArithmeticException.class)).hasMessage("Decimal overflow");
    }

    private static void addToState(LongDecimalWithOverflowState state, BigInteger value) {
        if (TYPE.isShort()) {
            DecimalSumAggregation.inputShortDecimal((LongDecimalWithOverflowState)state, (long)Int128.valueOf((BigInteger)value).toLongExact());
        } else {
            BlockBuilder blockBuilder = TYPE.createFixedSizeBlockBuilder(1);
            TYPE.writeObject(blockBuilder, (Object)Int128.valueOf((BigInteger)value));
            DecimalSumAggregation.inputLongDecimal((LongDecimalWithOverflowState)state, (Int128ArrayBlock)((Int128ArrayBlock)blockBuilder.buildValueBlock()), (int)0);
        }
    }

    private Int128 getDecimal(LongDecimalWithOverflowState state) {
        long[] decimal = state.getDecimalArray();
        int offset = state.getDecimalArrayOffset();
        return Int128.valueOf((long)decimal[offset], (long)decimal[offset + 1]);
    }
}

