/*
 * Decompiled with CFR 0.152.
 */
package io.trino.operator.aggregation;

import com.google.common.collect.ImmutableList;
import io.trino.block.BlockAssertions;
import io.trino.metadata.FunctionBundle;
import io.trino.metadata.ResolvedFunction;
import io.trino.metadata.TestingFunctionResolution;
import io.trino.operator.PagesIndex;
import io.trino.operator.aggregation.AccumulatorCompiler;
import io.trino.operator.aggregation.AggregationTestUtils;
import io.trino.operator.aggregation.WindowAccumulator;
import io.trino.operator.window.PagesWindowIndex;
import io.trino.spi.ErrorCodeSupplier;
import io.trino.spi.Page;
import io.trino.spi.StandardErrorCode;
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.block.RunLengthEncodedBlock;
import io.trino.spi.function.AggregationImplementation;
import io.trino.spi.function.BoundSignature;
import io.trino.spi.function.FunctionNullability;
import io.trino.spi.function.WindowIndex;
import io.trino.spi.type.Type;
import io.trino.sql.analyzer.TypeSignatureProvider;
import io.trino.testing.assertions.TrinoExceptionAssert;
import java.lang.reflect.Constructor;
import java.util.List;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;

public abstract class AbstractTestAggregationFunction {
    protected final TestingFunctionResolution functionResolution;

    protected AbstractTestAggregationFunction() {
        this.functionResolution = new TestingFunctionResolution();
    }

    protected AbstractTestAggregationFunction(FunctionBundle functions) {
        this.functionResolution = new TestingFunctionResolution(functions);
    }

    protected abstract Block[] getSequenceBlocks(int var1, int var2);

    protected abstract String getFunctionName();

    protected abstract List<Type> getFunctionParameterTypes();

    protected abstract Object getExpectedValue(int var1, int var2);

    protected Object getExpectedValueIncludingNulls(int start, int length, int lengthIncludingNulls) {
        return this.getExpectedValue(start, length);
    }

    @Test
    public void testNoPositions() {
        this.testAggregation(this.getExpectedValue(0, 0), this.getSequenceBlocks(0, 0));
    }

    @Test
    public void testSinglePosition() {
        this.testAggregation(this.getExpectedValue(0, 1), this.getSequenceBlocks(0, 1));
    }

    @Test
    public void testMultiplePositions() {
        this.testAggregation(this.getExpectedValue(0, 5), this.getSequenceBlocks(0, 5));
    }

    @Test
    public void testAllPositionsNull() {
        List<Type> parameterTypes = this.getFunctionParameterTypes();
        if (parameterTypes.isEmpty()) {
            return;
        }
        Block[] blocks = new Block[parameterTypes.size()];
        for (int i = 0; i < parameterTypes.size(); ++i) {
            blocks[i] = RunLengthEncodedBlock.create((Type)parameterTypes.get(0), null, (int)10);
        }
        this.testAggregation(this.getExpectedValueIncludingNulls(0, 0, 10), blocks);
    }

    @Test
    public void testMixedNullAndNonNullPositions() {
        List<Type> parameterTypes = this.getFunctionParameterTypes();
        if (parameterTypes.isEmpty()) {
            return;
        }
        Block[] alternatingNullsBlocks = AbstractTestAggregationFunction.createAlternatingNullsBlock(parameterTypes, this.getSequenceBlocks(0, 10));
        this.testAggregation(this.getExpectedValueIncludingNulls(0, 10, 20), alternatingNullsBlocks);
    }

    @Test
    public void testNegativeOnlyValues() {
        this.testAggregation(this.getExpectedValue(-10, 5), this.getSequenceBlocks(-10, 5));
    }

    @Test
    public void testPositiveOnlyValues() {
        this.testAggregation(this.getExpectedValue(2, 4), this.getSequenceBlocks(2, 4));
    }

    @Test
    public void testSlidingWindow() {
        int totalPositions = 12;
        int[] windowWidths = new int[totalPositions];
        Object[] expectedValues = new Object[totalPositions];
        for (int i = 0; i < totalPositions; ++i) {
            int windowWidth;
            windowWidths[i] = windowWidth = Integer.min(i, totalPositions - 1 - i);
            expectedValues[i] = this.getExpectedValue(i, windowWidth);
        }
        Page inputPage = new Page(totalPositions, this.getSequenceBlocks(0, totalPositions));
        PagesIndex pagesIndex = new PagesIndex.TestingFactory(false).newPagesIndex(this.getFunctionParameterTypes(), totalPositions);
        pagesIndex.addPage(inputPage);
        PagesWindowIndex windowIndex = new PagesWindowIndex(pagesIndex, 0, totalPositions - 1);
        ResolvedFunction resolvedFunction = this.functionResolution.resolveFunction(this.getFunctionName(), TypeSignatureProvider.fromTypes(this.getFunctionParameterTypes()));
        AggregationImplementation aggregationImplementation = this.functionResolution.getPlannerContext().getFunctionManager().getAggregationImplementation(resolvedFunction);
        WindowAccumulator aggregation = AbstractTestAggregationFunction.createWindowAccumulator(resolvedFunction, aggregationImplementation);
        int oldStart = 0;
        int oldWidth = 0;
        for (int start = 0; start < totalPositions; ++start) {
            int width = windowWidths[start];
            if (aggregationImplementation.getRemoveInputFunction().isPresent()) {
                for (int oldi = oldStart; oldi < oldStart + oldWidth; ++oldi) {
                    if (oldi >= start && oldi < start + width) continue;
                    aggregation.removeInput((WindowIndex)windowIndex, oldi, oldi);
                }
                for (int newi = start; newi < start + width; ++newi) {
                    if (newi >= oldStart && newi < oldStart + oldWidth) continue;
                    aggregation.addInput((WindowIndex)windowIndex, newi, newi);
                }
            } else {
                aggregation = AbstractTestAggregationFunction.createWindowAccumulator(resolvedFunction, aggregationImplementation);
                aggregation.addInput((WindowIndex)windowIndex, start, start + width - 1);
            }
            oldStart = start;
            oldWidth = width;
            Type outputType = resolvedFunction.getSignature().getReturnType();
            BlockBuilder blockBuilder = outputType.createBlockBuilder(null, 1000);
            aggregation.evaluateFinal(blockBuilder);
            Block block = blockBuilder.build();
            Assertions.assertThat((Boolean)AggregationTestUtils.makeValidityAssertion(expectedValues[start]).apply(BlockAssertions.getOnlyValue(outputType, block), expectedValues[start])).isTrue();
        }
    }

    private static WindowAccumulator createWindowAccumulator(ResolvedFunction resolvedFunction, AggregationImplementation aggregationImplementation) {
        try {
            Constructor constructor = AccumulatorCompiler.generateWindowAccumulatorClass((BoundSignature)resolvedFunction.getSignature(), (AggregationImplementation)aggregationImplementation, (FunctionNullability)resolvedFunction.getFunctionNullability());
            return (WindowAccumulator)constructor.newInstance(ImmutableList.of());
        }
        catch (ReflectiveOperationException e) {
            throw new RuntimeException(e);
        }
    }

    protected static Block[] createAlternatingNullsBlock(List<Type> types, Block ... sequenceBlocks) {
        Block[] alternatingNullsBlocks = new Block[sequenceBlocks.length];
        for (int i = 0; i < sequenceBlocks.length; ++i) {
            int positionCount = sequenceBlocks[i].getPositionCount();
            Type type = types.get(i);
            BlockBuilder blockBuilder = type.createBlockBuilder(null, positionCount);
            for (int position = 0; position < positionCount; ++position) {
                blockBuilder.appendNull();
                type.appendTo(sequenceBlocks[i], position, blockBuilder);
            }
            alternatingNullsBlocks[i] = blockBuilder.build();
        }
        return alternatingNullsBlocks;
    }

    protected void testAggregation(Object expectedValue, Block ... blocks) {
        AggregationTestUtils.assertAggregation(this.functionResolution, this.getFunctionName(), (List<TypeSignatureProvider>)TypeSignatureProvider.fromTypes(this.getFunctionParameterTypes()), expectedValue, blocks);
    }

    protected void assertInvalidAggregation(Runnable runnable) {
        TrinoExceptionAssert.assertTrinoExceptionThrownBy(runnable::run).hasErrorCode(new ErrorCodeSupplier[]{StandardErrorCode.INVALID_FUNCTION_ARGUMENT});
    }
}

