/*
 * Decompiled with CFR 0.152.
 */
package io.trino.operator.aggregation;

import io.airlift.slice.Slice;
import io.trino.operator.aggregation.DecimalAverageAggregation;
import io.trino.operator.aggregation.state.LongDecimalWithOverflowAndLongState;
import io.trino.operator.aggregation.state.LongDecimalWithOverflowAndLongStateFactory;
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.type.DecimalType;
import io.trino.spi.type.Type;
import io.trino.spi.type.UnscaledDecimal128Arithmetic;
import java.math.BigDecimal;
import java.math.BigInteger;
import org.testng.Assert;
import org.testng.annotations.BeforeMethod;
import org.testng.annotations.Test;

@Test(singleThreaded=true)
public class TestDecimalAverageAggregation {
    private static final BigInteger TWO = new BigInteger("2");
    private static final DecimalType TYPE = DecimalType.createDecimalType((int)38, (int)0);
    private LongDecimalWithOverflowAndLongState state;

    @BeforeMethod
    public void setUp() {
        this.state = new LongDecimalWithOverflowAndLongStateFactory().createSingleState();
    }

    @Test
    public void testOverflow() {
        TestDecimalAverageAggregation.addToState(this.state, TWO.pow(126));
        Assert.assertEquals((long)this.state.getLong(), (long)1L);
        Assert.assertEquals((long)this.state.getOverflow(), (long)0L);
        Assert.assertEquals((Object)this.state.getLongDecimal(), (Object)UnscaledDecimal128Arithmetic.unscaledDecimal((BigInteger)TWO.pow(126)));
        TestDecimalAverageAggregation.addToState(this.state, TWO.pow(126));
        Assert.assertEquals((long)this.state.getLong(), (long)2L);
        Assert.assertEquals((long)this.state.getOverflow(), (long)1L);
        Assert.assertEquals((Object)this.state.getLongDecimal(), (Object)UnscaledDecimal128Arithmetic.unscaledDecimal((long)0L));
        Assert.assertEquals((Object)DecimalAverageAggregation.average((LongDecimalWithOverflowAndLongState)this.state, (DecimalType)TYPE), (Object)new BigDecimal(TWO.pow(126)));
    }

    @Test
    public void testUnderflow() {
        TestDecimalAverageAggregation.addToState(this.state, TWO.pow(126).negate());
        Assert.assertEquals((long)this.state.getLong(), (long)1L);
        Assert.assertEquals((long)this.state.getOverflow(), (long)0L);
        Assert.assertEquals((Object)this.state.getLongDecimal(), (Object)UnscaledDecimal128Arithmetic.unscaledDecimal((BigInteger)TWO.pow(126).negate()));
        TestDecimalAverageAggregation.addToState(this.state, TWO.pow(126).negate());
        Assert.assertEquals((long)this.state.getLong(), (long)2L);
        Assert.assertEquals((long)this.state.getOverflow(), (long)-1L);
        Assert.assertEquals((int)UnscaledDecimal128Arithmetic.compare((Slice)this.state.getLongDecimal(), (Slice)UnscaledDecimal128Arithmetic.unscaledDecimal((long)0L)), (int)0);
        Assert.assertEquals((Object)DecimalAverageAggregation.average((LongDecimalWithOverflowAndLongState)this.state, (DecimalType)TYPE), (Object)new BigDecimal(TWO.pow(126).negate()));
    }

    @Test
    public void testUnderflowAfterOverflow() {
        TestDecimalAverageAggregation.addToState(this.state, TWO.pow(126));
        TestDecimalAverageAggregation.addToState(this.state, TWO.pow(126));
        TestDecimalAverageAggregation.addToState(this.state, TWO.pow(125));
        Assert.assertEquals((long)this.state.getOverflow(), (long)1L);
        Assert.assertEquals((Object)this.state.getLongDecimal(), (Object)UnscaledDecimal128Arithmetic.unscaledDecimal((BigInteger)TWO.pow(125)));
        TestDecimalAverageAggregation.addToState(this.state, TWO.pow(126).negate());
        TestDecimalAverageAggregation.addToState(this.state, TWO.pow(126).negate());
        TestDecimalAverageAggregation.addToState(this.state, TWO.pow(126).negate());
        Assert.assertEquals((long)this.state.getOverflow(), (long)0L);
        Assert.assertEquals((Object)this.state.getLongDecimal(), (Object)UnscaledDecimal128Arithmetic.unscaledDecimal((BigInteger)TWO.pow(125).negate()));
        Assert.assertEquals((Object)DecimalAverageAggregation.average((LongDecimalWithOverflowAndLongState)this.state, (DecimalType)TYPE), (Object)new BigDecimal(TWO.pow(125).negate().divide(BigInteger.valueOf(6L))));
    }

    @Test
    public void testCombineOverflow() {
        TestDecimalAverageAggregation.addToState(this.state, TWO.pow(125));
        TestDecimalAverageAggregation.addToState(this.state, TWO.pow(126));
        LongDecimalWithOverflowAndLongState otherState = new LongDecimalWithOverflowAndLongStateFactory().createSingleState();
        TestDecimalAverageAggregation.addToState(otherState, TWO.pow(125));
        TestDecimalAverageAggregation.addToState(otherState, TWO.pow(126));
        DecimalAverageAggregation.combine((LongDecimalWithOverflowAndLongState)this.state, (LongDecimalWithOverflowAndLongState)otherState);
        Assert.assertEquals((long)this.state.getLong(), (long)4L);
        Assert.assertEquals((long)this.state.getOverflow(), (long)1L);
        Assert.assertEquals((Object)this.state.getLongDecimal(), (Object)UnscaledDecimal128Arithmetic.unscaledDecimal((BigInteger)TWO.pow(126)));
        BigInteger expectedAverage = BigInteger.ZERO.add(TWO.pow(126)).add(TWO.pow(126)).add(TWO.pow(125)).add(TWO.pow(125)).divide(BigInteger.valueOf(4L));
        Assert.assertEquals((Object)DecimalAverageAggregation.average((LongDecimalWithOverflowAndLongState)this.state, (DecimalType)TYPE), (Object)new BigDecimal(expectedAverage));
    }

    @Test
    public void testCombineUnderflow() {
        TestDecimalAverageAggregation.addToState(this.state, TWO.pow(125).negate());
        TestDecimalAverageAggregation.addToState(this.state, TWO.pow(126).negate());
        LongDecimalWithOverflowAndLongState otherState = new LongDecimalWithOverflowAndLongStateFactory().createSingleState();
        TestDecimalAverageAggregation.addToState(otherState, TWO.pow(125).negate());
        TestDecimalAverageAggregation.addToState(otherState, TWO.pow(126).negate());
        DecimalAverageAggregation.combine((LongDecimalWithOverflowAndLongState)this.state, (LongDecimalWithOverflowAndLongState)otherState);
        Assert.assertEquals((long)this.state.getLong(), (long)4L);
        Assert.assertEquals((long)this.state.getOverflow(), (long)-1L);
        Assert.assertEquals((Object)this.state.getLongDecimal(), (Object)UnscaledDecimal128Arithmetic.unscaledDecimal((BigInteger)TWO.pow(126).negate()));
        BigInteger expectedAverage = BigInteger.ZERO.add(TWO.pow(126)).add(TWO.pow(126)).add(TWO.pow(125)).add(TWO.pow(125)).negate().divide(BigInteger.valueOf(4L));
        Assert.assertEquals((Object)DecimalAverageAggregation.average((LongDecimalWithOverflowAndLongState)this.state, (DecimalType)TYPE), (Object)new BigDecimal(expectedAverage));
    }

    private static void addToState(LongDecimalWithOverflowAndLongState state, BigInteger value) {
        BlockBuilder blockBuilder = TYPE.createFixedSizeBlockBuilder(1);
        TYPE.writeSlice(blockBuilder, UnscaledDecimal128Arithmetic.unscaledDecimal((BigInteger)value));
        DecimalAverageAggregation.inputLongDecimal((Type)TYPE, (LongDecimalWithOverflowAndLongState)state, (Block)blockBuilder.build(), (int)0);
    }
}

