/*
 * Decompiled with CFR 0.152.
 */
package io.trino.operator;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
import com.google.common.primitives.Ints;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import io.airlift.concurrent.Threads;
import io.airlift.slice.Slices;
import io.airlift.units.DataSize;
import io.trino.ExceededMemoryLimitException;
import io.trino.RowPagesBuilder;
import io.trino.Session;
import io.trino.SessionTestUtils;
import io.trino.block.BlockAssertions;
import io.trino.memory.context.AggregatedMemoryContext;
import io.trino.metadata.TestingFunctionResolution;
import io.trino.operator.DriverContext;
import io.trino.operator.DummySpillerFactory;
import io.trino.operator.GroupByHashYieldAssertion;
import io.trino.operator.HashAggregationOperator;
import io.trino.operator.Operator;
import io.trino.operator.OperatorAssertion;
import io.trino.operator.OperatorFactory;
import io.trino.operator.OperatorStats;
import io.trino.operator.SpillContext;
import io.trino.operator.aggregation.TestingAggregationFunction;
import io.trino.operator.aggregation.builder.HashAggregationBuilder;
import io.trino.operator.aggregation.builder.InMemoryHashAggregationBuilder;
import io.trino.operator.aggregation.partial.PartialAggregationController;
import io.trino.plugin.base.metrics.LongCount;
import io.trino.spi.Page;
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.block.VariableWidthBlockBuilder;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.BooleanType;
import io.trino.spi.type.DoubleType;
import io.trino.spi.type.Type;
import io.trino.spi.type.TypeOperators;
import io.trino.spi.type.VarcharType;
import io.trino.spiller.Spiller;
import io.trino.spiller.SpillerFactory;
import io.trino.sql.analyzer.TypeSignatureProvider;
import io.trino.sql.gen.JoinCompiler;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.PlanNodeId;
import io.trino.testing.MaterializedResult;
import io.trino.testing.TestingTaskContext;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;
import java.util.OptionalInt;
import java.util.OptionalLong;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import org.assertj.core.api.AbstractThrowableAssert;
import org.assertj.core.api.Assertions;
import org.testng.Assert;
import org.testng.annotations.AfterMethod;
import org.testng.annotations.BeforeMethod;
import org.testng.annotations.DataProvider;
import org.testng.annotations.Test;

@Test(singleThreaded=true)
public class TestHashAggregationOperator {
    private static final TestingFunctionResolution FUNCTION_RESOLUTION = new TestingFunctionResolution();
    private static final TestingAggregationFunction LONG_AVERAGE = FUNCTION_RESOLUTION.getAggregateFunction("avg", TypeSignatureProvider.fromTypes((Type[])new Type[]{BigintType.BIGINT}));
    private static final TestingAggregationFunction LONG_SUM = FUNCTION_RESOLUTION.getAggregateFunction("sum", TypeSignatureProvider.fromTypes((Type[])new Type[]{BigintType.BIGINT}));
    private static final TestingAggregationFunction COUNT = FUNCTION_RESOLUTION.getAggregateFunction("count", (List<TypeSignatureProvider>)ImmutableList.of());
    private static final TestingAggregationFunction LONG_MIN = FUNCTION_RESOLUTION.getAggregateFunction("min", TypeSignatureProvider.fromTypes((Type[])new Type[]{BigintType.BIGINT}));
    private static final int MAX_BLOCK_SIZE_IN_BYTES = 65536;
    private ExecutorService executor;
    private ScheduledExecutorService scheduledExecutor;
    private final TypeOperators typeOperators = new TypeOperators();
    private final JoinCompiler joinCompiler = new JoinCompiler(this.typeOperators);
    private DummySpillerFactory spillerFactory;

    @BeforeMethod
    public void setUp() {
        this.executor = Executors.newCachedThreadPool(Threads.daemonThreadsNamed((String)(this.getClass().getSimpleName() + "-%s")));
        this.scheduledExecutor = Executors.newScheduledThreadPool(2, Threads.daemonThreadsNamed((String)(this.getClass().getSimpleName() + "-scheduledExecutor-%s")));
        this.spillerFactory = new DummySpillerFactory();
    }

    @AfterMethod(alwaysRun=true)
    public void tearDown() {
        this.spillerFactory = null;
        this.executor.shutdownNow();
        this.scheduledExecutor.shutdownNow();
    }

    @DataProvider(name="hashEnabled")
    public static Object[][] hashEnabled() {
        return new Object[][]{{true}, {false}};
    }

    @DataProvider(name="hashEnabledAndMemoryLimitForMergeValues")
    public static Object[][] hashEnabledAndMemoryLimitForMergeValuesProvider() {
        return new Object[][]{{true, true, true, 8, Integer.MAX_VALUE}, {true, true, false, 8, Integer.MAX_VALUE}, {false, false, false, 0, 0}, {false, true, true, 0, 0}, {false, true, false, 0, 0}, {false, true, true, 8, 0}, {false, true, false, 8, 0}, {false, true, true, 8, Integer.MAX_VALUE}, {false, true, false, 8, Integer.MAX_VALUE}};
    }

    @DataProvider
    public Object[][] dataType() {
        return new Object[][]{{VarcharType.VARCHAR}, {BigintType.BIGINT}};
    }

    @Test(dataProvider="hashEnabledAndMemoryLimitForMergeValues")
    public void testHashAggregation(boolean hashEnabled, boolean spillEnabled, boolean revokeMemoryWhenAddingPages, long memoryLimitForMerge, long memoryLimitForMergeWithMemory) {
        int numberOfRows = 40000;
        TestingAggregationFunction countVarcharColumn = FUNCTION_RESOLUTION.getAggregateFunction("count", TypeSignatureProvider.fromTypes((Type[])new Type[]{VarcharType.VARCHAR}));
        TestingAggregationFunction countBooleanColumn = FUNCTION_RESOLUTION.getAggregateFunction("count", TypeSignatureProvider.fromTypes((Type[])new Type[]{BooleanType.BOOLEAN}));
        TestingAggregationFunction maxVarcharColumn = FUNCTION_RESOLUTION.getAggregateFunction("max", TypeSignatureProvider.fromTypes((Type[])new Type[]{VarcharType.VARCHAR}));
        List hashChannels = Ints.asList((int[])new int[]{1});
        RowPagesBuilder rowPagesBuilder = RowPagesBuilder.rowPagesBuilder(hashEnabled, (List<Integer>)hashChannels, new Type[]{VarcharType.VARCHAR, VarcharType.VARCHAR, VarcharType.VARCHAR, BigintType.BIGINT, BooleanType.BOOLEAN});
        List<Page> input = rowPagesBuilder.addSequencePage(numberOfRows, 100, 0, 100000, 0, 500).addSequencePage(numberOfRows, 100, 0, 200000, 0, 500).addSequencePage(numberOfRows, 100, 0, 300000, 0, 500).build();
        HashAggregationOperator.HashAggregationOperatorFactory operatorFactory = new HashAggregationOperator.HashAggregationOperatorFactory(0, new PlanNodeId("test"), (List)ImmutableList.of((Object)VarcharType.VARCHAR), hashChannels, (List)ImmutableList.of(), AggregationNode.Step.SINGLE, false, (List)ImmutableList.of((Object)COUNT.createAggregatorFactory(AggregationNode.Step.SINGLE, (List<Integer>)ImmutableList.of((Object)0), OptionalInt.empty()), (Object)LONG_SUM.createAggregatorFactory(AggregationNode.Step.SINGLE, (List<Integer>)ImmutableList.of((Object)3), OptionalInt.empty()), (Object)LONG_AVERAGE.createAggregatorFactory(AggregationNode.Step.SINGLE, (List<Integer>)ImmutableList.of((Object)3), OptionalInt.empty()), (Object)maxVarcharColumn.createAggregatorFactory(AggregationNode.Step.SINGLE, (List<Integer>)ImmutableList.of((Object)2), OptionalInt.empty()), (Object)countVarcharColumn.createAggregatorFactory(AggregationNode.Step.SINGLE, (List<Integer>)ImmutableList.of((Object)0), OptionalInt.empty()), (Object)countBooleanColumn.createAggregatorFactory(AggregationNode.Step.SINGLE, (List<Integer>)ImmutableList.of((Object)4), OptionalInt.empty())), rowPagesBuilder.getHashChannel(), Optional.empty(), 100000, Optional.of(DataSize.of((long)16L, (DataSize.Unit)DataSize.Unit.MEGABYTE)), spillEnabled, DataSize.succinctBytes((long)memoryLimitForMerge), DataSize.succinctBytes((long)memoryLimitForMergeWithMemory), (SpillerFactory)this.spillerFactory, this.joinCompiler, this.typeOperators, Optional.empty());
        DriverContext driverContext = this.createDriverContext(memoryLimitForMerge);
        MaterializedResult.Builder expectedBuilder = MaterializedResult.resultBuilder((Session)driverContext.getSession(), (Type[])new Type[]{VarcharType.VARCHAR, BigintType.BIGINT, BigintType.BIGINT, DoubleType.DOUBLE, VarcharType.VARCHAR, BigintType.BIGINT, BigintType.BIGINT});
        for (int i = 0; i < numberOfRows; ++i) {
            expectedBuilder.row(new Object[]{Integer.toString(i), 3L, 3L * (long)i, (double)i, Integer.toString(300000 + i), 3L, 3L});
        }
        MaterializedResult expected = expectedBuilder.build();
        List<Page> pages = OperatorAssertion.toPages((OperatorFactory)operatorFactory, driverContext, input, revokeMemoryWhenAddingPages);
        io.airlift.testing.Assertions.assertGreaterThan((Comparable)Integer.valueOf(pages.size()), (Comparable)Integer.valueOf(1), (String)"Expected more than one output page");
        OperatorAssertion.assertPagesEqualIgnoreOrder(driverContext, pages, expected, hashEnabled, Optional.of(hashChannels.size()));
        Assert.assertTrue((spillEnabled == this.spillerFactory.getSpillsCount() > 0L ? 1 : 0) != 0, (String)String.format("Spill state mismatch. Expected spill: %s, spill count: %s", spillEnabled, this.spillerFactory.getSpillsCount()));
    }

    @Test(dataProvider="hashEnabledAndMemoryLimitForMergeValues")
    public void testHashAggregationWithGlobals(boolean hashEnabled, boolean spillEnabled, boolean revokeMemoryWhenAddingPages, long memoryLimitForMerge, long memoryLimitForMergeWithMemory) {
        TestingAggregationFunction countVarcharColumn = FUNCTION_RESOLUTION.getAggregateFunction("count", TypeSignatureProvider.fromTypes((Type[])new Type[]{VarcharType.VARCHAR}));
        TestingAggregationFunction countBooleanColumn = FUNCTION_RESOLUTION.getAggregateFunction("count", TypeSignatureProvider.fromTypes((Type[])new Type[]{BooleanType.BOOLEAN}));
        TestingAggregationFunction maxVarcharColumn = FUNCTION_RESOLUTION.getAggregateFunction("max", TypeSignatureProvider.fromTypes((Type[])new Type[]{VarcharType.VARCHAR}));
        Optional<Integer> groupIdChannel = Optional.of(1);
        List groupByChannels = Ints.asList((int[])new int[]{1, 2});
        List globalAggregationGroupIds = Ints.asList((int[])new int[]{42, 49});
        RowPagesBuilder rowPagesBuilder = RowPagesBuilder.rowPagesBuilder(hashEnabled, (List<Integer>)groupByChannels, new Type[]{VarcharType.VARCHAR, VarcharType.VARCHAR, VarcharType.VARCHAR, BigintType.BIGINT, BigintType.BIGINT, BooleanType.BOOLEAN});
        List<Page> input = rowPagesBuilder.build();
        HashAggregationOperator.HashAggregationOperatorFactory operatorFactory = new HashAggregationOperator.HashAggregationOperatorFactory(0, new PlanNodeId("test"), (List)ImmutableList.of((Object)VarcharType.VARCHAR, (Object)BigintType.BIGINT), groupByChannels, globalAggregationGroupIds, AggregationNode.Step.SINGLE, true, (List)ImmutableList.of((Object)COUNT.createAggregatorFactory(AggregationNode.Step.SINGLE, (List<Integer>)ImmutableList.of((Object)0), OptionalInt.empty()), (Object)LONG_MIN.createAggregatorFactory(AggregationNode.Step.SINGLE, (List<Integer>)ImmutableList.of((Object)4), OptionalInt.empty()), (Object)LONG_AVERAGE.createAggregatorFactory(AggregationNode.Step.SINGLE, (List<Integer>)ImmutableList.of((Object)4), OptionalInt.empty()), (Object)maxVarcharColumn.createAggregatorFactory(AggregationNode.Step.SINGLE, (List<Integer>)ImmutableList.of((Object)2), OptionalInt.empty()), (Object)countVarcharColumn.createAggregatorFactory(AggregationNode.Step.SINGLE, (List<Integer>)ImmutableList.of((Object)0), OptionalInt.empty()), (Object)countBooleanColumn.createAggregatorFactory(AggregationNode.Step.SINGLE, (List<Integer>)ImmutableList.of((Object)5), OptionalInt.empty())), rowPagesBuilder.getHashChannel(), groupIdChannel, 100000, Optional.of(DataSize.of((long)16L, (DataSize.Unit)DataSize.Unit.MEGABYTE)), spillEnabled, DataSize.succinctBytes((long)memoryLimitForMerge), DataSize.succinctBytes((long)memoryLimitForMergeWithMemory), (SpillerFactory)this.spillerFactory, this.joinCompiler, this.typeOperators, Optional.empty());
        DriverContext driverContext = this.createDriverContext(memoryLimitForMerge);
        MaterializedResult expected = MaterializedResult.resultBuilder((Session)driverContext.getSession(), (Type[])new Type[]{VarcharType.VARCHAR, BigintType.BIGINT, BigintType.BIGINT, BigintType.BIGINT, DoubleType.DOUBLE, VarcharType.VARCHAR, BigintType.BIGINT, BigintType.BIGINT}).row(new Object[]{null, 42L, 0L, null, null, null, 0L, 0L}).row(new Object[]{null, 49L, 0L, null, null, null, 0L, 0L}).build();
        OperatorAssertion.assertOperatorEqualsIgnoreOrder((OperatorFactory)operatorFactory, driverContext, input, expected, hashEnabled, Optional.of(groupByChannels.size()), revokeMemoryWhenAddingPages);
    }

    @Test(dataProvider="hashEnabledAndMemoryLimitForMergeValues")
    public void testHashAggregationMemoryReservation(boolean hashEnabled, boolean spillEnabled, boolean revokeMemoryWhenAddingPages, long memoryLimitForMerge, long memoryLimitForMergeWithMemory) {
        TestingAggregationFunction arrayAggColumn = FUNCTION_RESOLUTION.getAggregateFunction("array_agg", TypeSignatureProvider.fromTypes((Type[])new Type[]{BigintType.BIGINT}));
        List hashChannels = Ints.asList((int[])new int[]{1});
        RowPagesBuilder rowPagesBuilder = RowPagesBuilder.rowPagesBuilder(hashEnabled, (List<Integer>)hashChannels, new Type[]{BigintType.BIGINT, BigintType.BIGINT});
        List<Page> input = rowPagesBuilder.addSequencePage(10, 100, 0).addSequencePage(10, 200, 0).addSequencePage(10, 300, 0).build();
        DriverContext driverContext = TestingTaskContext.createTaskContext((Executor)this.executor, (ScheduledExecutorService)this.scheduledExecutor, (Session)SessionTestUtils.TEST_SESSION, (DataSize)DataSize.of((long)11L, (DataSize.Unit)DataSize.Unit.MEGABYTE)).addPipelineContext(0, true, true, false).addDriverContext();
        HashAggregationOperator.HashAggregationOperatorFactory operatorFactory = new HashAggregationOperator.HashAggregationOperatorFactory(0, new PlanNodeId("test"), (List)ImmutableList.of((Object)BigintType.BIGINT), hashChannels, (List)ImmutableList.of(), AggregationNode.Step.SINGLE, true, (List)ImmutableList.of((Object)arrayAggColumn.createAggregatorFactory(AggregationNode.Step.SINGLE, (List<Integer>)ImmutableList.of((Object)0), OptionalInt.empty())), rowPagesBuilder.getHashChannel(), Optional.empty(), 100000, Optional.of(DataSize.of((long)16L, (DataSize.Unit)DataSize.Unit.MEGABYTE)), spillEnabled, DataSize.succinctBytes((long)memoryLimitForMerge), DataSize.succinctBytes((long)memoryLimitForMergeWithMemory), (SpillerFactory)this.spillerFactory, this.joinCompiler, this.typeOperators, Optional.empty());
        Operator operator = operatorFactory.createOperator(driverContext);
        OperatorAssertion.toPages(operator, input.iterator(), revokeMemoryWhenAddingPages);
        Assert.assertEquals((long)((OperatorStats)Iterables.getOnlyElement((Iterable)operator.getOperatorContext().getNestedOperatorStats())).getUserMemoryReservation().toBytes(), (long)(spillEnabled && revokeMemoryWhenAddingPages ? 4752672L : 0L));
        Assert.assertEquals((long)((OperatorStats)Iterables.getOnlyElement((Iterable)operator.getOperatorContext().getNestedOperatorStats())).getRevocableMemoryReservation().toBytes(), (long)0L);
    }

    @Test(dataProvider="hashEnabled", expectedExceptions={ExceededMemoryLimitException.class}, expectedExceptionsMessageRegExp="Query exceeded per-node memory limit of 10B.*")
    public void testMemoryLimit(boolean hashEnabled) {
        TestingAggregationFunction maxVarcharColumn = FUNCTION_RESOLUTION.getAggregateFunction("max", TypeSignatureProvider.fromTypes((Type[])new Type[]{VarcharType.VARCHAR}));
        List hashChannels = Ints.asList((int[])new int[]{1});
        RowPagesBuilder rowPagesBuilder = RowPagesBuilder.rowPagesBuilder(hashEnabled, (List<Integer>)hashChannels, new Type[]{VarcharType.VARCHAR, BigintType.BIGINT, VarcharType.VARCHAR, BigintType.BIGINT});
        List<Page> input = rowPagesBuilder.addSequencePage(10, 100, 0, 100, 0).addSequencePage(10, 100, 0, 200, 0).addSequencePage(10, 100, 0, 300, 0).build();
        DriverContext driverContext = TestingTaskContext.createTaskContext((Executor)this.executor, (ScheduledExecutorService)this.scheduledExecutor, (Session)SessionTestUtils.TEST_SESSION, (DataSize)DataSize.ofBytes((long)10L)).addPipelineContext(0, true, true, false).addDriverContext();
        HashAggregationOperator.HashAggregationOperatorFactory operatorFactory = new HashAggregationOperator.HashAggregationOperatorFactory(0, new PlanNodeId("test"), (List)ImmutableList.of((Object)BigintType.BIGINT), hashChannels, (List)ImmutableList.of(), AggregationNode.Step.SINGLE, (List)ImmutableList.of((Object)COUNT.createAggregatorFactory(AggregationNode.Step.SINGLE, (List<Integer>)ImmutableList.of((Object)0), OptionalInt.empty()), (Object)LONG_MIN.createAggregatorFactory(AggregationNode.Step.SINGLE, (List<Integer>)ImmutableList.of((Object)3), OptionalInt.empty()), (Object)LONG_AVERAGE.createAggregatorFactory(AggregationNode.Step.SINGLE, (List<Integer>)ImmutableList.of((Object)3), OptionalInt.empty()), (Object)maxVarcharColumn.createAggregatorFactory(AggregationNode.Step.SINGLE, (List<Integer>)ImmutableList.of((Object)2), OptionalInt.empty())), rowPagesBuilder.getHashChannel(), Optional.empty(), 100000, Optional.of(DataSize.of((long)16L, (DataSize.Unit)DataSize.Unit.MEGABYTE)), this.joinCompiler, this.typeOperators, Optional.empty());
        OperatorAssertion.toPages((OperatorFactory)operatorFactory, driverContext, input);
    }

    @Test(dataProvider="hashEnabledAndMemoryLimitForMergeValues")
    public void testHashBuilderResize(boolean hashEnabled, boolean spillEnabled, boolean revokeMemoryWhenAddingPages, long memoryLimitForMerge, long memoryLimitForMergeWithMemory) {
        VariableWidthBlockBuilder builder = VarcharType.VARCHAR.createBlockBuilder(null, 1, 65536);
        VarcharType.VARCHAR.writeSlice((BlockBuilder)builder, Slices.allocate((int)200000));
        builder.build();
        List hashChannels = Ints.asList((int[])new int[]{0});
        RowPagesBuilder rowPagesBuilder = RowPagesBuilder.rowPagesBuilder(hashEnabled, (List<Integer>)hashChannels, new Type[]{VarcharType.VARCHAR});
        List<Page> input = rowPagesBuilder.addSequencePage(10, 100).addBlocksPage(builder.build()).addSequencePage(10, 100).build();
        DriverContext driverContext = this.createDriverContext(memoryLimitForMerge);
        HashAggregationOperator.HashAggregationOperatorFactory operatorFactory = new HashAggregationOperator.HashAggregationOperatorFactory(0, new PlanNodeId("test"), (List)ImmutableList.of((Object)VarcharType.VARCHAR), hashChannels, (List)ImmutableList.of(), AggregationNode.Step.SINGLE, false, (List)ImmutableList.of((Object)COUNT.createAggregatorFactory(AggregationNode.Step.SINGLE, (List<Integer>)ImmutableList.of((Object)0), OptionalInt.empty())), rowPagesBuilder.getHashChannel(), Optional.empty(), 100000, Optional.of(DataSize.of((long)16L, (DataSize.Unit)DataSize.Unit.MEGABYTE)), spillEnabled, DataSize.succinctBytes((long)memoryLimitForMerge), DataSize.succinctBytes((long)memoryLimitForMergeWithMemory), (SpillerFactory)this.spillerFactory, this.joinCompiler, this.typeOperators, Optional.empty());
        OperatorAssertion.toPages((OperatorFactory)operatorFactory, driverContext, input, revokeMemoryWhenAddingPages);
    }

    @Test(dataProvider="dataType")
    public void testMemoryReservationYield(Type type) {
        List<Page> input = GroupByHashYieldAssertion.createPagesWithDistinctHashKeys(type, 6000, 600);
        HashAggregationOperator.HashAggregationOperatorFactory operatorFactory = new HashAggregationOperator.HashAggregationOperatorFactory(0, new PlanNodeId("test"), (List)ImmutableList.of((Object)type), (List)ImmutableList.of((Object)0), (List)ImmutableList.of(), AggregationNode.Step.SINGLE, (List)ImmutableList.of((Object)COUNT.createAggregatorFactory(AggregationNode.Step.SINGLE, (List<Integer>)ImmutableList.of((Object)0), OptionalInt.empty())), Optional.of(1), Optional.empty(), 1, Optional.of(DataSize.of((long)16L, (DataSize.Unit)DataSize.Unit.MEGABYTE)), this.joinCompiler, this.typeOperators, Optional.empty());
        GroupByHashYieldAssertion.GroupByHashYieldResult result = GroupByHashYieldAssertion.finishOperatorWithYieldingGroupByHash(input, type, (OperatorFactory)operatorFactory, this::getHashCapacity, 450000L);
        io.airlift.testing.Assertions.assertGreaterThanOrEqual((Comparable)Integer.valueOf(result.getYieldCount()), (Comparable)Integer.valueOf(5));
        io.airlift.testing.Assertions.assertGreaterThanOrEqual((Comparable)Long.valueOf(result.getMaxReservedBytes()), (Comparable)Long.valueOf(0x1400000L));
        int count = 0;
        for (Page page : result.getOutput()) {
            Assert.assertEquals((int)page.getChannelCount(), (int)3);
            for (int i = 0; i < page.getPositionCount(); ++i) {
                Assert.assertEquals((long)page.getBlock(2).getLong(i, 0), (long)1L);
                ++count;
            }
        }
        Assert.assertEquals((int)count, (int)3600000);
    }

    @Test(dataProvider="hashEnabled", expectedExceptions={ExceededMemoryLimitException.class}, expectedExceptionsMessageRegExp="Query exceeded per-node memory limit of 3MB.*")
    public void testHashBuilderResizeLimit(boolean hashEnabled) {
        VariableWidthBlockBuilder builder = VarcharType.VARCHAR.createBlockBuilder(null, 1, 65536);
        VarcharType.VARCHAR.writeSlice((BlockBuilder)builder, Slices.allocate((int)5000000));
        builder.build();
        List hashChannels = Ints.asList((int[])new int[]{0});
        RowPagesBuilder rowPagesBuilder = RowPagesBuilder.rowPagesBuilder(hashEnabled, (List<Integer>)hashChannels, new Type[]{VarcharType.VARCHAR});
        List<Page> input = rowPagesBuilder.addSequencePage(10, 100).addBlocksPage(builder.build()).addSequencePage(10, 100).build();
        DriverContext driverContext = TestingTaskContext.createTaskContext((Executor)this.executor, (ScheduledExecutorService)this.scheduledExecutor, (Session)SessionTestUtils.TEST_SESSION, (DataSize)DataSize.of((long)3L, (DataSize.Unit)DataSize.Unit.MEGABYTE)).addPipelineContext(0, true, true, false).addDriverContext();
        HashAggregationOperator.HashAggregationOperatorFactory operatorFactory = new HashAggregationOperator.HashAggregationOperatorFactory(0, new PlanNodeId("test"), (List)ImmutableList.of((Object)VarcharType.VARCHAR), hashChannels, (List)ImmutableList.of(), AggregationNode.Step.SINGLE, (List)ImmutableList.of((Object)COUNT.createAggregatorFactory(AggregationNode.Step.SINGLE, (List<Integer>)ImmutableList.of((Object)0), OptionalInt.empty())), rowPagesBuilder.getHashChannel(), Optional.empty(), 100000, Optional.of(DataSize.of((long)16L, (DataSize.Unit)DataSize.Unit.MEGABYTE)), this.joinCompiler, this.typeOperators, Optional.empty());
        OperatorAssertion.toPages((OperatorFactory)operatorFactory, driverContext, input);
    }

    @Test(dataProvider="hashEnabled")
    public void testMultiSliceAggregationOutput(boolean hashEnabled) {
        int fixedWidthSize = 32;
        int multiSlicePositionCount = (int)(1572864.0 / (double)fixedWidthSize);
        List hashChannels = Ints.asList((int[])new int[]{1});
        RowPagesBuilder rowPagesBuilder = RowPagesBuilder.rowPagesBuilder(hashEnabled, (List<Integer>)hashChannels, new Type[]{BigintType.BIGINT, BigintType.BIGINT});
        List<Page> input = rowPagesBuilder.addSequencePage(multiSlicePositionCount, 0, 0).build();
        HashAggregationOperator.HashAggregationOperatorFactory operatorFactory = new HashAggregationOperator.HashAggregationOperatorFactory(0, new PlanNodeId("test"), (List)ImmutableList.of((Object)BigintType.BIGINT), hashChannels, (List)ImmutableList.of(), AggregationNode.Step.SINGLE, (List)ImmutableList.of((Object)COUNT.createAggregatorFactory(AggregationNode.Step.SINGLE, (List<Integer>)ImmutableList.of((Object)0), OptionalInt.empty()), (Object)LONG_AVERAGE.createAggregatorFactory(AggregationNode.Step.SINGLE, (List<Integer>)ImmutableList.of((Object)1), OptionalInt.empty())), rowPagesBuilder.getHashChannel(), Optional.empty(), 100000, Optional.of(DataSize.of((long)16L, (DataSize.Unit)DataSize.Unit.MEGABYTE)), this.joinCompiler, this.typeOperators, Optional.empty());
        Assert.assertEquals((int)OperatorAssertion.toPages((OperatorFactory)operatorFactory, this.createDriverContext(), input).size(), (int)2);
    }

    @Test(dataProvider="hashEnabled")
    public void testMultiplePartialFlushes(boolean hashEnabled) throws Exception {
        List hashChannels = Ints.asList((int[])new int[]{0});
        RowPagesBuilder rowPagesBuilder = RowPagesBuilder.rowPagesBuilder(hashEnabled, (List<Integer>)hashChannels, new Type[]{BigintType.BIGINT});
        List<Page> input = rowPagesBuilder.addSequencePage(500, 0).addSequencePage(500, 500).addSequencePage(500, 1000).addSequencePage(500, 1500).build();
        HashAggregationOperator.HashAggregationOperatorFactory operatorFactory = new HashAggregationOperator.HashAggregationOperatorFactory(0, new PlanNodeId("test"), (List)ImmutableList.of((Object)BigintType.BIGINT), hashChannels, (List)ImmutableList.of(), AggregationNode.Step.PARTIAL, (List)ImmutableList.of((Object)LONG_MIN.createAggregatorFactory(AggregationNode.Step.PARTIAL, (List<Integer>)ImmutableList.of((Object)0), OptionalInt.empty())), rowPagesBuilder.getHashChannel(), Optional.empty(), 100000, Optional.of(DataSize.of((long)1L, (DataSize.Unit)DataSize.Unit.KILOBYTE)), this.joinCompiler, this.typeOperators, Optional.empty());
        DriverContext driverContext = this.createDriverContext(1024L);
        try (Operator operator = operatorFactory.createOperator(driverContext);){
            Page output;
            List<Page> expectedPages = RowPagesBuilder.rowPagesBuilder(new Type[]{BigintType.BIGINT, BigintType.BIGINT}).addSequencePage(2000, 0, 0).build();
            MaterializedResult expected = MaterializedResult.resultBuilder((Session)driverContext.getSession(), (Type[])new Type[]{BigintType.BIGINT, BigintType.BIGINT}).pages(expectedPages).build();
            Iterator<Page> inputIterator = input.iterator();
            while (operator.needsInput() && inputIterator.hasNext()) {
                operator.addInput(inputIterator.next());
            }
            Assertions.assertThat((long)driverContext.getMemoryUsage()).isGreaterThan(0L);
            List<Page> outputPages = new ArrayList<Page>();
            while ((output = operator.getOutput()) != null) {
                outputPages.add(output);
            }
            Assert.assertTrue((!outputPages.isEmpty() ? 1 : 0) != 0);
            Assert.assertTrue((boolean)operator.needsInput());
            outputPages.addAll(OperatorAssertion.toPages(operator, inputIterator));
            if (hashEnabled) {
                outputPages = OperatorAssertion.dropChannel(outputPages, (List<Integer>)ImmutableList.of((Object)1));
            }
            MaterializedResult actual = OperatorAssertion.toMaterializedResult(operator.getOperatorContext().getSession(), expected.getTypes(), outputPages);
            Assert.assertEquals((Collection)actual.getTypes(), (Collection)expected.getTypes());
            io.airlift.testing.Assertions.assertEqualsIgnoreOrder((Iterable)actual.getMaterializedRows(), (Iterable)expected.getMaterializedRows());
        }
        Assert.assertEquals((long)driverContext.getMemoryUsage(), (long)0L);
        Assert.assertEquals((long)driverContext.getRevocableMemoryUsage(), (long)0L);
    }

    @Test
    public void testMergeWithMemorySpill() {
        RowPagesBuilder rowPagesBuilder = RowPagesBuilder.rowPagesBuilder(new Type[]{BigintType.BIGINT});
        int smallPagesSpillThresholdSize = 150000;
        List<Page> input = rowPagesBuilder.addSequencePage(smallPagesSpillThresholdSize, 0).addSequencePage(10, smallPagesSpillThresholdSize).build();
        HashAggregationOperator.HashAggregationOperatorFactory operatorFactory = new HashAggregationOperator.HashAggregationOperatorFactory(0, new PlanNodeId("test"), (List)ImmutableList.of((Object)BigintType.BIGINT), (List)ImmutableList.of((Object)0), (List)ImmutableList.of(), AggregationNode.Step.SINGLE, false, (List)ImmutableList.of((Object)LONG_MIN.createAggregatorFactory(AggregationNode.Step.SINGLE, (List<Integer>)ImmutableList.of((Object)0), OptionalInt.empty())), rowPagesBuilder.getHashChannel(), Optional.empty(), 1, Optional.of(DataSize.of((long)16L, (DataSize.Unit)DataSize.Unit.MEGABYTE)), true, DataSize.ofBytes((long)smallPagesSpillThresholdSize), DataSize.succinctBytes((long)Integer.MAX_VALUE), (SpillerFactory)this.spillerFactory, this.joinCompiler, this.typeOperators, Optional.empty());
        DriverContext driverContext = this.createDriverContext(smallPagesSpillThresholdSize);
        MaterializedResult.Builder resultBuilder = MaterializedResult.resultBuilder((Session)driverContext.getSession(), (Type[])new Type[]{BigintType.BIGINT, BigintType.BIGINT});
        for (int i = 0; i < smallPagesSpillThresholdSize + 10; ++i) {
            resultBuilder.row(new Object[]{(long)i, (long)i});
        }
        OperatorAssertion.assertOperatorEqualsIgnoreOrder((OperatorFactory)operatorFactory, driverContext, input, resultBuilder.build());
    }

    @Test
    public void testSpillerFailure() {
        TestingAggregationFunction maxVarcharColumn = FUNCTION_RESOLUTION.getAggregateFunction("max", TypeSignatureProvider.fromTypes((Type[])new Type[]{VarcharType.VARCHAR}));
        List hashChannels = Ints.asList((int[])new int[]{1});
        ImmutableList types = ImmutableList.of((Object)VarcharType.VARCHAR, (Object)BigintType.BIGINT, (Object)VarcharType.VARCHAR, (Object)BigintType.BIGINT);
        RowPagesBuilder rowPagesBuilder = RowPagesBuilder.rowPagesBuilder(false, (List<Integer>)hashChannels, (Iterable<Type>)types);
        List<Page> input = rowPagesBuilder.addSequencePage(10, 100, 0, 100, 0).addSequencePage(2000, 100, 0, 200, 0).addSequencePage(10, 100, 0, 300, 0).build();
        DriverContext driverContext = TestingTaskContext.builder((Executor)this.executor, (ScheduledExecutorService)this.scheduledExecutor, (Session)SessionTestUtils.TEST_SESSION).setQueryMaxMemory(DataSize.valueOf((String)"7MB")).setMemoryPoolSize(DataSize.valueOf((String)"1GB")).build().addPipelineContext(0, true, true, false).addDriverContext();
        HashAggregationOperator.HashAggregationOperatorFactory operatorFactory = new HashAggregationOperator.HashAggregationOperatorFactory(0, new PlanNodeId("test"), (List)ImmutableList.of((Object)BigintType.BIGINT), hashChannels, (List)ImmutableList.of(), AggregationNode.Step.SINGLE, false, (List)ImmutableList.of((Object)COUNT.createAggregatorFactory(AggregationNode.Step.SINGLE, (List<Integer>)ImmutableList.of((Object)0), OptionalInt.empty()), (Object)LONG_MIN.createAggregatorFactory(AggregationNode.Step.SINGLE, (List<Integer>)ImmutableList.of((Object)3), OptionalInt.empty()), (Object)LONG_AVERAGE.createAggregatorFactory(AggregationNode.Step.SINGLE, (List<Integer>)ImmutableList.of((Object)3), OptionalInt.empty()), (Object)maxVarcharColumn.createAggregatorFactory(AggregationNode.Step.SINGLE, (List<Integer>)ImmutableList.of((Object)2), OptionalInt.empty())), rowPagesBuilder.getHashChannel(), Optional.empty(), 100000, Optional.of(DataSize.of((long)16L, (DataSize.Unit)DataSize.Unit.MEGABYTE)), true, DataSize.succinctBytes((long)8L), DataSize.succinctBytes((long)Integer.MAX_VALUE), (SpillerFactory)new FailingSpillerFactory(), this.joinCompiler, this.typeOperators, Optional.empty());
        ((AbstractThrowableAssert)Assertions.assertThatThrownBy(() -> OperatorAssertion.toPages((OperatorFactory)operatorFactory, driverContext, input)).isInstanceOf(RuntimeException.class)).hasCauseInstanceOf(IOException.class).hasMessageEndingWith("Failed to spill");
    }

    @Test
    public void testMemoryTracking() throws Exception {
        List hashChannels = Ints.asList((int[])new int[]{0});
        RowPagesBuilder rowPagesBuilder = RowPagesBuilder.rowPagesBuilder(false, (List<Integer>)hashChannels, new Type[]{BigintType.BIGINT});
        Page input = (Page)Iterables.getOnlyElement(rowPagesBuilder.addSequencePage(500, 0).build());
        HashAggregationOperator.HashAggregationOperatorFactory operatorFactory = new HashAggregationOperator.HashAggregationOperatorFactory(0, new PlanNodeId("test"), (List)ImmutableList.of((Object)BigintType.BIGINT), hashChannels, (List)ImmutableList.of(), AggregationNode.Step.SINGLE, (List)ImmutableList.of((Object)LONG_MIN.createAggregatorFactory(AggregationNode.Step.SINGLE, (List<Integer>)ImmutableList.of((Object)0), OptionalInt.empty())), rowPagesBuilder.getHashChannel(), Optional.empty(), 100000, Optional.of(DataSize.of((long)16L, (DataSize.Unit)DataSize.Unit.MEGABYTE)), this.joinCompiler, this.typeOperators, Optional.empty());
        DriverContext driverContext = this.createDriverContext(1024L);
        try (Operator operator = operatorFactory.createOperator(driverContext);){
            Assert.assertTrue((boolean)operator.needsInput());
            operator.addInput(input);
            Assertions.assertThat((long)driverContext.getMemoryUsage()).isGreaterThan(0L);
            OperatorAssertion.toPages(operator, Collections.emptyIterator());
        }
        Assert.assertEquals((long)driverContext.getMemoryUsage(), (long)0L);
        Assert.assertEquals((long)driverContext.getRevocableMemoryUsage(), (long)0L);
    }

    @Test
    public void testAdaptivePartialAggregation() {
        List hashChannels = Ints.asList((int[])new int[]{0});
        DataSize maxPartialMemory = DataSize.ofBytes((long)1L);
        PartialAggregationController partialAggregationController = new PartialAggregationController(maxPartialMemory, 0.8);
        HashAggregationOperator.HashAggregationOperatorFactory operatorFactory = new HashAggregationOperator.HashAggregationOperatorFactory(0, new PlanNodeId("test"), (List)ImmutableList.of((Object)BigintType.BIGINT), hashChannels, (List)ImmutableList.of(), AggregationNode.Step.PARTIAL, (List)ImmutableList.of((Object)LONG_MIN.createAggregatorFactory(AggregationNode.Step.PARTIAL, (List<Integer>)ImmutableList.of((Object)0), OptionalInt.empty())), Optional.empty(), Optional.empty(), 100, Optional.of(maxPartialMemory), this.joinCompiler, this.typeOperators, Optional.of(partialAggregationController));
        Assert.assertFalse((boolean)partialAggregationController.isPartialAggregationDisabled());
        List<Page> operator1Input = RowPagesBuilder.rowPagesBuilder(false, (List<Integer>)hashChannels, new Type[]{BigintType.BIGINT}).addBlocksPage(new Block[]{BlockAssertions.createLongsBlock(0, 1, 2, 3, 4, 5, 6, 7, 8, 8)}).addBlocksPage(BlockAssertions.createRepeatedValuesBlock(1L, 10)).build();
        List<Page> operator1Expected = RowPagesBuilder.rowPagesBuilder(new Type[]{BigintType.BIGINT, BigintType.BIGINT}).addBlocksPage(new Block[]{BlockAssertions.createLongsBlock(0, 1, 2, 3, 4, 5, 6, 7, 8), BlockAssertions.createLongsBlock(0, 1, 2, 3, 4, 5, 6, 7, 8)}).addBlocksPage(BlockAssertions.createRepeatedValuesBlock(1L, 10), BlockAssertions.createRepeatedValuesBlock(1L, 10)).build();
        this.assertOperatorEquals((OperatorFactory)operatorFactory, operator1Input, operator1Expected);
        Assert.assertTrue((boolean)partialAggregationController.isPartialAggregationDisabled());
        List<Page> operator2Input = RowPagesBuilder.rowPagesBuilder(false, (List<Integer>)hashChannels, new Type[]{BigintType.BIGINT}).addBlocksPage(BlockAssertions.createRepeatedValuesBlock(1L, 10)).addBlocksPage(BlockAssertions.createRepeatedValuesBlock(2L, 10)).build();
        List<Page> operator2Expected = RowPagesBuilder.rowPagesBuilder(new Type[]{BigintType.BIGINT, BigintType.BIGINT}).addBlocksPage(BlockAssertions.createRepeatedValuesBlock(1L, 10), BlockAssertions.createRepeatedValuesBlock(1L, 10)).addBlocksPage(BlockAssertions.createRepeatedValuesBlock(2L, 10), BlockAssertions.createRepeatedValuesBlock(2L, 10)).build();
        this.assertOperatorEquals((OperatorFactory)operatorFactory, operator2Input, operator2Expected);
        for (int i = 1; i <= 3; ++i) {
            List<Page> operatorInput = RowPagesBuilder.rowPagesBuilder(false, (List<Integer>)hashChannels, new Type[]{BigintType.BIGINT}).addBlocksPage(new Block[]{BlockAssertions.createLongsBlock(0, 1, 2, 3, 4, 5, 6, 7, 8)}).build();
            List<Page> operatorExpected = RowPagesBuilder.rowPagesBuilder(new Type[]{BigintType.BIGINT, BigintType.BIGINT}).addBlocksPage(new Block[]{BlockAssertions.createLongsBlock(0, 1, 2, 3, 4, 5, 6, 7, 8), BlockAssertions.createLongsBlock(0, 1, 2, 3, 4, 5, 6, 7, 8)}).build();
            this.assertOperatorEquals((OperatorFactory)operatorFactory, operatorInput, operatorExpected);
            if (i <= 2) {
                Assert.assertTrue((boolean)partialAggregationController.isPartialAggregationDisabled());
                continue;
            }
            Assert.assertFalse((boolean)partialAggregationController.isPartialAggregationDisabled());
        }
        partialAggregationController.onFlush(1000000L, 1000000L, OptionalLong.empty());
        List<Page> operator3Input = RowPagesBuilder.rowPagesBuilder(false, (List<Integer>)hashChannels, new Type[]{BigintType.BIGINT}).addBlocksPage(BlockAssertions.createRepeatedValuesBlock(1L, 100)).addBlocksPage(BlockAssertions.createRepeatedValuesBlock(2L, 100)).build();
        List<Page> operator3Expected = RowPagesBuilder.rowPagesBuilder(new Type[]{BigintType.BIGINT, BigintType.BIGINT}).addBlocksPage(BlockAssertions.createRepeatedValuesBlock(1L, 1), BlockAssertions.createRepeatedValuesBlock(1L, 1)).addBlocksPage(BlockAssertions.createRepeatedValuesBlock(2L, 1), BlockAssertions.createRepeatedValuesBlock(2L, 1)).build();
        this.assertOperatorEquals((OperatorFactory)operatorFactory, operator3Input, operator3Expected);
        Assert.assertFalse((boolean)partialAggregationController.isPartialAggregationDisabled());
    }

    @Test
    public void testAdaptivePartialAggregationTriggeredOnlyOnFlush() {
        List hashChannels = Ints.asList((int[])new int[]{0});
        PartialAggregationController partialAggregationController = new PartialAggregationController(DataSize.ofBytes((long)1L), 0.8);
        HashAggregationOperator.HashAggregationOperatorFactory operatorFactory = new HashAggregationOperator.HashAggregationOperatorFactory(0, new PlanNodeId("test"), (List)ImmutableList.of((Object)BigintType.BIGINT), hashChannels, (List)ImmutableList.of(), AggregationNode.Step.PARTIAL, (List)ImmutableList.of((Object)LONG_MIN.createAggregatorFactory(AggregationNode.Step.PARTIAL, (List<Integer>)ImmutableList.of((Object)0), OptionalInt.empty())), Optional.empty(), Optional.empty(), 10, Optional.of(DataSize.of((long)16L, (DataSize.Unit)DataSize.Unit.MEGABYTE)), this.joinCompiler, this.typeOperators, Optional.of(partialAggregationController));
        DriverContext driverContext = this.createDriverContext(1024L);
        List<Page> operator1Input = RowPagesBuilder.rowPagesBuilder(false, (List<Integer>)hashChannels, new Type[]{BigintType.BIGINT}).addSequencePage(10, 0).addBlocksPage(BlockAssertions.createRepeatedValuesBlock(1L, 2)).build();
        List<Page> operator1Expected = RowPagesBuilder.rowPagesBuilder(new Type[]{BigintType.BIGINT, BigintType.BIGINT}).addSequencePage(10, 0, 0).build();
        this.assertOperatorEquals(driverContext, (OperatorFactory)operatorFactory, operator1Input, operator1Expected);
        Assert.assertTrue((boolean)partialAggregationController.isPartialAggregationDisabled());
        this.assertInputRowsWithPartialAggregationDisabled(driverContext, 0L);
        List<Page> operator2Input = RowPagesBuilder.rowPagesBuilder(false, (List<Integer>)hashChannels, new Type[]{BigintType.BIGINT}).addBlocksPage(BlockAssertions.createRepeatedValuesBlock(1L, 10)).addBlocksPage(BlockAssertions.createRepeatedValuesBlock(2L, 10)).build();
        List<Page> operator2Expected = RowPagesBuilder.rowPagesBuilder(new Type[]{BigintType.BIGINT, BigintType.BIGINT}).addBlocksPage(BlockAssertions.createRepeatedValuesBlock(1L, 10), BlockAssertions.createRepeatedValuesBlock(1L, 10)).addBlocksPage(BlockAssertions.createRepeatedValuesBlock(2L, 10), BlockAssertions.createRepeatedValuesBlock(2L, 10)).build();
        driverContext = this.createDriverContext(1024L);
        this.assertOperatorEquals(driverContext, (OperatorFactory)operatorFactory, operator2Input, operator2Expected);
        this.assertInputRowsWithPartialAggregationDisabled(driverContext, 20L);
    }

    private void assertInputRowsWithPartialAggregationDisabled(DriverContext context, long expectedRowCount) {
        LongCount metric = (LongCount)((OperatorStats)context.getDriverStats().getOperatorStats().get(0)).getMetrics().getMetrics().get("Input rows processed without partial aggregation enabled");
        if (metric == null) {
            Assert.assertEquals((long)0L, (long)expectedRowCount);
        } else {
            Assert.assertEquals((long)metric.getTotal(), (long)expectedRowCount);
        }
    }

    private void assertOperatorEquals(OperatorFactory operatorFactory, List<Page> input, List<Page> expectedPages) {
        this.assertOperatorEquals(this.createDriverContext(1024L), operatorFactory, input, expectedPages);
    }

    private void assertOperatorEquals(DriverContext driverContext, OperatorFactory operatorFactory, List<Page> input, List<Page> expectedPages) {
        MaterializedResult expected = MaterializedResult.resultBuilder((Session)driverContext.getSession(), (Type[])new Type[]{BigintType.BIGINT, BigintType.BIGINT}).pages(expectedPages).build();
        OperatorAssertion.assertOperatorEquals(operatorFactory, driverContext, input, expected, false, false);
    }

    private DriverContext createDriverContext() {
        return this.createDriverContext(Integer.MAX_VALUE);
    }

    private DriverContext createDriverContext(long memoryLimit) {
        return TestingTaskContext.builder((Executor)this.executor, (ScheduledExecutorService)this.scheduledExecutor, (Session)SessionTestUtils.TEST_SESSION).setMemoryPoolSize(DataSize.succinctBytes((long)memoryLimit)).build().addPipelineContext(0, true, true, false).addDriverContext();
    }

    private int getHashCapacity(Operator operator) {
        Assert.assertTrue((boolean)(operator instanceof HashAggregationOperator));
        HashAggregationBuilder aggregationBuilder = ((HashAggregationOperator)operator).getAggregationBuilder();
        if (aggregationBuilder == null) {
            return 0;
        }
        Assert.assertTrue((boolean)(aggregationBuilder instanceof InMemoryHashAggregationBuilder));
        return ((InMemoryHashAggregationBuilder)aggregationBuilder).getCapacity();
    }

    private static class FailingSpillerFactory
    implements SpillerFactory {
        private FailingSpillerFactory() {
        }

        public Spiller create(List<Type> types, SpillContext spillContext, AggregatedMemoryContext memoryContext) {
            return new Spiller(){

                public ListenableFuture<Void> spill(Iterator<Page> pageIterator) {
                    return Futures.immediateFailedFuture((Throwable)new IOException("Failed to spill"));
                }

                public List<Iterator<Page>> getSpills() {
                    return ImmutableList.of();
                }

                public void close() {
                }
            };
        }
    }
}

