/*
 * Decompiled with CFR 0.152.
 */
package io.trino.operator;

import com.google.common.collect.ImmutableList;
import io.airlift.concurrent.Threads;
import io.airlift.stats.GcMonitor;
import io.airlift.stats.TestingGcMonitor;
import io.airlift.units.DataSize;
import io.trino.RowPagesBuilder;
import io.trino.Session;
import io.trino.SessionTestUtils;
import io.trino.execution.StageId;
import io.trino.execution.TaskId;
import io.trino.memory.MemoryPool;
import io.trino.memory.QueryContext;
import io.trino.operator.AppendOnlyVariableWidthData;
import io.trino.operator.DriverContext;
import io.trino.operator.Operator;
import io.trino.operator.OperatorAssertion;
import io.trino.operator.OperatorFactory;
import io.trino.spi.Page;
import io.trino.spi.QueryId;
import io.trino.spi.block.Block;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.Type;
import io.trino.spi.type.VarcharType;
import io.trino.spiller.SpillSpaceTracker;
import io.trino.testing.TestingTaskContext;
import java.util.LinkedList;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.function.Function;
import org.assertj.core.api.AbstractDoubleAssert;
import org.assertj.core.api.AbstractLongAssert;
import org.assertj.core.api.Assertions;

public final class GroupByHashYieldAssertion {
    private static final ExecutorService EXECUTOR = Executors.newCachedThreadPool(Threads.daemonThreadsNamed((String)"GroupByHashYieldAssertion-%s"));
    private static final ScheduledExecutorService SCHEDULED_EXECUTOR = Executors.newScheduledThreadPool(2, Threads.daemonThreadsNamed((String)"GroupByHashYieldAssertion-scheduledExecutor-%s"));

    private GroupByHashYieldAssertion() {
    }

    public static List<Page> createPagesWithDistinctHashKeys(Type type, int pageCount, int positionCountPerPage) {
        RowPagesBuilder rowPagesBuilder = RowPagesBuilder.rowPagesBuilder(true, (List<Integer>)ImmutableList.of((Object)0), type);
        for (int i = 0; i < pageCount; ++i) {
            rowPagesBuilder.addSequencePage(positionCountPerPage, positionCountPerPage * i);
        }
        return rowPagesBuilder.build();
    }

    public static GroupByHashYieldResult finishOperatorWithYieldingGroupByHash(List<Page> input, Type hashKeyType, OperatorFactory operatorFactory, Function<Operator, Integer> getHashCapacity, long additionalMemoryInBytes) {
        ((AbstractLongAssert)Assertions.assertThat((long)additionalMemoryInBytes).as("additionalMemoryInBytes should be a relatively small number", new Object[0])).isLessThan(0x200000L);
        LinkedList<Page> result = new LinkedList<Page>();
        QueryId queryId = new QueryId("test_query");
        TaskId anotherTaskId = new TaskId(new StageId("another_query", 0), 0, 0);
        MemoryPool memoryPool = new MemoryPool(DataSize.of((long)1L, (DataSize.Unit)DataSize.Unit.GIGABYTE));
        QueryContext queryContext = new QueryContext(queryId, DataSize.of((long)512L, (DataSize.Unit)DataSize.Unit.MEGABYTE), memoryPool, (GcMonitor)new TestingGcMonitor(), (Executor)EXECUTOR, SCHEDULED_EXECUTOR, SCHEDULED_EXECUTOR, DataSize.of((long)512L, (DataSize.Unit)DataSize.Unit.MEGABYTE), new SpillSpaceTracker(DataSize.of((long)512L, (DataSize.Unit)DataSize.Unit.MEGABYTE)));
        DriverContext driverContext = TestingTaskContext.createTaskContext((QueryContext)queryContext, (Executor)EXECUTOR, (Session)SessionTestUtils.TEST_SESSION).addPipelineContext(0, true, true, false).addDriverContext();
        Operator operator = operatorFactory.createOperator(driverContext);
        byte[] pointer = new byte[8];
        AppendOnlyVariableWidthData variableWidthData = new AppendOnlyVariableWidthData();
        int yieldCount = 0;
        long maxReservedBytes = 0L;
        for (Page page : input) {
            long pageVariableWidthSize = 0L;
            if (hashKeyType == VarcharType.VARCHAR) {
                long oldVariableWidthSize = variableWidthData.getRetainedSizeBytes();
                Block block = page.getBlock(0);
                for (int position = 0; position < page.getPositionCount(); ++position) {
                    variableWidthData.allocate(pointer, 0, hashKeyType.getFlatVariableWidthSize(block, position));
                }
                pageVariableWidthSize = variableWidthData.getRetainedSizeBytes() - oldVariableWidthSize;
            }
            Assertions.assertThat((boolean)operator.needsInput()).isTrue();
            memoryPool.reserve(anotherTaskId, "test", memoryPool.getFreeBytes() - additionalMemoryInBytes - pageVariableWidthSize);
            long oldMemoryUsage = operator.getOperatorContext().getDriverContext().getMemoryUsage();
            int oldCapacity = getHashCapacity.apply(operator);
            operator.addInput(page);
            Page output = operator.getOutput();
            if (output != null) {
                result.add(output);
            }
            long newMemoryUsage = operator.getOperatorContext().getDriverContext().getMemoryUsage();
            maxReservedBytes = Math.max(maxReservedBytes, newMemoryUsage);
            if (newMemoryUsage < DataSize.of((long)4L, (DataSize.Unit)DataSize.Unit.MEGABYTE).toBytes()) {
                memoryPool.free(anotherTaskId, "test", ((Long)memoryPool.getTaskMemoryReservations().get(anotherTaskId)).longValue());
                output = operator.getOutput();
                if (output == null) continue;
                result.add(output);
                continue;
            }
            long actualHashIncreased = newMemoryUsage - oldMemoryUsage - pageVariableWidthSize;
            if (operator.needsInput()) {
                Assertions.assertThat((boolean)operator.getOperatorContext().isWaitingForMemory().isDone()).isTrue();
                Assertions.assertThat((int)getHashCapacity.apply(operator)).isEqualTo(oldCapacity);
                Assertions.assertThat((long)actualHashIncreased).isLessThan(additionalMemoryInBytes);
                memoryPool.free(anotherTaskId, "test", ((Long)memoryPool.getTaskMemoryReservations().get(anotherTaskId)).longValue());
                continue;
            }
            ++yieldCount;
            Assertions.assertThat((boolean)operator.getOperatorContext().isWaitingForMemory().isDone()).isFalse();
            Assertions.assertThat((int)oldCapacity).isEqualTo((long)getHashCapacity.apply(operator).intValue());
            long expectedHashBytes = GroupByHashYieldAssertion.getHashTableSizeInBytes(hashKeyType, oldCapacity * 2);
            Assertions.assertThat((long)actualHashIncreased).isBetween(Long.valueOf(expectedHashBytes), Long.valueOf(expectedHashBytes + additionalMemoryInBytes));
            Assertions.assertThat((Object)operator.getOutput()).isNull();
            memoryPool.free(anotherTaskId, "test", ((Long)memoryPool.getTaskMemoryReservations().get(anotherTaskId)).longValue());
            output = operator.getOutput();
            if (output != null) {
                result.add(output);
            }
            Assertions.assertThat((boolean)operator.needsInput()).isTrue();
            Assertions.assertThat((Integer)getHashCapacity.apply(operator)).isGreaterThan(oldCapacity);
            long rehashedMemoryUsage = operator.getOperatorContext().getDriverContext().getMemoryUsage();
            long expectedMemoryUsageAfterRehash = oldMemoryUsage + GroupByHashYieldAssertion.getHashTableSizeInBytes(hashKeyType, oldCapacity);
            double memoryUsageErrorUpperBound = 1.01;
            double memoryUsageError = (double)rehashedMemoryUsage * 1.0 / (double)expectedMemoryUsageAfterRehash;
            if (memoryUsageError > memoryUsageErrorUpperBound) {
                ((AbstractDoubleAssert)Assertions.assertThat((double)((double)rehashedMemoryUsage * 1.0 / (double)(expectedMemoryUsageAfterRehash + additionalMemoryInBytes))).as("rehashedMemoryUsage " + rehashedMemoryUsage + ", expectedMemoryUsageAfterRehash: " + expectedMemoryUsageAfterRehash, new Object[0])).isBetween(Double.valueOf(0.97), Double.valueOf(memoryUsageErrorUpperBound));
            } else {
                Assertions.assertThat((double)memoryUsageError).isBetween(Double.valueOf(0.99), Double.valueOf(memoryUsageErrorUpperBound));
            }
            Assertions.assertThat((boolean)operator.needsInput()).isTrue();
            Assertions.assertThat((boolean)operator.getOperatorContext().isWaitingForMemory().isDone()).isTrue();
        }
        result.addAll(OperatorAssertion.finishOperator(operator));
        return new GroupByHashYieldResult(yieldCount, maxReservedBytes, result);
    }

    private static long getHashTableSizeInBytes(Type hashKeyType, int capacity) {
        if (hashKeyType == BigintType.BIGINT) {
            return (long)capacity * 18L;
        }
        int sizePerEntry = 5;
        return (long)capacity * (long)sizePerEntry;
    }

    public static final class GroupByHashYieldResult {
        private final int yieldCount;
        private final long maxReservedBytes;
        private final List<Page> output;

        public GroupByHashYieldResult(int yieldCount, long maxReservedBytes, List<Page> output) {
            this.yieldCount = yieldCount;
            this.maxReservedBytes = maxReservedBytes;
            this.output = Objects.requireNonNull(output, "output is null");
        }

        public int getYieldCount() {
            return this.yieldCount;
        }

        public long getMaxReservedBytes() {
            return this.maxReservedBytes;
        }

        public List<Page> getOutput() {
            return this.output;
        }
    }
}

