/*
 * Decompiled with CFR 0.152.
 */
package io.trino.operator.output;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import io.airlift.log.Logger;
import io.airlift.units.DataSize;
import io.trino.Session;
import io.trino.execution.resourcegroups.IndexedPriorityQueue;
import io.trino.operator.PartitionFunction;
import io.trino.spi.connector.ConnectorBucketNodeMap;
import io.trino.spi.type.Type;
import io.trino.sql.planner.NodePartitioningManager;
import io.trino.sql.planner.PartitioningHandle;
import io.trino.sql.planner.PartitioningScheme;
import io.trino.sql.planner.SystemPartitioningHandle;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicLongArray;
import java.util.stream.IntStream;
import javax.annotation.concurrent.GuardedBy;
import javax.annotation.concurrent.ThreadSafe;

@ThreadSafe
public class SkewedPartitionRebalancer {
    private static final Logger log = Logger.get(SkewedPartitionRebalancer.class);
    private static final int SCALE_WRITERS_PARTITION_COUNT = 4096;
    private static final double TASK_BUCKET_SKEWNESS_THRESHOLD = 0.7;
    private static final long MIN_DATA_PROCESSED_REBALANCE_THRESHOLD = DataSize.of((long)50L, (DataSize.Unit)DataSize.Unit.MEGABYTE).toBytes();
    private final int partitionCount;
    private final int taskCount;
    private final int taskBucketCount;
    private final long minPartitionDataProcessedRebalanceThreshold;
    private final long minDataProcessedRebalanceThreshold;
    private final AtomicLongArray partitionRowCount;
    private final AtomicLong dataProcessed;
    private final AtomicLong dataProcessedAtLastRebalance;
    @GuardedBy(value="this")
    private final long[] partitionDataSizeAtLastRebalance;
    @GuardedBy(value="this")
    private final long[] partitionDataSizeSinceLastRebalancePerTask;
    @GuardedBy(value="this")
    private final long[] estimatedTaskBucketDataSizeSinceLastRebalance;
    private final List<List<TaskBucket>> partitionAssignments;

    public static boolean checkCanScalePartitionsRemotely(Session session, int taskCount, PartitioningHandle partitioningHandle, NodePartitioningManager nodePartitioningManager) {
        boolean hasFixedNodeMapping = partitioningHandle.getCatalogHandle().map(handle -> nodePartitioningManager.getConnectorBucketNodeMap(session, partitioningHandle).map(ConnectorBucketNodeMap::hasFixedMapping).orElse(false)).orElse(false);
        return taskCount > 1 && !hasFixedNodeMapping && PartitioningHandle.isScaledWriterHashDistribution(partitioningHandle);
    }

    public static PartitionFunction createPartitionFunction(Session session, NodePartitioningManager nodePartitioningManager, PartitioningScheme scheme, List<Type> partitionChannelTypes) {
        PartitioningHandle handle = scheme.getPartitioning().getHandle();
        int bucketCount = handle.getConnectorHandle() instanceof SystemPartitioningHandle ? 4096 : nodePartitioningManager.getBucketNodeMap(session, handle).getBucketCount();
        return nodePartitioningManager.getPartitionFunction(session, scheme, partitionChannelTypes, IntStream.range(0, bucketCount).toArray());
    }

    public static SkewedPartitionRebalancer createSkewedPartitionRebalancer(int partitionCount, int taskCount, int taskPartitionedWriterCount, long minPartitionDataProcessedRebalanceThreshold) {
        int taskBucketCount = (int)Math.ceil(0.5 * (double)taskPartitionedWriterCount);
        return new SkewedPartitionRebalancer(partitionCount, taskCount, taskBucketCount, minPartitionDataProcessedRebalanceThreshold);
    }

    public static int getTaskCount(PartitioningScheme partitioningScheme) {
        int[] bucketToPartition = partitioningScheme.getBucketToPartition().orElseThrow(() -> new IllegalArgumentException("Bucket to partition must be set before calculating taskCount"));
        return IntStream.of(bucketToPartition).max().getAsInt() + 1;
    }

    private SkewedPartitionRebalancer(int partitionCount, int taskCount, int taskBucketCount, long minPartitionDataProcessedRebalanceThreshold) {
        this.partitionCount = partitionCount;
        this.taskCount = taskCount;
        this.taskBucketCount = taskBucketCount;
        this.minPartitionDataProcessedRebalanceThreshold = minPartitionDataProcessedRebalanceThreshold;
        this.minDataProcessedRebalanceThreshold = Math.max(minPartitionDataProcessedRebalanceThreshold, MIN_DATA_PROCESSED_REBALANCE_THRESHOLD);
        this.partitionRowCount = new AtomicLongArray(partitionCount);
        this.dataProcessed = new AtomicLong();
        this.dataProcessedAtLastRebalance = new AtomicLong();
        this.partitionDataSizeAtLastRebalance = new long[partitionCount];
        this.partitionDataSizeSinceLastRebalancePerTask = new long[partitionCount];
        this.estimatedTaskBucketDataSizeSinceLastRebalance = new long[taskCount * taskBucketCount];
        int[] taskBucketIds = new int[taskCount];
        ImmutableList.Builder partitionAssignments = ImmutableList.builder();
        for (int partition = 0; partition < partitionCount; ++partition) {
            int taskId;
            int n = taskId = partition % taskCount;
            int n2 = taskBucketIds[n];
            taskBucketIds[n] = n2 + 1;
            int bucketId = n2 % taskBucketCount;
            partitionAssignments.add(new CopyOnWriteArrayList(ImmutableList.of((Object)new TaskBucket(taskId, bucketId))));
        }
        this.partitionAssignments = partitionAssignments.build();
    }

    @VisibleForTesting
    List<List<Integer>> getPartitionAssignments() {
        ImmutableList.Builder assignedTasks = ImmutableList.builder();
        for (List<TaskBucket> partitionAssignment : this.partitionAssignments) {
            List tasks = (List)partitionAssignment.stream().map(taskBucket -> taskBucket.taskId).collect(ImmutableList.toImmutableList());
            assignedTasks.add((Object)tasks);
        }
        return assignedTasks.build();
    }

    public int getTaskCount() {
        return this.taskCount;
    }

    public int getTaskId(int partitionId, long index) {
        List<TaskBucket> taskIds = this.partitionAssignments.get(partitionId);
        return taskIds.get((int)Math.floorMod((long)index, (int)taskIds.size())).taskId;
    }

    public void addDataProcessed(long dataSize) {
        this.dataProcessed.addAndGet(dataSize);
    }

    public void addPartitionRowCount(int partition, long rowCount) {
        this.partitionRowCount.addAndGet(partition, rowCount);
    }

    public void rebalance() {
        long currentDataProcessed = this.dataProcessed.get();
        if (this.shouldRebalance(currentDataProcessed)) {
            this.rebalancePartitions(currentDataProcessed);
        }
    }

    private boolean shouldRebalance(long dataProcessed) {
        return dataProcessed - this.dataProcessedAtLastRebalance.get() >= this.minDataProcessedRebalanceThreshold;
    }

    private synchronized void rebalancePartitions(long dataProcessed) {
        if (!this.shouldRebalance(dataProcessed)) {
            return;
        }
        long[] partitionDataSize = this.calculatePartitionDataSize(dataProcessed);
        for (int partition = 0; partition < this.partitionCount; ++partition) {
            int totalAssignedTasks = this.partitionAssignments.get(partition).size();
            this.partitionDataSizeSinceLastRebalancePerTask[partition] = (partitionDataSize[partition] - this.partitionDataSizeAtLastRebalance[partition]) / (long)totalAssignedTasks;
        }
        ArrayList<IndexedPriorityQueue<Integer>> taskBucketMaxPartitions = new ArrayList<IndexedPriorityQueue<Integer>>(this.taskCount * this.taskBucketCount);
        for (int taskId = 0; taskId < this.taskCount; ++taskId) {
            for (int bucketId = 0; bucketId < this.taskBucketCount; ++bucketId) {
                taskBucketMaxPartitions.add(new IndexedPriorityQueue());
            }
        }
        for (int partition = 0; partition < this.partitionCount; ++partition) {
            List<TaskBucket> taskAssignments = this.partitionAssignments.get(partition);
            for (TaskBucket taskBucket : taskAssignments) {
                IndexedPriorityQueue queue = (IndexedPriorityQueue)taskBucketMaxPartitions.get(taskBucket.id);
                queue.addOrUpdate(partition, this.partitionDataSizeSinceLastRebalancePerTask[partition]);
            }
        }
        IndexedPriorityQueue<TaskBucket> maxTaskBuckets = new IndexedPriorityQueue<TaskBucket>();
        IndexedPriorityQueue<TaskBucket> minTaskBuckets = new IndexedPriorityQueue<TaskBucket>();
        for (int taskId = 0; taskId < this.taskCount; ++taskId) {
            for (int bucketId = 0; bucketId < this.taskBucketCount; ++bucketId) {
                TaskBucket taskBucket = new TaskBucket(taskId, bucketId);
                this.estimatedTaskBucketDataSizeSinceLastRebalance[taskBucket.id] = this.calculateTaskBucketDataSizeSinceLastRebalance((IndexedPriorityQueue)taskBucketMaxPartitions.get(taskBucket.id));
                maxTaskBuckets.addOrUpdate(taskBucket, this.estimatedTaskBucketDataSizeSinceLastRebalance[taskBucket.id]);
                minTaskBuckets.addOrUpdate(taskBucket, Long.MAX_VALUE - this.estimatedTaskBucketDataSizeSinceLastRebalance[taskBucket.id]);
            }
        }
        this.rebalanceBasedOnTaskBucketSkewness(maxTaskBuckets, minTaskBuckets, taskBucketMaxPartitions, partitionDataSize);
        this.dataProcessedAtLastRebalance.set(dataProcessed);
    }

    private long[] calculatePartitionDataSize(long dataProcessed) {
        long totalPartitionRowCount = 0L;
        for (int partition = 0; partition < this.partitionCount; ++partition) {
            totalPartitionRowCount += this.partitionRowCount.get(partition);
        }
        long[] partitionDataSize = new long[this.partitionCount];
        for (int partition = 0; partition < this.partitionCount; ++partition) {
            partitionDataSize[partition] = this.partitionRowCount.get(partition) * dataProcessed / totalPartitionRowCount;
        }
        return partitionDataSize;
    }

    private long calculateTaskBucketDataSizeSinceLastRebalance(IndexedPriorityQueue<Integer> maxPartitions) {
        long estimatedDataSizeSinceLastRebalance = 0L;
        for (int partition : maxPartitions) {
            estimatedDataSizeSinceLastRebalance += this.partitionDataSizeSinceLastRebalancePerTask[partition];
        }
        return estimatedDataSizeSinceLastRebalance;
    }

    private void rebalanceBasedOnTaskBucketSkewness(IndexedPriorityQueue<TaskBucket> maxTaskBuckets, IndexedPriorityQueue<TaskBucket> minTaskBuckets, List<IndexedPriorityQueue<Integer>> taskBucketMaxPartitions, long[] partitionDataSize) {
        TaskBucket maxTaskBucket;
        while ((maxTaskBucket = maxTaskBuckets.poll()) != null) {
            Integer maxPartition;
            IndexedPriorityQueue<Integer> maxPartitions = taskBucketMaxPartitions.get(maxTaskBucket.id);
            if (maxPartitions.isEmpty()) continue;
            List<TaskBucket> minSkewedTaskBuckets = this.findSkewedMinTaskBuckets(maxTaskBucket, minTaskBuckets);
            if (minSkewedTaskBuckets.isEmpty()) break;
            block1: while ((maxPartition = maxPartitions.poll()) != null && this.partitionDataSizeSinceLastRebalancePerTask[maxPartition] >= this.minPartitionDataProcessedRebalanceThreshold) {
                for (TaskBucket minTaskBucket : minSkewedTaskBuckets) {
                    if (!this.rebalancePartition(maxPartition, minTaskBucket, maxTaskBuckets, minTaskBuckets, partitionDataSize[maxPartition])) continue;
                    continue block1;
                }
            }
        }
    }

    private List<TaskBucket> findSkewedMinTaskBuckets(TaskBucket maxTaskBucket, IndexedPriorityQueue<TaskBucket> minTaskBuckets) {
        ImmutableList.Builder minSkewedTaskBuckets = ImmutableList.builder();
        for (TaskBucket minTaskBucket : minTaskBuckets) {
            double skewness = (double)(this.estimatedTaskBucketDataSizeSinceLastRebalance[maxTaskBucket.id] - this.estimatedTaskBucketDataSizeSinceLastRebalance[minTaskBucket.id]) / (double)this.estimatedTaskBucketDataSizeSinceLastRebalance[maxTaskBucket.id];
            if (skewness <= 0.7 || Double.isNaN(skewness)) break;
            if (maxTaskBucket.taskId == minTaskBucket.taskId) continue;
            minSkewedTaskBuckets.add((Object)minTaskBucket);
        }
        return minSkewedTaskBuckets.build();
    }

    private boolean rebalancePartition(int partitionId, TaskBucket toTaskBucket, IndexedPriorityQueue<TaskBucket> maxTasks, IndexedPriorityQueue<TaskBucket> minTasks, long partitionDataSize) {
        List<TaskBucket> assignments = this.partitionAssignments.get(partitionId);
        if (assignments.stream().anyMatch(taskBucket -> taskBucket.taskId == toTaskBucket.taskId)) {
            return false;
        }
        assignments.add(toTaskBucket);
        this.partitionDataSizeAtLastRebalance[partitionId] = partitionDataSize;
        int newTaskCount = assignments.size();
        int oldTaskCount = newTaskCount - 1;
        for (TaskBucket taskBucket2 : assignments) {
            if (taskBucket2.equals(toTaskBucket)) {
                int n = taskBucket2.id;
                this.estimatedTaskBucketDataSizeSinceLastRebalance[n] = this.estimatedTaskBucketDataSizeSinceLastRebalance[n] + this.partitionDataSizeSinceLastRebalancePerTask[partitionId] * (long)oldTaskCount / (long)newTaskCount;
            } else {
                int n = taskBucket2.id;
                this.estimatedTaskBucketDataSizeSinceLastRebalance[n] = this.estimatedTaskBucketDataSizeSinceLastRebalance[n] - this.partitionDataSizeSinceLastRebalancePerTask[partitionId] / (long)newTaskCount;
            }
            maxTasks.addOrUpdate(taskBucket2, this.estimatedTaskBucketDataSizeSinceLastRebalance[taskBucket2.id]);
            minTasks.addOrUpdate(taskBucket2, Long.MAX_VALUE - this.estimatedTaskBucketDataSizeSinceLastRebalance[taskBucket2.id]);
        }
        log.debug("Rebalanced partition %s to task %s with taskCount %s", new Object[]{partitionId, toTaskBucket.taskId, assignments.size()});
        return true;
    }

    private final class TaskBucket {
        private final int taskId;
        private final int id;

        private TaskBucket(int taskId, int bucketId) {
            this.taskId = taskId;
            this.id = taskId * SkewedPartitionRebalancer.this.taskBucketCount + bucketId;
        }

        public int hashCode() {
            return Objects.hash(this.id, this.id);
        }

        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || this.getClass() != o.getClass()) {
                return false;
            }
            TaskBucket that = (TaskBucket)o;
            return that.id == this.id;
        }
    }
}

