/*
 * Decompiled with CFR 0.152.
 */
package io.prestosql.operator.aggregation;

import com.google.common.primitives.Ints;
import io.prestosql.block.BlockAssertions;
import io.prestosql.metadata.Metadata;
import io.prestosql.metadata.MetadataManager;
import io.prestosql.metadata.ResolvedFunction;
import io.prestosql.operator.PagesIndex;
import io.prestosql.operator.aggregation.Accumulator;
import io.prestosql.operator.aggregation.AccumulatorFactory;
import io.prestosql.operator.aggregation.AggregationTestUtils;
import io.prestosql.operator.aggregation.InternalAggregationFunction;
import io.prestosql.operator.window.PagesWindowIndex;
import io.prestosql.spi.ErrorCodeSupplier;
import io.prestosql.spi.Page;
import io.prestosql.spi.StandardErrorCode;
import io.prestosql.spi.block.Block;
import io.prestosql.spi.block.BlockBuilder;
import io.prestosql.spi.block.RunLengthEncodedBlock;
import io.prestosql.spi.function.WindowIndex;
import io.prestosql.spi.type.Type;
import io.prestosql.sql.analyzer.TypeSignatureProvider;
import io.prestosql.sql.tree.QualifiedName;
import io.prestosql.testing.assertions.PrestoExceptionAssert;
import java.util.List;
import java.util.Optional;
import org.testng.annotations.AfterClass;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;

public abstract class AbstractTestAggregationFunction {
    protected Metadata metadata;

    @BeforeClass
    public final void initTestAggregationFunction() {
        this.metadata = MetadataManager.createTestMetadataManager();
    }

    @AfterClass(alwaysRun=true)
    public final void destroyTestAggregationFunction() {
        this.metadata = null;
    }

    protected abstract Block[] getSequenceBlocks(int var1, int var2);

    protected final InternalAggregationFunction getFunction() {
        ResolvedFunction resolvedFunction = this.metadata.resolveFunction(QualifiedName.of((String)this.getFunctionName()), TypeSignatureProvider.fromTypes(this.getFunctionParameterTypes()));
        return this.metadata.getAggregateFunctionImplementation(resolvedFunction);
    }

    protected abstract String getFunctionName();

    protected abstract List<Type> getFunctionParameterTypes();

    protected abstract Object getExpectedValue(int var1, int var2);

    protected Object getExpectedValueIncludingNulls(int start, int length, int lengthIncludingNulls) {
        return this.getExpectedValue(start, length);
    }

    @Test
    public void testNoPositions() {
        this.testAggregation(this.getExpectedValue(0, 0), this.getSequenceBlocks(0, 0));
    }

    @Test
    public void testSinglePosition() {
        this.testAggregation(this.getExpectedValue(0, 1), this.getSequenceBlocks(0, 1));
    }

    @Test
    public void testMultiplePositions() {
        this.testAggregation(this.getExpectedValue(0, 5), this.getSequenceBlocks(0, 5));
    }

    @Test
    public void testAllPositionsNull() {
        List parameterTypes = this.getFunction().getParameterTypes();
        if (parameterTypes.isEmpty()) {
            return;
        }
        Block[] blocks = new Block[parameterTypes.size()];
        for (int i = 0; i < parameterTypes.size(); ++i) {
            blocks[i] = RunLengthEncodedBlock.create((Type)((Type)parameterTypes.get(0)), null, (int)10);
        }
        this.testAggregation(this.getExpectedValueIncludingNulls(0, 0, 10), blocks);
    }

    @Test
    public void testMixedNullAndNonNullPositions() {
        List parameterTypes = this.getFunction().getParameterTypes();
        if (parameterTypes.isEmpty()) {
            return;
        }
        Block[] alternatingNullsBlocks = AbstractTestAggregationFunction.createAlternatingNullsBlock(parameterTypes, this.getSequenceBlocks(0, 10));
        this.testAggregation(this.getExpectedValueIncludingNulls(0, 10, 20), alternatingNullsBlocks);
    }

    @Test
    public void testNegativeOnlyValues() {
        this.testAggregation(this.getExpectedValue(-10, 5), this.getSequenceBlocks(-10, 5));
    }

    @Test
    public void testPositiveOnlyValues() {
        this.testAggregation(this.getExpectedValue(2, 4), this.getSequenceBlocks(2, 4));
    }

    @Test
    public void testSlidingWindow() {
        int totalPositions = 12;
        int[] windowWidths = new int[totalPositions];
        Object[] expectedValues = new Object[totalPositions];
        for (int i = 0; i < totalPositions; ++i) {
            int windowWidth;
            windowWidths[i] = windowWidth = Integer.min(i, totalPositions - 1 - i);
            expectedValues[i] = this.getExpectedValue(i, windowWidth);
        }
        Page inputPage = new Page(totalPositions, this.getSequenceBlocks(0, totalPositions));
        InternalAggregationFunction function = this.getFunction();
        List channels = Ints.asList((int[])AggregationTestUtils.createArgs(function));
        AccumulatorFactory accumulatorFactory = function.bind(channels, Optional.empty());
        PagesIndex pagesIndex = new PagesIndex.TestingFactory(false).newPagesIndex(function.getParameterTypes(), totalPositions);
        pagesIndex.addPage(inputPage);
        PagesWindowIndex windowIndex = new PagesWindowIndex(pagesIndex, 0, totalPositions - 1);
        Accumulator aggregation = accumulatorFactory.createAccumulator();
        int oldStart = 0;
        int oldWidth = 0;
        for (int start = 0; start < totalPositions; ++start) {
            int width = windowWidths[start];
            if (accumulatorFactory.hasRemoveInput()) {
                for (int oldi = oldStart; oldi < oldStart + oldWidth; ++oldi) {
                    if (oldi >= start && oldi < start + width) continue;
                    aggregation.removeInput((WindowIndex)windowIndex, channels, oldi, oldi);
                }
                for (int newi = start; newi < start + width; ++newi) {
                    if (newi >= oldStart && newi < oldStart + oldWidth) continue;
                    aggregation.addInput((WindowIndex)windowIndex, channels, newi, newi);
                }
            } else {
                aggregation = accumulatorFactory.createAccumulator();
                aggregation.addInput((WindowIndex)windowIndex, channels, start, start + width - 1);
            }
            oldStart = start;
            oldWidth = width;
            Block block = AggregationTestUtils.getFinalBlock(aggregation);
            AggregationTestUtils.makeValidityAssertion(expectedValues[start]).apply(BlockAssertions.getOnlyValue(aggregation.getFinalType(), block), expectedValues[start]);
        }
    }

    protected static Block[] createAlternatingNullsBlock(List<Type> types, Block ... sequenceBlocks) {
        Block[] alternatingNullsBlocks = new Block[sequenceBlocks.length];
        for (int i = 0; i < sequenceBlocks.length; ++i) {
            int positionCount = sequenceBlocks[i].getPositionCount();
            Type type = types.get(i);
            BlockBuilder blockBuilder = type.createBlockBuilder(null, positionCount);
            for (int position = 0; position < positionCount; ++position) {
                blockBuilder.appendNull();
                type.appendTo(sequenceBlocks[i], position, blockBuilder);
            }
            alternatingNullsBlocks[i] = blockBuilder.build();
        }
        return alternatingNullsBlocks;
    }

    protected void testAggregation(Object expectedValue, Block ... blocks) {
        AggregationTestUtils.assertAggregation(this.getFunction(), expectedValue, blocks);
    }

    protected void assertInvalidAggregation(Runnable runnable) {
        PrestoExceptionAssert.assertPrestoExceptionThrownBy(runnable::run).hasErrorCode((ErrorCodeSupplier)StandardErrorCode.INVALID_FUNCTION_ARGUMENT);
    }
}

