/*
 * Decompiled with CFR 0.152.
 */
package io.trino.operator.exchange;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import io.airlift.log.Logger;
import io.airlift.units.DataSize;
import io.trino.execution.resourcegroups.IndexedPriorityQueue;
import it.unimi.dsi.fastutil.longs.Long2LongMap;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicLongArray;
import java.util.function.Supplier;
import javax.annotation.concurrent.ThreadSafe;

@ThreadSafe
public class UniformPartitionRebalancer {
    private static final Logger log = Logger.get(UniformPartitionRebalancer.class);
    private static final double SKEWNESS_THRESHOLD = 0.7;
    private final List<Supplier<Long>> writerPhysicalWrittenBytesSuppliers;
    private final Supplier<Long2LongMap> partitionRowCountsSupplier;
    private final long writerMinSize;
    private final int numberOfWriters;
    private final long rebalanceThresholdMinPhysicalWrittenBytes;
    private final AtomicLongArray writerPhysicalWrittenBytesAtLastRebalance;
    private final PartitionInfo[] partitionInfos;

    public UniformPartitionRebalancer(List<Supplier<Long>> writerPhysicalWrittenBytesSuppliers, Supplier<Long2LongMap> partitionRowCountsSupplier, int partitionCount, int numberOfWriters, long writerMinSize) {
        this.writerPhysicalWrittenBytesSuppliers = Objects.requireNonNull(writerPhysicalWrittenBytesSuppliers, "writerPhysicalWrittenBytesSuppliers is null");
        this.partitionRowCountsSupplier = Objects.requireNonNull(partitionRowCountsSupplier, "partitionRowCountsSupplier is null");
        this.writerMinSize = writerMinSize;
        this.numberOfWriters = numberOfWriters;
        this.rebalanceThresholdMinPhysicalWrittenBytes = Math.max(DataSize.of((long)50L, (DataSize.Unit)DataSize.Unit.MEGABYTE).toBytes(), writerMinSize);
        this.writerPhysicalWrittenBytesAtLastRebalance = new AtomicLongArray(numberOfWriters);
        this.partitionInfos = new PartitionInfo[partitionCount];
        for (int i = 0; i < partitionCount; ++i) {
            this.partitionInfos[i] = new PartitionInfo(i % numberOfWriters);
        }
    }

    public int getWriterId(int partitionId, int index) {
        return this.partitionInfos[partitionId].getWriterId(index);
    }

    @VisibleForTesting
    List<Integer> getWriterIds(int partitionId) {
        return this.partitionInfos[partitionId].getWriterIds();
    }

    public void rebalancePartitions() {
        List writerPhysicalWrittenBytes = (List)this.writerPhysicalWrittenBytesSuppliers.stream().map(Supplier::get).collect(ImmutableList.toImmutableList());
        if ((long)this.getPhysicalWrittenBytesSinceLastRebalance(writerPhysicalWrittenBytes) > this.rebalanceThresholdMinPhysicalWrittenBytes) {
            this.rebalancePartitions(writerPhysicalWrittenBytes);
        }
    }

    private int getPhysicalWrittenBytesSinceLastRebalance(List<Long> writerPhysicalWrittenBytes) {
        int physicalWrittenBytesSinceLastRebalance = 0;
        for (int writerId = 0; writerId < writerPhysicalWrittenBytes.size(); ++writerId) {
            physicalWrittenBytesSinceLastRebalance = (int)((long)physicalWrittenBytesSinceLastRebalance + (writerPhysicalWrittenBytes.get(writerId) - this.writerPhysicalWrittenBytesAtLastRebalance.get(writerId)));
        }
        return physicalWrittenBytesSinceLastRebalance;
    }

    private synchronized void rebalancePartitions(List<Long> writerPhysicalWrittenBytes) {
        List<WriterId> minSkewedWriters;
        WriterId maxWriter;
        Long2LongMap partitionRowCounts = this.partitionRowCountsSupplier.get();
        RebalanceContext context = new RebalanceContext(writerPhysicalWrittenBytes, partitionRowCounts);
        IndexedPriorityQueue<WriterId> maxWriters = new IndexedPriorityQueue<WriterId>();
        IndexedPriorityQueue<WriterId> minWriters = new IndexedPriorityQueue<WriterId>();
        for (int writerId = 0; writerId < this.numberOfWriters; ++writerId) {
            WriterId writer = new WriterId(writerId);
            maxWriters.addOrUpdate(writer, context.getWriterEstimatedWrittenBytes(writer));
            minWriters.addOrUpdate(writer, Long.MAX_VALUE - context.getWriterEstimatedWrittenBytes(writer));
        }
        while ((maxWriter = (WriterId)maxWriters.poll()) != null && !(minSkewedWriters = this.findSkewedMinWriters(context, maxWriter, minWriters)).isEmpty()) {
            for (WriterId minSkewedWriter : minSkewedWriters) {
                List<WriterId> affectedWriters = context.rebalancePartition(maxWriter, minSkewedWriter);
                if (affectedWriters.isEmpty()) continue;
                for (WriterId affectedWriter : affectedWriters) {
                    maxWriters.addOrUpdate(affectedWriter, context.getWriterEstimatedWrittenBytes(maxWriter));
                    minWriters.addOrUpdate(affectedWriter, Long.MAX_VALUE - context.getWriterEstimatedWrittenBytes(maxWriter));
                }
            }
            for (WriterId minSkewedWriter : minSkewedWriters) {
                maxWriters.addOrUpdate(minSkewedWriter, context.getWriterEstimatedWrittenBytes(minSkewedWriter));
                minWriters.addOrUpdate(minSkewedWriter, Long.MAX_VALUE - context.getWriterEstimatedWrittenBytes(minSkewedWriter));
            }
        }
        this.resetStateForNextRebalance(context, writerPhysicalWrittenBytes, partitionRowCounts);
    }

    private List<WriterId> findSkewedMinWriters(RebalanceContext context, WriterId maxWriter, IndexedPriorityQueue<WriterId> minWriters) {
        long minWriterWrittenBytes;
        double skewness;
        WriterId minWriter;
        ImmutableList.Builder minSkewedWriters = ImmutableList.builder();
        long maxWriterWrittenBytes = context.getWriterEstimatedWrittenBytes(maxWriter);
        while ((minWriter = minWriters.poll()) != null && !((skewness = (double)(maxWriterWrittenBytes - (minWriterWrittenBytes = context.getWriterEstimatedWrittenBytes(minWriter))) / (double)maxWriterWrittenBytes) <= 0.7) && !Double.isNaN(skewness)) {
            minSkewedWriters.add((Object)minWriter);
        }
        return minSkewedWriters.build();
    }

    private void resetStateForNextRebalance(RebalanceContext context, List<Long> writerPhysicalWrittenBytes, Long2LongMap partitionRowCounts) {
        partitionRowCounts.forEach((serializedKey, rowCount) -> {
            WriterPartitionId writerPartitionId = WriterPartitionId.deserialize(serializedKey);
            PartitionInfo partitionInfo = this.partitionInfos[writerPartitionId.partitionId];
            if (context.isPartitionRebalanced(writerPartitionId.partitionId)) {
                partitionInfo.resetPhysicalWrittenBytesAtLastRebalance();
            } else {
                long writtenBytes = context.estimatePartitionWrittenBytesSinceLastRebalance(new WriterId(writerPartitionId.writerId), (long)rowCount);
                partitionInfo.addToPhysicalWrittenBytesAtLastRebalance(writtenBytes);
            }
        });
        for (int i = 0; i < this.numberOfWriters; ++i) {
            this.writerPhysicalWrittenBytesAtLastRebalance.set(i, writerPhysicalWrittenBytes.get(i));
        }
    }

    @ThreadSafe
    private static class PartitionInfo {
        private final List<Integer> writerAssignments;
        private final AtomicLong physicalWrittenBytesAtLastRebalance = new AtomicLong(0L);

        private PartitionInfo(int initialWriterId) {
            this.writerAssignments = new CopyOnWriteArrayList<Integer>((Collection<Integer>)ImmutableList.of((Object)initialWriterId));
        }

        private boolean containsWriter(int writerId) {
            return this.writerAssignments.contains(writerId);
        }

        private void addWriter(int writerId) {
            this.writerAssignments.add(writerId);
        }

        private int getWriterId(int index) {
            return this.writerAssignments.get(Math.floorMod(index, this.getWriterCount()));
        }

        private List<Integer> getWriterIds() {
            return ImmutableList.copyOf(this.writerAssignments);
        }

        private int getWriterCount() {
            return this.writerAssignments.size();
        }

        private void resetPhysicalWrittenBytesAtLastRebalance() {
            this.physicalWrittenBytesAtLastRebalance.set(0L);
        }

        private void addToPhysicalWrittenBytesAtLastRebalance(long writtenBytes) {
            this.physicalWrittenBytesAtLastRebalance.addAndGet(writtenBytes);
        }

        private long getPhysicalWrittenBytesAtLastRebalancePerWriter() {
            return this.physicalWrittenBytesAtLastRebalance.get() / (long)this.writerAssignments.size();
        }
    }

    private class RebalanceContext {
        private final Set<Integer> rebalancedPartitions = new HashSet<Integer>();
        private final long[] writerPhysicalWrittenBytesSinceLastRebalance;
        private final long[] writerRowCountSinceLastRebalance;
        private final long[] writerEstimatedWrittenBytes;
        private final List<IndexedPriorityQueue<PartitionIdWithRowCount>> writerMaxPartitions;

        private RebalanceContext(List<Long> writerPhysicalWrittenBytes, Long2LongMap partitionRowCounts) {
            int writerId;
            this.writerPhysicalWrittenBytesSinceLastRebalance = new long[UniformPartitionRebalancer.this.numberOfWriters];
            this.writerEstimatedWrittenBytes = new long[UniformPartitionRebalancer.this.numberOfWriters];
            for (writerId = 0; writerId < writerPhysicalWrittenBytes.size(); ++writerId) {
                long physicalWrittenBytesSinceLastRebalance;
                this.writerPhysicalWrittenBytesSinceLastRebalance[writerId] = physicalWrittenBytesSinceLastRebalance = writerPhysicalWrittenBytes.get(writerId) - UniformPartitionRebalancer.this.writerPhysicalWrittenBytesAtLastRebalance.get(writerId);
                this.writerEstimatedWrittenBytes[writerId] = physicalWrittenBytesSinceLastRebalance;
            }
            this.writerRowCountSinceLastRebalance = new long[UniformPartitionRebalancer.this.numberOfWriters];
            this.writerMaxPartitions = new ArrayList<IndexedPriorityQueue<PartitionIdWithRowCount>>(UniformPartitionRebalancer.this.numberOfWriters);
            for (writerId = 0; writerId < UniformPartitionRebalancer.this.numberOfWriters; ++writerId) {
                this.writerMaxPartitions.add(new IndexedPriorityQueue());
            }
            partitionRowCounts.forEach((serializedKey, rowCount) -> {
                WriterPartitionId writerPartitionId = WriterPartitionId.deserialize(serializedKey);
                int n = writerPartitionId.writerId;
                this.writerRowCountSinceLastRebalance[n] = this.writerRowCountSinceLastRebalance[n] + rowCount;
                this.writerMaxPartitions.get(writerPartitionId.writerId).addOrUpdate(new PartitionIdWithRowCount(writerPartitionId.partitionId, (long)rowCount), (long)rowCount);
            });
        }

        private List<WriterId> rebalancePartition(WriterId from, WriterId to) {
            IndexedPriorityQueue<PartitionIdWithRowCount> maxPartitions = this.writerMaxPartitions.get(from.id);
            ImmutableList.Builder affectedWriters = ImmutableList.builder();
            for (PartitionIdWithRowCount partitionToRebalance : maxPartitions) {
                PartitionInfo partitionInfo = UniformPartitionRebalancer.this.partitionInfos[partitionToRebalance.id];
                if (this.isPartitionRebalanced(partitionToRebalance.id) || partitionInfo.containsWriter(to.id)) continue;
                maxPartitions.remove(partitionToRebalance);
                long estimatedPartitionWrittenBytesSinceLastRebalance = this.estimatePartitionWrittenBytesSinceLastRebalance(from, partitionToRebalance.rowCount);
                long estimatedPartitionWrittenBytes = estimatedPartitionWrittenBytesSinceLastRebalance + partitionInfo.getPhysicalWrittenBytesAtLastRebalancePerWriter();
                if (partitionInfo.getWriterCount() > UniformPartitionRebalancer.this.numberOfWriters || estimatedPartitionWrittenBytes < UniformPartitionRebalancer.this.writerMinSize) break;
                partitionInfo.addWriter(to.id);
                this.rebalancedPartitions.add(partitionToRebalance.id);
                this.updateWriterEstimatedWrittenBytes(to, estimatedPartitionWrittenBytesSinceLastRebalance, partitionInfo);
                for (int writer : partitionInfo.getWriterIds()) {
                    affectedWriters.add((Object)new WriterId(writer));
                }
                log.debug("Scaled partition (%s) to writer %s with writer count %s", new Object[]{partitionToRebalance.id, to.id, partitionInfo.getWriterCount()});
                break;
            }
            return affectedWriters.build();
        }

        private void updateWriterEstimatedWrittenBytes(WriterId to, long estimatedPartitionWrittenBytesSinceLastRebalance, PartitionInfo partitionInfo) {
            int newWriterCount = partitionInfo.getWriterCount();
            int oldWriterCount = newWriterCount - 1;
            for (int writer : partitionInfo.getWriterIds()) {
                if (writer == to.id) continue;
                int n = writer;
                this.writerEstimatedWrittenBytes[n] = this.writerEstimatedWrittenBytes[n] - estimatedPartitionWrittenBytesSinceLastRebalance / (long)newWriterCount;
            }
            int n = to.id;
            this.writerEstimatedWrittenBytes[n] = this.writerEstimatedWrittenBytes[n] + estimatedPartitionWrittenBytesSinceLastRebalance * (long)oldWriterCount / (long)newWriterCount;
        }

        private long getWriterEstimatedWrittenBytes(WriterId writer) {
            return this.writerEstimatedWrittenBytes[writer.id];
        }

        private boolean isPartitionRebalanced(int partitionId) {
            return this.rebalancedPartitions.contains(partitionId);
        }

        private long estimatePartitionWrittenBytesSinceLastRebalance(WriterId writer, long partitionRowCount) {
            if (this.writerRowCountSinceLastRebalance[writer.id] == 0L) {
                return 0L;
            }
            return this.writerPhysicalWrittenBytesSinceLastRebalance[writer.id] * partitionRowCount / this.writerRowCountSinceLastRebalance[writer.id];
        }
    }

    private record WriterId(int id) {
    }

    public record WriterPartitionId(int writerId, int partitionId) {
        public static WriterPartitionId deserialize(long value) {
            int writerId = (int)(value >> 32);
            int partitionId = (int)value;
            return new WriterPartitionId(writerId, partitionId);
        }

        public static long serialize(WriterPartitionId writerPartitionId) {
            return (long)writerPartitionId.writerId << 32 | (long)writerPartitionId.partitionId & 0xFFFFFFFFL;
        }
    }

    private record PartitionIdWithRowCount(int id, long rowCount) {
        @Override
        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || this.getClass() != o.getClass()) {
                return false;
            }
            PartitionIdWithRowCount that = (PartitionIdWithRowCount)o;
            return this.id == that.id;
        }

        @Override
        public int hashCode() {
            return Objects.hashCode(this.id);
        }
    }
}

