/*
 * Decompiled with CFR 0.152.
 */
package org.apache.paimon.flink.lookup;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Random;
import java.util.Set;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.table.data.GenericRowData;
import org.apache.flink.table.data.RowData;
import org.apache.paimon.flink.CatalogITCaseBase;
import org.apache.paimon.flink.lookup.partitioner.BucketIdExtractor;
import org.apache.paimon.flink.lookup.partitioner.BucketShufflePartitioner;
import org.apache.paimon.flink.lookup.partitioner.BucketShuffleStrategy;
import org.apache.paimon.flink.lookup.partitioner.ShuffleStrategy;
import org.apache.paimon.manifest.ManifestEntry;
import org.apache.paimon.table.FileStoreTable;
import org.apache.paimon.table.source.DataSplit;
import org.apache.paimon.table.source.Split;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;

public class BucketShufflePartitionerTest
extends CatalogITCaseBase {
    @Test
    public void testBucketNumLessThanLookupJoinParallelism() throws Exception {
        int numBuckets = 100;
        int lookupJoinParallelism = 27;
        FileStoreTable table = this.createTestTable(numBuckets);
        BucketShuffleStrategy strategy = new BucketShuffleStrategy(numBuckets);
        BucketShufflePartitioner bucketShufflePartitioner = this.createBucketShufflePartitioner(table, numBuckets, strategy);
        List<Tuple2<RowData, Integer>> joinKeysList = this.getGroundTruthJoinKeysWithBucketId(table, numBuckets);
        HashMap<Integer, List> partitionResult = new HashMap<Integer, List>();
        for (Tuple2<RowData, Integer> joinKeysWithBucketId : joinKeysList) {
            int subtaskId = bucketShufflePartitioner.partition((RowData)joinKeysWithBucketId.f0, lookupJoinParallelism);
            partitionResult.compute(subtaskId, (key, currentValue) -> {
                List newValue = currentValue != null ? currentValue : new ArrayList();
                newValue.add(joinKeysWithBucketId);
                return newValue;
            });
        }
        for (int subtaskId = 0; subtaskId < lookupJoinParallelism; ++subtaskId) {
            List joinKeysWithBucketIdList = (List)partitionResult.get(subtaskId);
            Set requiredCacheBucketIds = strategy.getRequiredCacheBucketIds(subtaskId, lookupJoinParallelism);
            for (Tuple2 joinKeysWithBucketId : joinKeysWithBucketIdList) {
                Assertions.assertThat((Collection)requiredCacheBucketIds).contains((Object[])new Integer[]{(Integer)joinKeysWithBucketId.f1});
                Assertions.assertThat((int)((Integer)joinKeysWithBucketId.f1 % lookupJoinParallelism)).isEqualTo(subtaskId);
            }
        }
    }

    @Test
    public void testBucketNumEqualToLookupJoinParallelism() throws Exception {
        int numBuckets = 78;
        int lookupJoinParallelism = 78;
        FileStoreTable table = this.createTestTable(numBuckets);
        BucketShuffleStrategy strategy = new BucketShuffleStrategy(numBuckets);
        BucketShufflePartitioner bucketShufflePartitioner = this.createBucketShufflePartitioner(table, numBuckets, strategy);
        List<Tuple2<RowData, Integer>> joinKeysList = this.getGroundTruthJoinKeysWithBucketId(table, numBuckets);
        HashMap<Integer, List> partitionResult = new HashMap<Integer, List>();
        for (Tuple2<RowData, Integer> joinKeysWithBucketId : joinKeysList) {
            int subtaskId = bucketShufflePartitioner.partition((RowData)joinKeysWithBucketId.f0, lookupJoinParallelism);
            partitionResult.compute(subtaskId, (key, currentValue) -> {
                List newValue = currentValue != null ? currentValue : new ArrayList();
                newValue.add(joinKeysWithBucketId);
                return newValue;
            });
        }
        for (int subtaskId = 0; subtaskId < lookupJoinParallelism; ++subtaskId) {
            List joinKeysWithBucketIdList = (List)partitionResult.get(subtaskId);
            Set requiredCacheBucketIds = strategy.getRequiredCacheBucketIds(subtaskId, lookupJoinParallelism);
            for (Tuple2 joinKeysWithBucketId : joinKeysWithBucketIdList) {
                Assertions.assertThat((int)requiredCacheBucketIds.size()).isOne();
                Assertions.assertThat((Collection)requiredCacheBucketIds).contains((Object[])new Integer[]{(Integer)joinKeysWithBucketId.f1});
                Assertions.assertThat((int)((Integer)joinKeysWithBucketId.f1 % lookupJoinParallelism)).isEqualTo(subtaskId);
            }
        }
    }

    @Test
    public void testBucketNumLargerThanLookupJoinParallelism() throws Exception {
        int numBuckets = 4;
        int lookupJoinParallelism = 15;
        FileStoreTable table = this.createTestTable(numBuckets);
        BucketShuffleStrategy strategy = new BucketShuffleStrategy(numBuckets);
        BucketShufflePartitioner bucketShufflePartitioner = this.createBucketShufflePartitioner(table, numBuckets, strategy);
        List<Tuple2<RowData, Integer>> joinKeysList = this.getGroundTruthJoinKeysWithBucketId(table, numBuckets);
        HashMap<Integer, List> partitionResult = new HashMap<Integer, List>();
        for (Tuple2<RowData, Integer> joinKeysWithBucketId : joinKeysList) {
            int subtaskId = bucketShufflePartitioner.partition((RowData)joinKeysWithBucketId.f0, lookupJoinParallelism);
            partitionResult.compute(subtaskId, (key, currentValue) -> {
                List newValue = currentValue != null ? currentValue : new ArrayList();
                newValue.add(joinKeysWithBucketId);
                return newValue;
            });
        }
        for (int subtaskId = 0; subtaskId < lookupJoinParallelism; ++subtaskId) {
            List joinKeysWithBucketIdList = (List)partitionResult.get(subtaskId);
            Set requiredCacheBucketIds = strategy.getRequiredCacheBucketIds(subtaskId, lookupJoinParallelism);
            for (Tuple2 joinKeysWithBucketId : joinKeysWithBucketIdList) {
                Assertions.assertThat((Collection)requiredCacheBucketIds).contains((Object[])new Integer[]{(Integer)joinKeysWithBucketId.f1});
                Assertions.assertThat((int)(subtaskId % numBuckets)).isEqualTo(joinKeysWithBucketId.f1);
            }
        }
    }

    private List<Tuple2<RowData, Integer>> getGroundTruthJoinKeysWithBucketId(FileStoreTable table, int numBuckets) throws IOException {
        ArrayList<Tuple2<RowData, Integer>> joinKeyRows = new ArrayList<Tuple2<RowData, Integer>>();
        Random random = new Random();
        int bucketId = 0;
        while (bucketId < numBuckets) {
            ManifestEntry file = (ManifestEntry)table.store().newScan().withBucket(bucketId).plan().files().get(0);
            DataSplit dataSplit = DataSplit.builder().withPartition(file.partition()).withBucket(file.bucket()).withDataFiles(Collections.singletonList(file.file())).withBucketPath("not used").build();
            int bucket = bucketId++;
            table.newReadBuilder().newRead().createReader((Split)dataSplit).forEachRemaining(internalRow -> joinKeyRows.add(Tuple2.of((Object)GenericRowData.of((Object[])new Object[]{String.valueOf(random.nextInt(numBuckets)), internalRow.getInt(1)}), (Object)bucket)));
        }
        return joinKeyRows;
    }

    private BucketShufflePartitioner createBucketShufflePartitioner(FileStoreTable table, int numBuckets, BucketShuffleStrategy strategy) {
        BucketIdExtractor bucketIdExtractor = new BucketIdExtractor(numBuckets, table.schema(), Arrays.asList("col1", "col2"), Collections.singletonList("col2"));
        return new BucketShufflePartitioner((ShuffleStrategy)strategy, bucketIdExtractor);
    }

    private FileStoreTable createTestTable(int bucketNum) throws Exception {
        String tableName = "Test";
        String ddl = String.format("CREATE TABLE %s (col1 STRING, col2 INT, col3 FLOAT) WITH ('bucket'='%s', 'bucket-key' = 'col2')", tableName, bucketNum);
        this.batchSql(ddl, new Object[0]);
        StringBuilder dml = new StringBuilder(String.format("INSERT INTO %s VALUES ", tableName));
        for (int index = 1; index < 1000; ++index) {
            dml.append(String.format("('%s', %s, %s), ", index, index, Float.valueOf(101.1f)));
        }
        dml.append(String.format("('%s', %s, %s)", 1000, 1000, Float.valueOf(101.1f)));
        this.batchSql(dml.toString(), new Object[0]);
        return this.paimonTable(tableName);
    }
}

