/*
 * Decompiled with CFR 0.152.
 */
package io.prestosql.plugin.raptor.legacy.storage;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.VerifyException;
import com.google.common.collect.HashMultimap;
import com.google.common.collect.HashMultiset;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableMultimap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Multimap;
import com.google.common.collect.Multiset;
import io.airlift.concurrent.Threads;
import io.airlift.log.Logger;
import io.airlift.stats.CounterStat;
import io.airlift.units.Duration;
import io.prestosql.plugin.raptor.legacy.NodeSupplier;
import io.prestosql.plugin.raptor.legacy.RaptorConnectorId;
import io.prestosql.plugin.raptor.legacy.backup.BackupService;
import io.prestosql.plugin.raptor.legacy.metadata.BucketNode;
import io.prestosql.plugin.raptor.legacy.metadata.Distribution;
import io.prestosql.plugin.raptor.legacy.metadata.ShardManager;
import io.prestosql.plugin.raptor.legacy.storage.BucketBalancerConfig;
import io.prestosql.spi.Node;
import io.prestosql.spi.NodeManager;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.stream.Collectors;
import javax.annotation.PostConstruct;
import javax.annotation.PreDestroy;
import javax.inject.Inject;
import org.weakref.jmx.Managed;
import org.weakref.jmx.Nested;

public class BucketBalancer {
    private static final Logger log = Logger.get(BucketBalancer.class);
    private final NodeSupplier nodeSupplier;
    private final ShardManager shardManager;
    private final boolean enabled;
    private final Duration interval;
    private final boolean backupAvailable;
    private final boolean coordinator;
    private final ScheduledExecutorService executor;
    private final AtomicBoolean started = new AtomicBoolean();
    private final CounterStat bucketsBalanced = new CounterStat();
    private final CounterStat jobErrors = new CounterStat();

    @Inject
    public BucketBalancer(NodeManager nodeManager, NodeSupplier nodeSupplier, ShardManager shardManager, BucketBalancerConfig config, BackupService backupService, RaptorConnectorId connectorId) {
        this(nodeSupplier, shardManager, config.isBalancerEnabled(), config.getBalancerInterval(), backupService.isBackupAvailable(), nodeManager.getCurrentNode().isCoordinator(), connectorId.toString());
    }

    public BucketBalancer(NodeSupplier nodeSupplier, ShardManager shardManager, boolean enabled, Duration interval, boolean backupAvailable, boolean coordinator, String connectorId) {
        this.nodeSupplier = Objects.requireNonNull(nodeSupplier, "nodeSupplier is null");
        this.shardManager = Objects.requireNonNull(shardManager, "shardManager is null");
        this.enabled = enabled;
        this.interval = Objects.requireNonNull(interval, "interval is null");
        this.backupAvailable = backupAvailable;
        this.coordinator = coordinator;
        this.executor = Executors.newSingleThreadScheduledExecutor(Threads.daemonThreadsNamed((String)("bucket-balancer-" + connectorId)));
    }

    @PostConstruct
    public void start() {
        if (this.enabled && this.backupAvailable && this.coordinator && !this.started.getAndSet(true)) {
            this.executor.scheduleWithFixedDelay(this::runBalanceJob, this.interval.toMillis(), this.interval.toMillis(), TimeUnit.MILLISECONDS);
        }
    }

    @PreDestroy
    public void shutdown() {
        this.executor.shutdownNow();
    }

    @Managed
    @Nested
    public CounterStat getBucketsBalanced() {
        return this.bucketsBalanced;
    }

    @Managed
    @Nested
    public CounterStat getJobErrors() {
        return this.jobErrors;
    }

    @Managed
    public void startBalanceJob() {
        this.executor.submit(this::runBalanceJob);
    }

    private void runBalanceJob() {
        try {
            this.balance();
        }
        catch (Throwable t) {
            log.error(t, "Error balancing buckets");
            this.jobErrors.update(1L);
        }
    }

    @VisibleForTesting
    synchronized int balance() {
        log.info("Bucket balancer started. Computing assignments...");
        Multimap<String, BucketAssignment> sourceToAssignmentChanges = BucketBalancer.computeAssignmentChanges(this.fetchClusterState());
        log.info("Moving buckets...");
        int moves = this.updateAssignments(sourceToAssignmentChanges);
        log.info("Bucket balancing finished. Moved %s buckets.", new Object[]{moves});
        return moves;
    }

    private static Multimap<String, BucketAssignment> computeAssignmentChanges(ClusterState clusterState) {
        HashMultimap sourceToAllocationChanges = HashMultimap.create();
        HashMap<String, Long> allocationBytes = new HashMap<String, Long>(clusterState.getAssignedBytes());
        Set<String> activeNodes = clusterState.getActiveNodes();
        for (Distribution distribution : clusterState.getDistributionAssignments().keySet()) {
            HashMultiset allocationCounts = HashMultiset.create();
            Collection distributionAssignments = clusterState.getDistributionAssignments().get((Object)distribution);
            distributionAssignments.stream().map(BucketAssignment::getNodeIdentifier).forEach(arg_0 -> ((Multiset)allocationCounts).add(arg_0));
            int currentMin = allocationBytes.keySet().stream().mapToInt(arg_0 -> ((Multiset)allocationCounts).count(arg_0)).min().getAsInt();
            int currentMax = allocationBytes.keySet().stream().mapToInt(arg_0 -> ((Multiset)allocationCounts).count(arg_0)).max().getAsInt();
            int numBuckets = distributionAssignments.size();
            int targetMin = (int)Math.floor((double)numBuckets * 1.0 / (double)clusterState.getActiveNodes().size());
            int targetMax = (int)Math.ceil((double)numBuckets * 1.0 / (double)clusterState.getActiveNodes().size());
            log.info("Distribution %s: Current bucket skew: min %s, max %s. Target bucket skew: min %s, max %s", new Object[]{distribution.getId(), currentMin, currentMax, targetMin, targetMax});
            block1: for (String source : ImmutableSet.copyOf((Collection)allocationCounts)) {
                List existingAssignments = distributionAssignments.stream().filter(assignment -> assignment.getNodeIdentifier().equals(source)).collect(Collectors.toList());
                for (BucketAssignment existingAssignment : existingAssignments) {
                    if (activeNodes.contains(source) && allocationCounts.count((Object)source) <= targetMin) continue block1;
                    String target = activeNodes.stream().filter(arg_0 -> BucketBalancer.lambda$computeAssignmentChanges$1(source, (Multiset)allocationCounts, targetMax, arg_0)).sorted(Comparator.comparingInt(arg_0 -> ((Multiset)allocationCounts).count(arg_0))).min(Comparator.comparingDouble(allocationBytes::get)).orElseThrow(() -> new VerifyException("unable to find target for rebalancing"));
                    long bucketSize = clusterState.getDistributionBucketSize().get(distribution);
                    if (activeNodes.contains(source) && allocationCounts.count((Object)source) == targetMax && allocationCounts.count((Object)target) == targetMin) continue block1;
                    allocationCounts.remove((Object)source);
                    allocationCounts.add((Object)target);
                    allocationBytes.compute(source, (k, v) -> v - bucketSize);
                    allocationBytes.compute(target, (k, v) -> v + bucketSize);
                    sourceToAllocationChanges.put((Object)existingAssignment.getNodeIdentifier(), (Object)new BucketAssignment(existingAssignment.getDistributionId(), existingAssignment.getBucketNumber(), target));
                }
            }
        }
        return sourceToAllocationChanges;
    }

    private int updateAssignments(Multimap<String, BucketAssignment> sourceToAllocationChanges) {
        List sourceNodes = sourceToAllocationChanges.asMap().entrySet().stream().sorted((a, b) -> Integer.compare(((Collection)b.getValue()).size(), ((Collection)a.getValue()).size())).map(Map.Entry::getKey).collect(Collectors.toList());
        int moves = 0;
        for (String source : sourceNodes) {
            for (BucketAssignment reassignment : sourceToAllocationChanges.get((Object)source)) {
                this.shardManager.updateBucketAssignment(reassignment.getDistributionId(), reassignment.getBucketNumber(), reassignment.getNodeIdentifier());
                this.bucketsBalanced.update(1L);
                ++moves;
                log.info("Distribution %s: Moved bucket %s from %s to %s", new Object[]{reassignment.getDistributionId(), reassignment.getBucketNumber(), source, reassignment.getNodeIdentifier()});
            }
        }
        return moves;
    }

    @VisibleForTesting
    ClusterState fetchClusterState() {
        Set<String> activeNodes = this.nodeSupplier.getWorkerNodes().stream().map(Node::getNodeIdentifier).collect(Collectors.toSet());
        HashMap<String, Long> assignedNodeSize = new HashMap<String, Long>(activeNodes.stream().collect(Collectors.toMap(node -> node, node -> 0L)));
        ImmutableMultimap.Builder distributionAssignments = ImmutableMultimap.builder();
        ImmutableMap.Builder distributionBucketSize = ImmutableMap.builder();
        for (Distribution distribution : this.shardManager.getDistributions()) {
            long distributionSize = this.shardManager.getDistributionSizeInBytes(distribution.getId());
            long bucketSize = (long)(1.0 * (double)distributionSize) / (long)distribution.getBucketCount();
            distributionBucketSize.put((Object)distribution, (Object)bucketSize);
            for (BucketNode bucketNode : this.shardManager.getBucketNodes(distribution.getId())) {
                String node2 = bucketNode.getNodeIdentifier();
                distributionAssignments.put((Object)distribution, (Object)new BucketAssignment(distribution.getId(), bucketNode.getBucketNumber(), node2));
                assignedNodeSize.merge(node2, bucketSize, Math::addExact);
            }
        }
        return new ClusterState(activeNodes, assignedNodeSize, (Multimap<Distribution, BucketAssignment>)distributionAssignments.build(), (Map<Distribution, Long>)distributionBucketSize.build());
    }

    private static /* synthetic */ boolean lambda$computeAssignmentChanges$1(String source, Multiset allocationCounts, int targetMax, String candidate) {
        return !candidate.equals(source) && allocationCounts.count((Object)candidate) < targetMax;
    }

    @VisibleForTesting
    static class BucketAssignment {
        private final long distributionId;
        private final int bucketNumber;
        private final String nodeIdentifier;

        public BucketAssignment(long distributionId, int bucketNumber, String nodeIdentifier) {
            this.distributionId = distributionId;
            this.bucketNumber = bucketNumber;
            this.nodeIdentifier = Objects.requireNonNull(nodeIdentifier, "nodeIdentifier is null");
        }

        public long getDistributionId() {
            return this.distributionId;
        }

        public int getBucketNumber() {
            return this.bucketNumber;
        }

        public String getNodeIdentifier() {
            return this.nodeIdentifier;
        }
    }

    @VisibleForTesting
    static class ClusterState {
        private final Set<String> activeNodes;
        private final Map<String, Long> assignedBytes;
        private final Multimap<Distribution, BucketAssignment> distributionAssignments;
        private final Map<Distribution, Long> distributionBucketSize;

        public ClusterState(Set<String> activeNodes, Map<String, Long> assignedBytes, Multimap<Distribution, BucketAssignment> distributionAssignments, Map<Distribution, Long> distributionBucketSize) {
            this.activeNodes = ImmutableSet.copyOf((Collection)Objects.requireNonNull(activeNodes, "activeNodes is null"));
            this.assignedBytes = ImmutableMap.copyOf(Objects.requireNonNull(assignedBytes, "assignedBytes is null"));
            this.distributionAssignments = ImmutableMultimap.copyOf(Objects.requireNonNull(distributionAssignments, "distributionAssignments is null"));
            this.distributionBucketSize = ImmutableMap.copyOf(Objects.requireNonNull(distributionBucketSize, "distributionBucketSize is null"));
        }

        public Set<String> getActiveNodes() {
            return this.activeNodes;
        }

        public Map<String, Long> getAssignedBytes() {
            return this.assignedBytes;
        }

        public Multimap<Distribution, BucketAssignment> getDistributionAssignments() {
            return this.distributionAssignments;
        }

        public Map<Distribution, Long> getDistributionBucketSize() {
            return this.distributionBucketSize;
        }
    }
}

