/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.operator;

import com.facebook.presto.Session;
import com.facebook.presto.block.BlockAssertions;
import com.facebook.presto.common.Page;
import com.facebook.presto.common.PageBuilder;
import com.facebook.presto.common.block.Block;
import com.facebook.presto.common.block.DictionaryBlock;
import com.facebook.presto.common.block.DictionaryId;
import com.facebook.presto.common.type.BigintType;
import com.facebook.presto.common.type.DoubleType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.common.type.VarcharType;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.metadata.MetadataManager;
import com.facebook.presto.operator.GroupByHash;
import com.facebook.presto.operator.UpdateMemory;
import com.facebook.presto.operator.Work;
import com.facebook.presto.spi.function.aggregation.GroupByIdBlock;
import com.facebook.presto.sql.analyzer.FeaturesConfig;
import com.facebook.presto.sql.gen.JoinCompiler;
import com.facebook.presto.testing.TestingSession;
import com.facebook.presto.type.TypeUtils;
import com.google.common.collect.ImmutableList;
import com.google.common.math.DoubleMath;
import java.math.RoundingMode;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.IntStream;
import org.testng.Assert;
import org.testng.annotations.DataProvider;
import org.testng.annotations.Test;

public class TestGroupByHash {
    private static final int MAX_GROUP_ID = 500;
    private static final int[] CONTAINS_CHANNELS = new int[]{0};
    private static final Session TEST_SESSION = TestingSession.testSessionBuilder().build();
    private static final JoinCompiler JOIN_COMPILER = new JoinCompiler((Metadata)MetadataManager.createTestMetadataManager(), new FeaturesConfig());

    @DataProvider
    public Object[][] dataType() {
        return new Object[][]{{VarcharType.VARCHAR}, {BigintType.BIGINT}};
    }

    @Test
    public void testAddPage() {
        GroupByHash groupByHash = GroupByHash.createGroupByHash((Session)TEST_SESSION, (List)ImmutableList.of((Object)BigintType.BIGINT), (int[])new int[]{0}, Optional.of(1), (int)100, (JoinCompiler)JOIN_COMPILER);
        for (int tries = 0; tries < 2; ++tries) {
            for (int value = 0; value < 500; ++value) {
                Block block = BlockAssertions.createLongsBlock(value);
                Block hashBlock = TypeUtils.getHashBlock((List)ImmutableList.of((Object)BigintType.BIGINT), (Block[])new Block[]{block});
                Page page = new Page(new Block[]{block, hashBlock});
                for (int addValuesTries = 0; addValuesTries < 10; ++addValuesTries) {
                    groupByHash.addPage(page).process();
                    Assert.assertEquals((int)groupByHash.getGroupCount(), (int)(tries == 0 ? value + 1 : 500));
                    Work work = groupByHash.getGroupIds(page);
                    work.process();
                    GroupByIdBlock groupIds = (GroupByIdBlock)work.getResult();
                    Assert.assertEquals((int)groupByHash.getGroupCount(), (int)(tries == 0 ? value + 1 : 500));
                    Assert.assertEquals((long)groupIds.getGroupCount(), (long)(tries == 0 ? (long)(value + 1) : 500L));
                    Assert.assertEquals((int)groupIds.getPositionCount(), (int)1);
                    long groupId = groupIds.getGroupId(0);
                    Assert.assertEquals((long)groupId, (long)value);
                }
            }
        }
    }

    @Test
    public void testNullGroup() {
        GroupByHash groupByHash = GroupByHash.createGroupByHash((Session)TEST_SESSION, (List)ImmutableList.of((Object)BigintType.BIGINT), (int[])new int[]{0}, Optional.of(1), (int)100, (JoinCompiler)JOIN_COMPILER);
        Block block = BlockAssertions.createLongsBlock(new Long[]{null});
        Block hashBlock = TypeUtils.getHashBlock((List)ImmutableList.of((Object)BigintType.BIGINT), (Block[])new Block[]{block});
        Page page = new Page(new Block[]{block, hashBlock});
        groupByHash.addPage(page).process();
        block = BlockAssertions.createLongSequenceBlock(1, 132748);
        hashBlock = TypeUtils.getHashBlock((List)ImmutableList.of((Object)BigintType.BIGINT), (Block[])new Block[]{block});
        page = new Page(new Block[]{block, hashBlock});
        groupByHash.addPage(page).process();
        block = BlockAssertions.createLongsBlock(0);
        hashBlock = TypeUtils.getHashBlock((List)ImmutableList.of((Object)BigintType.BIGINT), (Block[])new Block[]{block});
        page = new Page(new Block[]{block, hashBlock});
        Assert.assertFalse((boolean)groupByHash.contains(0, page, CONTAINS_CHANNELS));
    }

    @Test
    public void testGetGroupIds() {
        GroupByHash groupByHash = GroupByHash.createGroupByHash((Session)TEST_SESSION, (List)ImmutableList.of((Object)BigintType.BIGINT), (int[])new int[]{0}, Optional.of(1), (int)100, (JoinCompiler)JOIN_COMPILER);
        for (int tries = 0; tries < 2; ++tries) {
            for (int value = 0; value < 500; ++value) {
                Block block = BlockAssertions.createLongsBlock(value);
                Block hashBlock = TypeUtils.getHashBlock((List)ImmutableList.of((Object)BigintType.BIGINT), (Block[])new Block[]{block});
                Page page = new Page(new Block[]{block, hashBlock});
                for (int addValuesTries = 0; addValuesTries < 10; ++addValuesTries) {
                    Work work = groupByHash.getGroupIds(page);
                    work.process();
                    GroupByIdBlock groupIds = (GroupByIdBlock)work.getResult();
                    Assert.assertEquals((long)groupIds.getGroupCount(), (long)(tries == 0 ? (long)(value + 1) : 500L));
                    Assert.assertEquals((int)groupIds.getPositionCount(), (int)1);
                    long groupId = groupIds.getGroupId(0);
                    Assert.assertEquals((long)groupId, (long)value);
                }
            }
        }
    }

    @Test
    public void testTypes() {
        GroupByHash groupByHash = GroupByHash.createGroupByHash((Session)TEST_SESSION, (List)ImmutableList.of((Object)VarcharType.VARCHAR), (int[])new int[]{0}, Optional.of(1), (int)100, (JoinCompiler)JOIN_COMPILER);
        Assert.assertEquals((Collection)groupByHash.getTypes(), (Collection)ImmutableList.of((Object)VarcharType.VARCHAR, (Object)BigintType.BIGINT));
    }

    @Test
    public void testAppendTo() {
        Block valuesBlock = BlockAssertions.createStringSequenceBlock(0, 100);
        Block hashBlock = TypeUtils.getHashBlock((List)ImmutableList.of((Object)VarcharType.VARCHAR), (Block[])new Block[]{valuesBlock});
        GroupByHash groupByHash = GroupByHash.createGroupByHash((Session)TEST_SESSION, (List)ImmutableList.of((Object)VarcharType.VARCHAR), (int[])new int[]{0}, Optional.of(1), (int)100, (JoinCompiler)JOIN_COMPILER);
        Work work = groupByHash.getGroupIds(new Page(new Block[]{valuesBlock, hashBlock}));
        work.process();
        GroupByIdBlock groupIds = (GroupByIdBlock)work.getResult();
        for (int i = 0; i < groupIds.getPositionCount(); ++i) {
            Assert.assertEquals((long)groupIds.getGroupId(i), (long)i);
        }
        Assert.assertEquals((int)groupByHash.getGroupCount(), (int)100);
        PageBuilder pageBuilder = new PageBuilder(groupByHash.getTypes());
        for (int i = 0; i < groupByHash.getGroupCount(); ++i) {
            pageBuilder.declarePosition();
            groupByHash.appendValuesTo(i, pageBuilder, 0);
        }
        Page page = pageBuilder.build();
        for (int i = 0; i < groupByHash.getTypes().size(); ++i) {
            Assert.assertEquals((int)page.getBlock(i).getPositionCount(), (int)100);
        }
        Assert.assertEquals((int)page.getPositionCount(), (int)100);
        BlockAssertions.assertBlockEquals((Type)VarcharType.VARCHAR, page.getBlock(0), valuesBlock);
        BlockAssertions.assertBlockEquals((Type)BigintType.BIGINT, page.getBlock(1), hashBlock);
    }

    @Test
    public void testAppendToMultipleTuplesPerGroup() {
        ArrayList<Long> values = new ArrayList<Long>();
        for (long i = 0L; i < 100L; ++i) {
            values.add(i % 50L);
        }
        Block valuesBlock = BlockAssertions.createLongsBlock(values);
        Block hashBlock = TypeUtils.getHashBlock((List)ImmutableList.of((Object)BigintType.BIGINT), (Block[])new Block[]{valuesBlock});
        GroupByHash groupByHash = GroupByHash.createGroupByHash((Session)TEST_SESSION, (List)ImmutableList.of((Object)BigintType.BIGINT), (int[])new int[]{0}, Optional.of(1), (int)100, (JoinCompiler)JOIN_COMPILER);
        groupByHash.getGroupIds(new Page(new Block[]{valuesBlock, hashBlock})).process();
        Assert.assertEquals((int)groupByHash.getGroupCount(), (int)50);
        PageBuilder pageBuilder = new PageBuilder(groupByHash.getTypes());
        for (int i = 0; i < groupByHash.getGroupCount(); ++i) {
            pageBuilder.declarePosition();
            groupByHash.appendValuesTo(i, pageBuilder, 0);
        }
        Page outputPage = pageBuilder.build();
        Assert.assertEquals((int)outputPage.getPositionCount(), (int)50);
        BlockAssertions.assertBlockEquals((Type)BigintType.BIGINT, outputPage.getBlock(0), BlockAssertions.createLongSequenceBlock(0, 50));
    }

    @Test
    public void testContains() {
        Block valuesBlock = BlockAssertions.createDoubleSequenceBlock(0, 10);
        Block hashBlock = TypeUtils.getHashBlock((List)ImmutableList.of((Object)DoubleType.DOUBLE), (Block[])new Block[]{valuesBlock});
        GroupByHash groupByHash = GroupByHash.createGroupByHash((Session)TEST_SESSION, (List)ImmutableList.of((Object)DoubleType.DOUBLE), (int[])new int[]{0}, Optional.of(1), (int)100, (JoinCompiler)JOIN_COMPILER);
        groupByHash.getGroupIds(new Page(new Block[]{valuesBlock, hashBlock})).process();
        Block testBlock = BlockAssertions.createDoublesBlock(3.0);
        Block testHashBlock = TypeUtils.getHashBlock((List)ImmutableList.of((Object)DoubleType.DOUBLE), (Block[])new Block[]{testBlock});
        Assert.assertTrue((boolean)groupByHash.contains(0, new Page(new Block[]{testBlock, testHashBlock}), CONTAINS_CHANNELS));
        testBlock = BlockAssertions.createDoublesBlock(11.0);
        testHashBlock = TypeUtils.getHashBlock((List)ImmutableList.of((Object)DoubleType.DOUBLE), (Block[])new Block[]{testBlock});
        Assert.assertFalse((boolean)groupByHash.contains(0, new Page(new Block[]{testBlock, testHashBlock}), CONTAINS_CHANNELS));
    }

    @Test
    public void testContainsMultipleColumns() {
        Block valuesBlock = BlockAssertions.createDoubleSequenceBlock(0, 10);
        Block stringValuesBlock = BlockAssertions.createStringSequenceBlock(0, 10);
        Block hashBlock = TypeUtils.getHashBlock((List)ImmutableList.of((Object)DoubleType.DOUBLE, (Object)VarcharType.VARCHAR), (Block[])new Block[]{valuesBlock, stringValuesBlock});
        int[] hashChannels = new int[]{0, 1};
        GroupByHash groupByHash = GroupByHash.createGroupByHash((Session)TEST_SESSION, (List)ImmutableList.of((Object)DoubleType.DOUBLE, (Object)VarcharType.VARCHAR), (int[])hashChannels, Optional.of(2), (int)100, (JoinCompiler)JOIN_COMPILER);
        groupByHash.getGroupIds(new Page(new Block[]{valuesBlock, stringValuesBlock, hashBlock})).process();
        Block testValuesBlock = BlockAssertions.createDoublesBlock(3.0);
        Block testStringValuesBlock = BlockAssertions.createStringsBlock("3");
        Block testHashBlock = TypeUtils.getHashBlock((List)ImmutableList.of((Object)DoubleType.DOUBLE, (Object)VarcharType.VARCHAR), (Block[])new Block[]{testValuesBlock, testStringValuesBlock});
        Assert.assertTrue((boolean)groupByHash.contains(0, new Page(new Block[]{testValuesBlock, testStringValuesBlock, testHashBlock}), hashChannels));
    }

    @Test
    public void testForceRehash() {
        Block valuesBlock = BlockAssertions.createStringSequenceBlock(0, 100);
        Block hashBlock = TypeUtils.getHashBlock((List)ImmutableList.of((Object)VarcharType.VARCHAR), (Block[])new Block[]{valuesBlock});
        GroupByHash groupByHash = GroupByHash.createGroupByHash((Session)TEST_SESSION, (List)ImmutableList.of((Object)VarcharType.VARCHAR), (int[])new int[]{0}, Optional.of(1), (int)4, (JoinCompiler)JOIN_COMPILER);
        groupByHash.getGroupIds(new Page(new Block[]{valuesBlock, hashBlock})).process();
        for (int i = 0; i < valuesBlock.getPositionCount(); ++i) {
            Assert.assertTrue((boolean)groupByHash.contains(i, new Page(new Block[]{valuesBlock, hashBlock}), CONTAINS_CHANNELS));
        }
    }

    @Test(dataProvider="dataType")
    public void testUpdateMemory(Type type) {
        Block valuesBlock;
        int length = 1000000;
        if (type == VarcharType.VARCHAR) {
            valuesBlock = BlockAssertions.createStringSequenceBlock(0, length);
        } else if (type == BigintType.BIGINT) {
            valuesBlock = BlockAssertions.createLongSequenceBlock(0, length);
        } else {
            throw new IllegalArgumentException("unsupported data type");
        }
        Block hashBlock = TypeUtils.getHashBlock((List)ImmutableList.of((Object)type), (Block[])new Block[]{valuesBlock});
        AtomicInteger rehashCount = new AtomicInteger();
        GroupByHash groupByHash = GroupByHash.createGroupByHash((List)ImmutableList.of((Object)type), (int[])new int[]{0}, Optional.of(1), (int)1, (boolean)false, (JoinCompiler)JOIN_COMPILER, () -> {
            rehashCount.incrementAndGet();
            return true;
        });
        groupByHash.addPage(new Page(new Block[]{valuesBlock, hashBlock})).process();
        Assert.assertEquals((int)rehashCount.get(), (int)(2 * DoubleMath.log2((double)((double)length / 0.75), (RoundingMode)RoundingMode.FLOOR)));
    }

    @Test(dataProvider="dataType")
    public void testEmptyPage(Type type) {
        Block valuesBlock;
        int length = 0;
        if (type == VarcharType.VARCHAR) {
            valuesBlock = BlockAssertions.createStringSequenceBlock(0, length);
        } else if (type == BigintType.BIGINT) {
            valuesBlock = BlockAssertions.createLongSequenceBlock(0, length);
        } else {
            throw new IllegalArgumentException("unsupported data type");
        }
        Block hashBlock = TypeUtils.getHashBlock((List)ImmutableList.of((Object)type), (Block[])new Block[]{valuesBlock});
        Page page = new Page(new Block[]{valuesBlock, hashBlock});
        AtomicInteger currentQuota = new AtomicInteger(0);
        AtomicInteger allowedQuota = new AtomicInteger(6);
        UpdateMemory updateMemory = () -> {
            if (currentQuota.get() < allowedQuota.get()) {
                currentQuota.getAndIncrement();
                return true;
            }
            return false;
        };
        GroupByHash groupByHash = GroupByHash.createGroupByHash((List)ImmutableList.of((Object)type), (int[])new int[]{0}, Optional.of(1), (int)1, (boolean)false, (JoinCompiler)JOIN_COMPILER, (UpdateMemory)updateMemory);
        Work addPageWork = groupByHash.addPage(page);
        Assert.assertTrue((boolean)addPageWork.process());
    }

    @Test(dataProvider="dataType")
    public void testMemoryReservationYield(Type type) {
        Block valuesBlock;
        int length = 1000000;
        if (type == VarcharType.VARCHAR) {
            valuesBlock = BlockAssertions.createStringSequenceBlock(0, length);
        } else if (type == BigintType.BIGINT) {
            valuesBlock = BlockAssertions.createLongSequenceBlock(0, length);
        } else {
            throw new IllegalArgumentException("unsupported data type");
        }
        Block hashBlock = TypeUtils.getHashBlock((List)ImmutableList.of((Object)type), (Block[])new Block[]{valuesBlock});
        Page page = new Page(new Block[]{valuesBlock, hashBlock});
        AtomicInteger currentQuota = new AtomicInteger(0);
        AtomicInteger allowedQuota = new AtomicInteger(6);
        UpdateMemory updateMemory = () -> {
            if (currentQuota.get() < allowedQuota.get()) {
                currentQuota.getAndIncrement();
                return true;
            }
            return false;
        };
        int yields = 0;
        GroupByHash groupByHash = GroupByHash.createGroupByHash((List)ImmutableList.of((Object)type), (int[])new int[]{0}, Optional.of(1), (int)1, (boolean)false, (JoinCompiler)JOIN_COMPILER, (UpdateMemory)updateMemory);
        boolean finish = false;
        Work addPageWork = groupByHash.addPage(page);
        while (!finish) {
            finish = addPageWork.process();
            if (finish) continue;
            Assert.assertEquals((int)currentQuota.get(), (int)allowedQuota.get());
            Assert.assertFalse((boolean)addPageWork.process());
            Assert.assertEquals((int)currentQuota.get(), (int)allowedQuota.get());
            ++yields;
            allowedQuota.getAndAdd(6);
        }
        Assert.assertEquals((int)length, (int)groupByHash.getGroupCount());
        Assert.assertEquals((int)currentQuota.get(), (int)40);
        Assert.assertEquals((int)(currentQuota.get() / 3 / 2), (int)yields);
        currentQuota.set(0);
        allowedQuota.set(6);
        yields = 0;
        groupByHash = GroupByHash.createGroupByHash((List)ImmutableList.of((Object)type), (int[])new int[]{0}, Optional.of(1), (int)1, (boolean)false, (JoinCompiler)JOIN_COMPILER, (UpdateMemory)updateMemory);
        finish = false;
        Work getGroupIdsWork = groupByHash.getGroupIds(page);
        while (!finish) {
            finish = getGroupIdsWork.process();
            if (finish) continue;
            Assert.assertEquals((int)currentQuota.get(), (int)allowedQuota.get());
            Assert.assertFalse((boolean)getGroupIdsWork.process());
            Assert.assertEquals((int)currentQuota.get(), (int)allowedQuota.get());
            ++yields;
            allowedQuota.getAndAdd(6);
        }
        Assert.assertEquals((int)length, (int)groupByHash.getGroupCount());
        Assert.assertEquals((int)length, (int)((GroupByIdBlock)getGroupIdsWork.getResult()).getPositionCount());
        Assert.assertEquals((int)currentQuota.get(), (int)40);
        Assert.assertEquals((int)(currentQuota.get() / 3 / 2), (int)yields);
    }

    @Test
    public void testMemoryReservationYieldWithDictionary() {
        int dictionaryLength = 1000;
        int length = 2000000;
        int[] ids = IntStream.range(0, dictionaryLength).toArray();
        DictionaryId dictionaryId = DictionaryId.randomDictionaryId();
        DictionaryBlock valuesBlock = new DictionaryBlock(dictionaryLength, BlockAssertions.createStringSequenceBlock(0, length), ids, dictionaryId);
        DictionaryBlock hashBlock = new DictionaryBlock(dictionaryLength, TypeUtils.getHashBlock((List)ImmutableList.of((Object)VarcharType.VARCHAR), (Block[])new Block[]{valuesBlock}), ids, dictionaryId);
        Page page = new Page(new Block[]{valuesBlock, hashBlock});
        AtomicInteger currentQuota = new AtomicInteger(0);
        AtomicInteger allowedQuota = new AtomicInteger(6);
        UpdateMemory updateMemory = () -> {
            if (currentQuota.get() < allowedQuota.get()) {
                currentQuota.getAndIncrement();
                return true;
            }
            return false;
        };
        int yields = 0;
        GroupByHash groupByHash = GroupByHash.createGroupByHash((List)ImmutableList.of((Object)VarcharType.VARCHAR), (int[])new int[]{0}, Optional.of(1), (int)1, (boolean)true, (JoinCompiler)JOIN_COMPILER, (UpdateMemory)updateMemory);
        boolean finish = false;
        Work addPageWork = groupByHash.addPage(page);
        while (!finish) {
            finish = addPageWork.process();
            if (finish) continue;
            Assert.assertEquals((int)currentQuota.get(), (int)allowedQuota.get());
            Assert.assertFalse((boolean)addPageWork.process());
            Assert.assertEquals((int)currentQuota.get(), (int)allowedQuota.get());
            ++yields;
            allowedQuota.getAndAdd(6);
        }
        Assert.assertEquals((int)dictionaryLength, (int)groupByHash.getGroupCount());
        Assert.assertEquals((int)currentQuota.get(), (int)20);
        Assert.assertEquals((int)(currentQuota.get() / 3 / 2), (int)yields);
        currentQuota.set(0);
        allowedQuota.set(6);
        yields = 0;
        groupByHash = GroupByHash.createGroupByHash((List)ImmutableList.of((Object)VarcharType.VARCHAR), (int[])new int[]{0}, Optional.of(1), (int)1, (boolean)true, (JoinCompiler)JOIN_COMPILER, (UpdateMemory)updateMemory);
        finish = false;
        Work getGroupIdsWork = groupByHash.getGroupIds(page);
        while (!finish) {
            finish = getGroupIdsWork.process();
            if (finish) continue;
            Assert.assertEquals((int)currentQuota.get(), (int)allowedQuota.get());
            Assert.assertFalse((boolean)getGroupIdsWork.process());
            Assert.assertEquals((int)currentQuota.get(), (int)allowedQuota.get());
            ++yields;
            allowedQuota.getAndAdd(6);
        }
        Assert.assertEquals((int)dictionaryLength, (int)groupByHash.getGroupCount());
        Assert.assertEquals((int)dictionaryLength, (int)((GroupByIdBlock)getGroupIdsWork.getResult()).getPositionCount());
        Assert.assertEquals((int)currentQuota.get(), (int)20);
        Assert.assertEquals((int)(currentQuota.get() / 3 / 2), (int)yields);
    }
}

