/*
 * Decompiled with CFR 0.152.
 */
package io.trino.operator.aggregation;

import com.google.common.collect.ImmutableList;
import com.google.common.primitives.Ints;
import io.trino.block.BlockAssertions;
import io.trino.metadata.TestingFunctionResolution;
import io.trino.operator.aggregation.Aggregator;
import io.trino.operator.aggregation.AggregatorFactory;
import io.trino.operator.aggregation.GroupedAggregator;
import io.trino.operator.aggregation.TestingAggregationFunction;
import io.trino.spi.Page;
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.block.RunLengthEncodedBlock;
import io.trino.spi.type.BooleanType;
import io.trino.spi.type.Type;
import io.trino.sql.analyzer.TypeSignatureProvider;
import io.trino.sql.planner.plan.AggregationNode;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.OptionalInt;
import java.util.function.BiFunction;
import java.util.stream.IntStream;
import org.apache.commons.math3.util.Precision;
import org.testng.Assert;

public final class AggregationTestUtils {
    private AggregationTestUtils() {
    }

    public static void assertAggregation(TestingFunctionResolution functionResolution, String name, List<TypeSignatureProvider> parameterTypes, Object expectedValue, Block ... blocks) {
        AggregationTestUtils.assertAggregation(functionResolution, name, parameterTypes, expectedValue, new Page(blocks));
    }

    public static void assertAggregation(TestingFunctionResolution functionResolution, String name, List<TypeSignatureProvider> parameterTypes, Object expectedValue, Page page) {
        BiFunction<Object, Object, Boolean> equalAssertion = AggregationTestUtils.makeValidityAssertion(expectedValue);
        AggregationTestUtils.assertAggregation(functionResolution, name, parameterTypes, equalAssertion, null, page, expectedValue);
    }

    public static BiFunction<Object, Object, Boolean> makeValidityAssertion(Object expectedValue) {
        if (expectedValue instanceof Double && !expectedValue.equals(Double.NaN)) {
            return (actual, expected) -> actual != null && expected != null && Precision.equals((double)((Double)actual), (double)((Double)expected), (double)1.0E-10);
        }
        if (expectedValue instanceof Float && !expectedValue.equals(Float.valueOf(Float.NaN))) {
            return (actual, expected) -> actual != null && expected != null && Precision.equals((float)((Float)actual).floatValue(), (float)((Float)expected).floatValue(), (float)1.0E-10f);
        }
        return Objects::equals;
    }

    public static void assertAggregation(TestingFunctionResolution functionResolution, String name, List<TypeSignatureProvider> parameterTypes, BiFunction<Object, Object, Boolean> equalAssertion, String testDescription, Page page, Object expectedValue) {
        TestingAggregationFunction function = functionResolution.getAggregateFunction(name, parameterTypes);
        int positions = page.getPositionCount();
        for (int i = 1; i < page.getChannelCount(); ++i) {
            Assert.assertEquals((int)positions, (int)page.getBlock(i).getPositionCount(), (String)"input blocks provided are not equal in position count");
        }
        if (positions == 0) {
            AggregationTestUtils.assertAggregationInternal(function, equalAssertion, testDescription, expectedValue, new Page[0]);
        } else if (positions == 1) {
            AggregationTestUtils.assertAggregationInternal(function, equalAssertion, testDescription, expectedValue, page);
        } else {
            int split = positions / 2;
            Page page1 = page.getRegion(0, split);
            Page page2 = page.getRegion(split, positions - split);
            AggregationTestUtils.assertAggregationInternal(function, equalAssertion, testDescription, expectedValue, page1, page2);
        }
    }

    public static Block getIntermediateBlock(Type intermediateType, Aggregator aggregator) {
        BlockBuilder blockBuilder = intermediateType.createBlockBuilder(null, 1000);
        aggregator.evaluate(blockBuilder);
        return blockBuilder.build();
    }

    public static Block getIntermediateBlock(Type intermediateType, GroupedAggregator aggregator) {
        BlockBuilder blockBuilder = intermediateType.createBlockBuilder(null, 1000);
        aggregator.evaluate(0, blockBuilder);
        return blockBuilder.build();
    }

    public static Block getFinalBlock(Type finalType, Aggregator aggregator) {
        BlockBuilder blockBuilder = finalType.createBlockBuilder(null, 1000);
        aggregator.evaluate(blockBuilder);
        return blockBuilder.build();
    }

    private static void assertAggregationInternal(TestingAggregationFunction function, BiFunction<Object, Object, Boolean> isEqual, String testDescription, Object expectedValue, Page ... pages) {
        AggregationTestUtils.assertFunctionEquals(isEqual, testDescription, AggregationTestUtils.aggregation(function, pages), expectedValue);
        AggregationTestUtils.assertFunctionEquals(isEqual, testDescription, AggregationTestUtils.partialAggregation(function, pages), expectedValue);
        if (pages.length > 0) {
            AggregationTestUtils.assertFunctionEquals(isEqual, testDescription, AggregationTestUtils.groupedAggregation(isEqual, function, pages), expectedValue);
            AggregationTestUtils.assertFunctionEquals(isEqual, testDescription, AggregationTestUtils.groupedPartialAggregation(isEqual, function, pages), expectedValue);
            AggregationTestUtils.assertFunctionEquals(isEqual, testDescription, AggregationTestUtils.distinctAggregation(function, pages), expectedValue);
        }
    }

    private static void assertFunctionEquals(BiFunction<Object, Object, Boolean> isEqual, String testDescription, Object actualValue, Object expectedValue) {
        if (!isEqual.apply(actualValue, expectedValue).booleanValue()) {
            StringBuilder sb = new StringBuilder();
            if (testDescription != null) {
                sb.append(String.format("Test: %s, ", testDescription));
            }
            sb.append(String.format("Expected: %s, actual: %s", expectedValue, actualValue));
            Assert.fail((String)sb.toString());
        }
    }

    private static Object distinctAggregation(TestingAggregationFunction function, Page ... pages) {
        int parameterCount = function.getParameterCount();
        OptionalInt maskChannel = OptionalInt.of(pages[0].getChannelCount());
        Object aggregation = AggregationTestUtils.aggregation(function, AggregationTestUtils.createArgs(parameterCount), maskChannel, AggregationTestUtils.maskPages(true, pages));
        Page[] dupedPages = new Page[pages.length * 2];
        System.arraycopy(AggregationTestUtils.maskPages(true, pages), 0, dupedPages, 0, pages.length);
        System.arraycopy(AggregationTestUtils.maskPages(false, pages), 0, dupedPages, pages.length, pages.length);
        Object aggregationWithDupes = AggregationTestUtils.aggregation(function, AggregationTestUtils.createArgs(parameterCount), maskChannel, dupedPages);
        Assert.assertEquals((Object)aggregationWithDupes, (Object)aggregation, (String)"Inconsistent results with mask");
        System.arraycopy(AggregationTestUtils.maskPagesWithRle(true, pages), 0, dupedPages, 0, pages.length);
        System.arraycopy(AggregationTestUtils.maskPagesWithRle(false, pages), 0, dupedPages, pages.length, pages.length);
        Object aggregationWithRleMasks = AggregationTestUtils.aggregation(function, AggregationTestUtils.createArgs(parameterCount), maskChannel, dupedPages);
        Assert.assertEquals((Object)aggregationWithRleMasks, (Object)aggregation, (String)"Inconsistent results with RLE mask");
        return aggregation;
    }

    private static Page[] maskPagesWithRle(boolean maskValue, Page ... pages) {
        Page[] maskedPages = new Page[pages.length];
        for (int i = 0; i < pages.length; ++i) {
            Page page = pages[i];
            maskedPages[i] = page.appendColumn(RunLengthEncodedBlock.create((Block)BooleanType.createBlockForSingleNonNullValue((boolean)maskValue), (int)page.getPositionCount()));
        }
        return maskedPages;
    }

    private static Page[] maskPages(boolean maskValue, Page ... pages) {
        Page[] maskedPages = new Page[pages.length];
        for (int i = 0; i < pages.length; ++i) {
            Page page = pages[i];
            BlockBuilder blockBuilder = BooleanType.BOOLEAN.createBlockBuilder(null, page.getPositionCount());
            for (int j = 0; j < page.getPositionCount(); ++j) {
                BooleanType.BOOLEAN.writeBoolean(blockBuilder, maskValue);
            }
            maskedPages[i] = page.appendColumn(blockBuilder.build());
        }
        return maskedPages;
    }

    public static Object aggregation(TestingAggregationFunction function, Page ... pages) {
        Object aggregationWithOffset;
        int parameterCount = function.getParameterCount();
        Object aggregation = AggregationTestUtils.aggregation(function, AggregationTestUtils.createArgs(parameterCount), OptionalInt.empty(), pages);
        if (parameterCount > 1) {
            aggregationWithOffset = AggregationTestUtils.aggregation(function, AggregationTestUtils.reverseArgs(parameterCount), OptionalInt.empty(), AggregationTestUtils.reverseColumns(pages));
            Assert.assertEquals((Object)aggregationWithOffset, (Object)aggregation, (String)"Inconsistent results with reversed channels");
        }
        aggregationWithOffset = AggregationTestUtils.aggregation(function, AggregationTestUtils.offsetArgs(parameterCount, 3), OptionalInt.empty(), AggregationTestUtils.offsetColumns(pages, 3));
        Assert.assertEquals((Object)aggregationWithOffset, (Object)aggregation, (String)"Inconsistent results with channel offset");
        return aggregation;
    }

    private static Object aggregation(TestingAggregationFunction function, int[] args, OptionalInt maskChannel, Page ... pages) {
        Aggregator aggregator = function.createAggregatorFactory(AggregationNode.Step.SINGLE, Ints.asList((int[])args), maskChannel).createAggregator();
        for (Page page : pages) {
            if (page.getPositionCount() <= 0) continue;
            aggregator.processPage(page);
        }
        Block block = AggregationTestUtils.getFinalBlock(function.getFinalType(), aggregator);
        return BlockAssertions.getOnlyValue(function.getFinalType(), block);
    }

    public static Object partialAggregation(TestingAggregationFunction function, Page ... pages) {
        Object aggregationWithOffset;
        int parameterCount = function.getParameterCount();
        Object aggregation = AggregationTestUtils.partialAggregation(function, AggregationTestUtils.createArgs(parameterCount), pages);
        if (parameterCount > 1) {
            aggregationWithOffset = AggregationTestUtils.partialAggregation(function, AggregationTestUtils.reverseArgs(parameterCount), AggregationTestUtils.reverseColumns(pages));
            Assert.assertEquals((Object)aggregationWithOffset, (Object)aggregation, (String)"Inconsistent results with reversed channels");
        }
        aggregationWithOffset = AggregationTestUtils.partialAggregation(function, AggregationTestUtils.offsetArgs(parameterCount, 3), AggregationTestUtils.offsetColumns(pages, 3));
        Assert.assertEquals((Object)aggregationWithOffset, (Object)aggregation, (String)"Inconsistent results with channel offset");
        return aggregation;
    }

    private static Object partialAggregation(TestingAggregationFunction function, int[] args, Page ... pages) {
        AggregatorFactory finalAggregatorFactory = function.createAggregatorFactory(AggregationNode.Step.FINAL, Ints.asList((int[])new int[]{0}), OptionalInt.empty());
        Aggregator finalAggregator = finalAggregatorFactory.createAggregator();
        AggregatorFactory partialAggregatorFactory = function.createAggregatorFactory(AggregationNode.Step.PARTIAL, Ints.asList((int[])args), OptionalInt.empty());
        Block emptyBlock = AggregationTestUtils.getIntermediateBlock(function.getIntermediateType(), partialAggregatorFactory.createAggregator());
        finalAggregator.processPage(new Page(new Block[]{emptyBlock}));
        for (Page page : pages) {
            Aggregator partialAggregation = partialAggregatorFactory.createAggregator();
            if (page.getPositionCount() > 0) {
                partialAggregation.processPage(page);
            }
            Block partialBlock = AggregationTestUtils.getIntermediateBlock(function.getIntermediateType(), partialAggregation);
            finalAggregator.processPage(new Page(new Block[]{partialBlock}));
        }
        finalAggregator.processPage(new Page(new Block[]{emptyBlock}));
        Block finalBlock = AggregationTestUtils.getFinalBlock(function.getFinalType(), finalAggregator);
        return BlockAssertions.getOnlyValue(function.getFinalType(), finalBlock);
    }

    public static Object groupedAggregation(TestingAggregationFunction function, Page page) {
        return AggregationTestUtils.groupedAggregation(Objects::equals, function, page);
    }

    private static Object groupedAggregation(BiFunction<Object, Object, Boolean> isEqual, TestingAggregationFunction function, Page ... pages) {
        Object aggregationWithOffset;
        int parameterCount = function.getParameterCount();
        Object aggregation = AggregationTestUtils.groupedAggregation(function, AggregationTestUtils.createArgs(parameterCount), pages);
        if (parameterCount > 1) {
            aggregationWithOffset = AggregationTestUtils.groupedAggregation(function, AggregationTestUtils.reverseArgs(parameterCount), AggregationTestUtils.reverseColumns(pages));
            AggregationTestUtils.assertFunctionEquals(isEqual, "Inconsistent results with reversed channels", aggregationWithOffset, aggregation);
        }
        aggregationWithOffset = AggregationTestUtils.groupedAggregation(function, AggregationTestUtils.offsetArgs(parameterCount, 3), AggregationTestUtils.offsetColumns(pages, 3));
        AggregationTestUtils.assertFunctionEquals(isEqual, "Consistent results with channel offset", aggregationWithOffset, aggregation);
        return aggregation;
    }

    public static Object groupedAggregation(TestingAggregationFunction function, int[] args, Page ... pages) {
        GroupedAggregator groupedAggregator = function.createAggregatorFactory(AggregationNode.Step.SINGLE, Ints.asList((int[])args), OptionalInt.empty()).createGroupedAggregator();
        for (Page page : pages) {
            groupedAggregator.processPage(0, AggregationTestUtils.createGroupByIdBlock(0, page.getPositionCount()), page);
        }
        Object groupValue = AggregationTestUtils.getGroupValue(function.getFinalType(), groupedAggregator, 0);
        for (Page page : pages) {
            groupedAggregator.processPage(4000, AggregationTestUtils.createGroupByIdBlock(4000, page.getPositionCount()), page);
        }
        Object largeGroupValue = AggregationTestUtils.getGroupValue(function.getFinalType(), groupedAggregator, 4000);
        Assert.assertEquals((Object)largeGroupValue, (Object)groupValue, (String)"Inconsistent results with large group id");
        return groupValue;
    }

    private static Object groupedPartialAggregation(BiFunction<Object, Object, Boolean> isEqual, TestingAggregationFunction function, Page ... pages) {
        Object aggregationWithOffset;
        int parameterCount = function.getParameterCount();
        Object aggregation = AggregationTestUtils.groupedPartialAggregation(function, AggregationTestUtils.createArgs(parameterCount), pages);
        if (parameterCount > 1) {
            aggregationWithOffset = AggregationTestUtils.groupedPartialAggregation(function, AggregationTestUtils.reverseArgs(parameterCount), AggregationTestUtils.reverseColumns(pages));
            AggregationTestUtils.assertFunctionEquals(isEqual, "Consistent results with reversed channels", aggregationWithOffset, aggregation);
        }
        aggregationWithOffset = AggregationTestUtils.groupedPartialAggregation(function, AggregationTestUtils.offsetArgs(parameterCount, 3), AggregationTestUtils.offsetColumns(pages, 3));
        AggregationTestUtils.assertFunctionEquals(isEqual, "Consistent results with channel offset", aggregationWithOffset, aggregation);
        return aggregation;
    }

    private static Object groupedPartialAggregation(TestingAggregationFunction function, int[] args, Page ... pages) {
        AggregatorFactory finalFactory = function.createAggregatorFactory(AggregationNode.Step.FINAL, (List<Integer>)ImmutableList.of((Object)0), OptionalInt.empty());
        GroupedAggregator finalAggregator = finalFactory.createGroupedAggregator();
        AggregatorFactory partialFactory = function.createAggregatorFactory(AggregationNode.Step.PARTIAL, Ints.asList((int[])args), OptionalInt.empty());
        Block emptyBlock = AggregationTestUtils.getIntermediateBlock(function.getIntermediateType(), partialFactory.createGroupedAggregator());
        finalAggregator.processPage(0, AggregationTestUtils.createGroupByIdBlock(0, emptyBlock.getPositionCount()), new Page(new Block[]{emptyBlock}));
        for (Page page : pages) {
            GroupedAggregator partialAggregator = partialFactory.createGroupedAggregator();
            partialAggregator.processPage(0, AggregationTestUtils.createGroupByIdBlock(0, page.getPositionCount()), page);
            Block partialBlock = AggregationTestUtils.getIntermediateBlock(function.getIntermediateType(), partialAggregator);
            finalAggregator.processPage(0, AggregationTestUtils.createGroupByIdBlock(0, partialBlock.getPositionCount()), new Page(new Block[]{partialBlock}));
        }
        finalAggregator.processPage(0, AggregationTestUtils.createGroupByIdBlock(0, emptyBlock.getPositionCount()), new Page(new Block[]{emptyBlock}));
        return AggregationTestUtils.getGroupValue(function.getFinalType(), finalAggregator, 0);
    }

    public static int[] createGroupByIdBlock(int groupId, int positions) {
        int[] groupIds = new int[positions];
        Arrays.fill(groupIds, groupId);
        return groupIds;
    }

    static int[] createArgs(int parameterCount) {
        int[] args = new int[parameterCount];
        for (int i = 0; i < args.length; ++i) {
            args[i] = i;
        }
        return args;
    }

    private static int[] reverseArgs(int parameterCount) {
        int[] args = AggregationTestUtils.createArgs(parameterCount);
        Collections.reverse(Ints.asList((int[])args));
        return args;
    }

    private static int[] offsetArgs(int parameterCount, int offset) {
        int[] args = AggregationTestUtils.createArgs(parameterCount);
        int i = 0;
        while (i < args.length) {
            int n = i++;
            args[n] = args[n] + offset;
        }
        return args;
    }

    private static Page[] reverseColumns(Page[] pages) {
        Page[] newPages = new Page[pages.length];
        for (int i = 0; i < pages.length; ++i) {
            Page page = pages[i];
            if (page.getPositionCount() == 0) {
                newPages[i] = page;
                continue;
            }
            Block[] newBlocks = new Block[page.getChannelCount()];
            for (int channel = 0; channel < page.getChannelCount(); ++channel) {
                newBlocks[channel] = page.getBlock(page.getChannelCount() - channel - 1);
            }
            newPages[i] = new Page(page.getPositionCount(), newBlocks);
        }
        return newPages;
    }

    public static Page[] offsetColumns(Page[] pages, int offset) {
        Page[] newPages = new Page[pages.length];
        for (int i = 0; i < pages.length; ++i) {
            int channel;
            Page page = pages[i];
            Block[] newBlocks = new Block[page.getChannelCount() + offset];
            for (channel = 0; channel < offset; ++channel) {
                newBlocks[channel] = AggregationTestUtils.createAllNullBlock(page.getPositionCount());
            }
            for (channel = 0; channel < page.getChannelCount(); ++channel) {
                newBlocks[channel + offset] = page.getBlock(channel);
            }
            newPages[i] = new Page(page.getPositionCount(), newBlocks);
        }
        return newPages;
    }

    private static Block createAllNullBlock(int positionCount) {
        return RunLengthEncodedBlock.create((Type)BooleanType.BOOLEAN, null, (int)positionCount);
    }

    public static Object getGroupValue(Type finalType, GroupedAggregator groupedAggregator, int groupId) {
        BlockBuilder out = finalType.createBlockBuilder(null, 1);
        groupedAggregator.evaluate(groupId, out);
        return BlockAssertions.getOnlyValue(finalType, out.build());
    }

    public static double[] constructDoublePrimitiveArray(int start, int length) {
        return IntStream.range(start, start + length).asDoubleStream().toArray();
    }
}

