/*
 * Decompiled with CFR 0.152.
 */
package io.trino.operator;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
import io.airlift.concurrent.Threads;
import io.trino.RowPagesBuilder;
import io.trino.Session;
import io.trino.SessionTestUtils;
import io.trino.block.BlockAssertions;
import io.trino.metadata.TestingFunctionResolution;
import io.trino.operator.AggregationOperator;
import io.trino.operator.DriverContext;
import io.trino.operator.Operator;
import io.trino.operator.OperatorAssertion;
import io.trino.operator.OperatorFactory;
import io.trino.operator.aggregation.AggregatorFactory;
import io.trino.operator.aggregation.TestingAggregationFunction;
import io.trino.spi.Page;
import io.trino.spi.block.Block;
import io.trino.spi.block.ByteArrayBlock;
import io.trino.spi.block.RunLengthEncodedBlock;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.DoubleType;
import io.trino.spi.type.RealType;
import io.trino.spi.type.Type;
import io.trino.spi.type.VarcharType;
import io.trino.sql.analyzer.TypeSignatureProvider;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.PlanNodeId;
import io.trino.testing.MaterializedResult;
import io.trino.testing.TestingTaskContext;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.OptionalInt;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;

@TestInstance(value=TestInstance.Lifecycle.PER_METHOD)
public class TestAggregationOperator {
    private static final TestingFunctionResolution FUNCTION_RESOLUTION = new TestingFunctionResolution();
    private static final TestingAggregationFunction LONG_AVERAGE = FUNCTION_RESOLUTION.getAggregateFunction("avg", TypeSignatureProvider.fromTypes((Type[])new Type[]{BigintType.BIGINT}));
    private static final TestingAggregationFunction DOUBLE_SUM = FUNCTION_RESOLUTION.getAggregateFunction("sum", TypeSignatureProvider.fromTypes((Type[])new Type[]{DoubleType.DOUBLE}));
    private static final TestingAggregationFunction LONG_SUM = FUNCTION_RESOLUTION.getAggregateFunction("sum", TypeSignatureProvider.fromTypes((Type[])new Type[]{BigintType.BIGINT}));
    private static final TestingAggregationFunction REAL_SUM = FUNCTION_RESOLUTION.getAggregateFunction("sum", TypeSignatureProvider.fromTypes((Type[])new Type[]{RealType.REAL}));
    private static final TestingAggregationFunction COUNT = FUNCTION_RESOLUTION.getAggregateFunction("count", (List<TypeSignatureProvider>)ImmutableList.of());
    private ExecutorService executor;
    private ScheduledExecutorService scheduledExecutor;

    @BeforeEach
    public void setUp() {
        this.executor = Executors.newCachedThreadPool(Threads.daemonThreadsNamed((String)(this.getClass().getSimpleName() + "-%s")));
        this.scheduledExecutor = Executors.newScheduledThreadPool(2, Threads.daemonThreadsNamed((String)(this.getClass().getSimpleName() + "-scheduledExecutor-%s")));
    }

    @AfterEach
    public void tearDown() {
        this.executor.shutdownNow();
        this.scheduledExecutor.shutdownNow();
    }

    @Test
    public void testMaskWithDirtyNulls() {
        ImmutableList input = ImmutableList.of((Object)new Page(4, new Block[]{BlockAssertions.createLongsBlock(1, 2, 3, 4), new ByteArrayBlock(4, Optional.of(new boolean[]{true, true, false, false}), new byte[]{0, 27, 0, 75})}));
        AggregationOperator.AggregationOperatorFactory operatorFactory = new AggregationOperator.AggregationOperatorFactory(0, new PlanNodeId("test"), (List)ImmutableList.of((Object)COUNT.createAggregatorFactory(AggregationNode.Step.SINGLE, (List<Integer>)ImmutableList.of((Object)0), OptionalInt.of(1))));
        DriverContext driverContext = TestingTaskContext.createTaskContext((Executor)this.executor, (ScheduledExecutorService)this.scheduledExecutor, (Session)SessionTestUtils.TEST_SESSION).addPipelineContext(0, true, true, false).addDriverContext();
        MaterializedResult expected = MaterializedResult.resultBuilder((Session)driverContext.getSession(), (Type[])new Type[]{BigintType.BIGINT}).row(new Object[]{1L}).build();
        OperatorAssertion.assertOperatorEquals((OperatorFactory)operatorFactory, driverContext, (List<Page>)input, expected);
    }

    @Test
    public void testDistinctMaskWithNulls() {
        AggregatorFactory distinctFactory = LONG_SUM.createDistinctAggregatorFactory(AggregationNode.Step.SINGLE, (List<Integer>)ImmutableList.of((Object)0), OptionalInt.of(1));
        DriverContext driverContext = TestingTaskContext.createTaskContext((Executor)this.executor, (ScheduledExecutorService)this.scheduledExecutor, (Session)SessionTestUtils.TEST_SESSION).addPipelineContext(0, true, true, false).addDriverContext();
        AggregationOperator.AggregationOperatorFactory operatorFactory = new AggregationOperator.AggregationOperatorFactory(0, new PlanNodeId("test"), (List)ImmutableList.of((Object)distinctFactory));
        ByteArrayBlock trueMaskAllNull = new ByteArrayBlock(4, Optional.of(new boolean[]{true, true, true, true}), new byte[]{1, 1, 1, 1});
        Block trueNullRleMask = RunLengthEncodedBlock.create((Block)trueMaskAllNull.getSingleValueBlock(0), (int)4);
        ImmutableList nullTrueMaskInput = ImmutableList.of((Object)new Page(4, new Block[]{BlockAssertions.createLongsBlock(1, 2, 3, 4), trueMaskAllNull}), (Object)new Page(4, new Block[]{BlockAssertions.createLongsBlock(10, 11, 10, 11), BlockAssertions.createBooleansBlock(true, true, true, true)}), (Object)new Page(4, new Block[]{BlockAssertions.createLongsBlock(5, 6, 7, 8), trueNullRleMask}));
        MaterializedResult expected = MaterializedResult.resultBuilder((Session)driverContext.getSession(), (Type[])new Type[]{BigintType.BIGINT}).row(new Object[]{21L}).build();
        OperatorAssertion.assertOperatorEquals((OperatorFactory)operatorFactory, driverContext, (List<Page>)nullTrueMaskInput, expected);
    }

    @Test
    public void testAggregation() {
        TestingAggregationFunction countVarcharColumn = FUNCTION_RESOLUTION.getAggregateFunction("count", TypeSignatureProvider.fromTypes((Type[])new Type[]{VarcharType.VARCHAR}));
        TestingAggregationFunction maxVarcharColumn = FUNCTION_RESOLUTION.getAggregateFunction("max", TypeSignatureProvider.fromTypes((Type[])new Type[]{VarcharType.VARCHAR}));
        List<Page> input = RowPagesBuilder.rowPagesBuilder(new Type[]{VarcharType.VARCHAR, BigintType.BIGINT, VarcharType.VARCHAR, BigintType.BIGINT, RealType.REAL, DoubleType.DOUBLE, VarcharType.VARCHAR}).addSequencePage(100, 0, 0, 300, 500, 400, 500, 500).build();
        AggregationOperator.AggregationOperatorFactory operatorFactory = new AggregationOperator.AggregationOperatorFactory(0, new PlanNodeId("test"), (List)ImmutableList.of((Object)COUNT.createAggregatorFactory(AggregationNode.Step.SINGLE, (List<Integer>)ImmutableList.of((Object)0), OptionalInt.empty()), (Object)LONG_SUM.createAggregatorFactory(AggregationNode.Step.SINGLE, (List<Integer>)ImmutableList.of((Object)1), OptionalInt.empty()), (Object)LONG_AVERAGE.createAggregatorFactory(AggregationNode.Step.SINGLE, (List<Integer>)ImmutableList.of((Object)1), OptionalInt.empty()), (Object)maxVarcharColumn.createAggregatorFactory(AggregationNode.Step.SINGLE, (List<Integer>)ImmutableList.of((Object)2), OptionalInt.empty()), (Object)countVarcharColumn.createAggregatorFactory(AggregationNode.Step.SINGLE, (List<Integer>)ImmutableList.of((Object)0), OptionalInt.empty()), (Object)LONG_SUM.createAggregatorFactory(AggregationNode.Step.SINGLE, (List<Integer>)ImmutableList.of((Object)3), OptionalInt.empty()), (Object)REAL_SUM.createAggregatorFactory(AggregationNode.Step.SINGLE, (List<Integer>)ImmutableList.of((Object)4), OptionalInt.empty()), (Object)DOUBLE_SUM.createAggregatorFactory(AggregationNode.Step.SINGLE, (List<Integer>)ImmutableList.of((Object)5), OptionalInt.empty()), (Object)maxVarcharColumn.createAggregatorFactory(AggregationNode.Step.SINGLE, (List<Integer>)ImmutableList.of((Object)6), OptionalInt.empty())));
        DriverContext driverContext = TestingTaskContext.createTaskContext((Executor)this.executor, (ScheduledExecutorService)this.scheduledExecutor, (Session)SessionTestUtils.TEST_SESSION).addPipelineContext(0, true, true, false).addDriverContext();
        MaterializedResult expected = MaterializedResult.resultBuilder((Session)driverContext.getSession(), (Type[])new Type[]{BigintType.BIGINT, BigintType.BIGINT, DoubleType.DOUBLE, VarcharType.VARCHAR, BigintType.BIGINT, BigintType.BIGINT, RealType.REAL, DoubleType.DOUBLE, VarcharType.VARCHAR}).row(new Object[]{100L, 4950L, 49.5, "399", 100L, 54950L, Float.valueOf(44950.0f), 54950.0, "599"}).build();
        OperatorAssertion.assertOperatorEquals((OperatorFactory)operatorFactory, driverContext, input, expected);
        Assertions.assertThat((long)driverContext.getMemoryUsage()).isEqualTo(0L);
    }

    @Test
    public void testMemoryTracking() throws Exception {
        Page input = (Page)Iterables.getOnlyElement(RowPagesBuilder.rowPagesBuilder(new Type[]{BigintType.BIGINT}).addSequencePage(100, 0).build());
        AggregationOperator.AggregationOperatorFactory operatorFactory = new AggregationOperator.AggregationOperatorFactory(0, new PlanNodeId("test"), (List)ImmutableList.of((Object)LONG_SUM.createAggregatorFactory(AggregationNode.Step.SINGLE, (List<Integer>)ImmutableList.of((Object)0), OptionalInt.empty())));
        DriverContext driverContext = TestingTaskContext.createTaskContext((Executor)this.executor, (ScheduledExecutorService)this.scheduledExecutor, (Session)SessionTestUtils.TEST_SESSION).addPipelineContext(0, true, true, false).addDriverContext();
        try (Operator operator = operatorFactory.createOperator(driverContext);){
            Assertions.assertThat((boolean)operator.needsInput()).isTrue();
            operator.addInput(input);
            Assertions.assertThat((long)driverContext.getMemoryUsage()).isGreaterThan(0L);
            OperatorAssertion.toPages(operator, Collections.emptyIterator());
        }
        Assertions.assertThat((long)driverContext.getMemoryUsage()).isEqualTo(0L);
    }
}

