/*
 * Decompiled with CFR 0.152.
 */
package org.apache.druid.server.coordinator.loading;

import com.google.common.collect.Sets;
import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeSet;
import java.util.stream.Collectors;
import javax.annotation.concurrent.NotThreadSafe;
import org.apache.druid.java.util.emitter.EmittingLogger;
import org.apache.druid.server.coordinator.DruidCluster;
import org.apache.druid.server.coordinator.ServerHolder;
import org.apache.druid.server.coordinator.balancer.BalancerStrategy;
import org.apache.druid.server.coordinator.loading.ReplicationThrottler;
import org.apache.druid.server.coordinator.loading.RoundRobinServerSelector;
import org.apache.druid.server.coordinator.loading.SegmentAction;
import org.apache.druid.server.coordinator.loading.SegmentLoadQueueManager;
import org.apache.druid.server.coordinator.loading.SegmentLoadingConfig;
import org.apache.druid.server.coordinator.loading.SegmentReplicaCount;
import org.apache.druid.server.coordinator.loading.SegmentReplicaCountMap;
import org.apache.druid.server.coordinator.loading.SegmentReplicationStatus;
import org.apache.druid.server.coordinator.loading.SegmentStatusInTier;
import org.apache.druid.server.coordinator.rules.SegmentActionHandler;
import org.apache.druid.server.coordinator.stats.CoordinatorRunStats;
import org.apache.druid.server.coordinator.stats.CoordinatorStat;
import org.apache.druid.server.coordinator.stats.Dimension;
import org.apache.druid.server.coordinator.stats.RowKey;
import org.apache.druid.server.coordinator.stats.Stats;
import org.apache.druid.timeline.DataSegment;
import org.apache.druid.timeline.SegmentId;

@NotThreadSafe
public class StrategicSegmentAssigner
implements SegmentActionHandler {
    private static final EmittingLogger log = new EmittingLogger(StrategicSegmentAssigner.class);
    private final SegmentLoadQueueManager loadQueueManager;
    private final DruidCluster cluster;
    private final CoordinatorRunStats stats;
    private final SegmentReplicaCountMap replicaCountMap;
    private final ReplicationThrottler replicationThrottler;
    private final RoundRobinServerSelector serverSelector;
    private final BalancerStrategy strategy;
    private final boolean useRoundRobinAssignment;
    private final Map<String, Set<String>> datasourceToInvalidLoadTiers = new HashMap<String, Set<String>>();
    private final Map<String, Integer> tierToHistoricalCount = new HashMap<String, Integer>();
    private final Map<String, Set<SegmentId>> segmentsToDelete = new HashMap<String, Set<SegmentId>>();

    public StrategicSegmentAssigner(SegmentLoadQueueManager loadQueueManager, DruidCluster cluster, BalancerStrategy strategy, SegmentLoadingConfig loadingConfig, CoordinatorRunStats stats) {
        this.stats = stats;
        this.cluster = cluster;
        this.strategy = strategy;
        this.loadQueueManager = loadQueueManager;
        this.replicaCountMap = SegmentReplicaCountMap.create(cluster);
        this.replicationThrottler = StrategicSegmentAssigner.createReplicationThrottler(cluster, loadingConfig);
        this.useRoundRobinAssignment = loadingConfig.isUseRoundRobinSegmentAssignment();
        this.serverSelector = this.useRoundRobinAssignment ? new RoundRobinServerSelector(cluster) : null;
        cluster.getHistoricals().forEach((tier, historicals) -> this.tierToHistoricalCount.put((String)tier, historicals.size()));
    }

    public CoordinatorRunStats getStats() {
        return this.stats;
    }

    public SegmentReplicationStatus getReplicationStatus() {
        return this.replicaCountMap.toReplicationStatus();
    }

    public Map<String, Set<SegmentId>> getSegmentsToDelete() {
        return this.segmentsToDelete;
    }

    public Map<String, Set<String>> getDatasourceToInvalidLoadTiers() {
        return this.datasourceToInvalidLoadTiers;
    }

    public boolean moveSegment(DataSegment segment, ServerHolder sourceServer, List<ServerHolder> destinationServers) {
        ServerHolder destination;
        String tier = sourceServer.getServer().getTier();
        List<ServerHolder> eligibleDestinationServers = destinationServers.stream().filter(s -> s.getServer().getTier().equals(tier)).filter(s -> s.canLoadSegment(segment)).collect(Collectors.toList());
        if (eligibleDestinationServers.isEmpty()) {
            this.incrementSkipStat(Stats.Segments.MOVE_SKIPPED, "No eligible server", segment, tier);
            return false;
        }
        if (!sourceServer.isDecommissioning()) {
            eligibleDestinationServers.add(sourceServer);
        }
        if ((destination = this.strategy.findDestinationServerToMoveSegment(segment, sourceServer, eligibleDestinationServers)) == null || destination.getServer().equals(sourceServer.getServer())) {
            this.incrementSkipStat(Stats.Segments.MOVE_SKIPPED, "Optimally placed", segment, tier);
            return false;
        }
        if (this.moveSegment(segment, sourceServer, destination)) {
            this.incrementStat(Stats.Segments.MOVED, segment, tier, 1L);
            return true;
        }
        this.incrementSkipStat(Stats.Segments.MOVE_SKIPPED, "Encountered error", segment, tier);
        return false;
    }

    private boolean moveSegment(DataSegment segment, ServerHolder serverA, ServerHolder serverB) {
        String tier = serverA.getServer().getTier();
        if (serverA.isLoadingSegment(segment)) {
            if (serverA.cancelOperation(SegmentAction.LOAD, segment)) {
                int loadedCountOnTier = this.replicaCountMap.get(segment.getId(), tier).loadedNotDropping();
                if (loadedCountOnTier >= 1) {
                    return this.replicateSegment(segment, serverB);
                }
                return this.loadSegment(segment, serverB);
            }
            return false;
        }
        if (serverA.isServingSegment(segment)) {
            return this.loadQueueManager.moveSegment(segment, serverA, serverB);
        }
        return false;
    }

    @Override
    public void replicateSegment(DataSegment segment, Map<String, Integer> tierToReplicaCount) {
        HashSet allTiersInCluster = Sets.newHashSet(this.cluster.getTierNames());
        if (tierToReplicaCount.isEmpty()) {
            this.replicaCountMap.computeIfAbsent(segment.getId(), "_default_tier");
        } else {
            tierToReplicaCount.forEach((tier, requiredReplicas) -> {
                this.reportTierCapacityStats(segment, (int)requiredReplicas, (String)tier);
                SegmentReplicaCount replicaCount = this.replicaCountMap.computeIfAbsent(segment.getId(), (String)tier);
                replicaCount.setRequired((int)requiredReplicas, this.tierToHistoricalCount.getOrDefault(tier, 0));
                if (!allTiersInCluster.contains(tier)) {
                    this.datasourceToInvalidLoadTiers.computeIfAbsent(segment.getDataSource(), ds -> new HashSet()).add(tier);
                }
            });
        }
        SegmentReplicaCount replicaCountInCluster = this.replicaCountMap.getTotal(segment.getId());
        int replicaSurplus = replicaCountInCluster.loadedNotDropping() - replicaCountInCluster.requiredAndLoadable();
        int dropsQueued = 0;
        for (String tier2 : allTiersInCluster) {
            dropsQueued += this.updateReplicasInTier(segment, tier2, tierToReplicaCount.getOrDefault(tier2, 0), replicaSurplus - dropsQueued);
        }
    }

    private int updateReplicasInTier(DataSegment segment, String tier, int requiredReplicas, int maxReplicasToDrop) {
        int cancelledLoads;
        int replicaSurplus;
        int numReplicasToDrop;
        int cancelledDrops;
        int replicaDeficit;
        int numReplicasToLoad;
        boolean shouldCancelMoves;
        SegmentReplicaCount replicaCountOnTier = this.replicaCountMap.get(segment.getId(), tier);
        int projectedReplicas = replicaCountOnTier.loadedNotDropping() + replicaCountOnTier.loading();
        int movingReplicas = replicaCountOnTier.moving();
        boolean bl = shouldCancelMoves = requiredReplicas == 0 && movingReplicas > 0;
        if (projectedReplicas == requiredReplicas && !shouldCancelMoves) {
            return 0;
        }
        SegmentStatusInTier segmentStatus = new SegmentStatusInTier(segment, this.cluster.getHistoricalsByTier(tier));
        if (shouldCancelMoves) {
            this.cancelOperations(SegmentAction.MOVE_TO, movingReplicas, segment, segmentStatus);
            this.cancelOperations(SegmentAction.MOVE_FROM, movingReplicas, segment, segmentStatus);
        }
        if (projectedReplicas < requiredReplicas && (numReplicasToLoad = (replicaDeficit = requiredReplicas - projectedReplicas) - (cancelledDrops = this.cancelOperations(SegmentAction.DROP, replicaDeficit, segment, segmentStatus))) > 0) {
            int numLoadedReplicas = replicaCountOnTier.loadedNotDropping() + cancelledDrops;
            int numLoadsQueued = this.loadReplicas(numReplicasToLoad, numLoadedReplicas, segment, tier, segmentStatus);
            this.incrementStat(Stats.Segments.ASSIGNED, segment, tier, numLoadsQueued);
        }
        if (projectedReplicas > requiredReplicas && (numReplicasToDrop = Math.min((replicaSurplus = projectedReplicas - requiredReplicas) - (cancelledLoads = this.cancelOperations(SegmentAction.LOAD, replicaSurplus, segment, segmentStatus)), maxReplicasToDrop)) > 0) {
            int dropsQueuedOnTier = this.dropReplicas(numReplicasToDrop, segment, tier, segmentStatus);
            this.incrementStat(Stats.Segments.DROPPED, segment, tier, dropsQueuedOnTier);
            return dropsQueuedOnTier;
        }
        return 0;
    }

    private void reportTierCapacityStats(DataSegment segment, int requiredReplicas, String tier) {
        RowKey rowKey = RowKey.of(Dimension.TIER, tier);
        this.stats.updateMax(Stats.Tier.REPLICATION_FACTOR, rowKey, requiredReplicas);
        this.stats.add(Stats.Tier.REQUIRED_CAPACITY, rowKey, segment.getSize() * (long)requiredReplicas);
    }

    @Override
    public void broadcastSegment(DataSegment segment) {
        Object2IntOpenHashMap tierToRequiredReplicas = new Object2IntOpenHashMap();
        for (ServerHolder server : this.cluster.getAllServers()) {
            if (!server.getServer().getType().isSegmentBroadcastTarget()) continue;
            String tier = server.getServer().getTier();
            int numDropsQueued = 0;
            int numLoadsQueued = 0;
            if (server.isDecommissioning()) {
                numDropsQueued += this.dropBroadcastSegment(segment, server) ? 1 : 0;
            } else {
                tierToRequiredReplicas.addTo((Object)tier, 1);
                numLoadsQueued += this.loadBroadcastSegment(segment, server) ? 1 : 0;
            }
            if (numLoadsQueued > 0) {
                this.incrementStat(Stats.Segments.ASSIGNED, segment, tier, numLoadsQueued);
            }
            if (numDropsQueued <= 0) continue;
            this.incrementStat(Stats.Segments.DROPPED, segment, tier, numDropsQueued);
        }
        tierToRequiredReplicas.object2IntEntrySet().fastForEach(entry -> this.replicaCountMap.computeIfAbsent(segment.getId(), (String)entry.getKey()).setRequired(entry.getIntValue(), entry.getIntValue()));
    }

    @Override
    public void deleteSegment(DataSegment segment) {
        this.segmentsToDelete.computeIfAbsent(segment.getDataSource(), ds -> new HashSet()).add(segment.getId());
    }

    private boolean loadBroadcastSegment(DataSegment segment, ServerHolder server) {
        if (server.isServingSegment(segment) || server.isLoadingSegment(segment)) {
            return false;
        }
        if (server.isDroppingSegment(segment)) {
            return server.cancelOperation(SegmentAction.DROP, segment);
        }
        if (server.canLoadSegment(segment)) {
            return this.loadSegment(segment, server);
        }
        String skipReason = server.getAvailableSize() < segment.getSize() ? "Not enough disk space" : (server.isLoadQueueFull() ? "Load queue is full" : "Unknown error");
        this.incrementSkipStat(Stats.Segments.ASSIGN_SKIPPED, skipReason, segment, server.getServer().getTier());
        return false;
    }

    private boolean dropBroadcastSegment(DataSegment segment, ServerHolder server) {
        if (server.isLoadingSegment(segment)) {
            return server.cancelOperation(SegmentAction.LOAD, segment);
        }
        if (server.isServingSegment(segment)) {
            return this.loadQueueManager.dropSegment(segment, server);
        }
        return false;
    }

    private int dropReplicas(int numToDrop, DataSegment segment, String tier, SegmentStatusInTier segmentStatus) {
        if (numToDrop <= 0) {
            return 0;
        }
        List<ServerHolder> eligibleServers = segmentStatus.getServersEligibleToDrop();
        if (eligibleServers.isEmpty()) {
            this.incrementSkipStat(Stats.Segments.DROP_SKIPPED, "No eligible server", segment, tier);
            return 0;
        }
        TreeSet eligibleLiveServers = new TreeSet(Comparator.reverseOrder());
        TreeSet eligibleDyingServers = new TreeSet(Comparator.reverseOrder());
        for (ServerHolder server : eligibleServers) {
            if (server.isDecommissioning()) {
                eligibleDyingServers.add(server);
                continue;
            }
            eligibleLiveServers.add(server);
        }
        int remainingNumToDrop = numToDrop;
        int numDropsQueued = this.dropReplicasFromServers(remainingNumToDrop, segment, eligibleDyingServers.iterator(), tier);
        if (numToDrop > numDropsQueued) {
            remainingNumToDrop = numToDrop - numDropsQueued;
            Iterator<ServerHolder> serverIterator = this.useRoundRobinAssignment || eligibleLiveServers.size() <= remainingNumToDrop ? eligibleLiveServers.iterator() : this.strategy.findServersToDropSegment(segment, new ArrayList<ServerHolder>(eligibleLiveServers));
            numDropsQueued += this.dropReplicasFromServers(remainingNumToDrop, segment, serverIterator, tier);
        }
        return numDropsQueued;
    }

    private int dropReplicasFromServers(int numToDrop, DataSegment segment, Iterator<ServerHolder> serverIterator, String tier) {
        int numDropsQueued = 0;
        while (numToDrop > numDropsQueued && serverIterator.hasNext()) {
            ServerHolder holder = serverIterator.next();
            boolean dropped = this.loadQueueManager.dropSegment(segment, holder);
            if (dropped) {
                ++numDropsQueued;
                continue;
            }
            this.incrementSkipStat(Stats.Segments.DROP_SKIPPED, "Encountered error", segment, tier);
        }
        return numDropsQueued;
    }

    private int loadReplicas(int numToLoad, int numLoadedReplicas, DataSegment segment, String tier, SegmentStatusInTier segmentStatus) {
        int numLoadsQueued;
        boolean queuedSuccessfully;
        Iterator<ServerHolder> serverIterator;
        boolean isAlreadyLoadedOnTier;
        boolean bl = isAlreadyLoadedOnTier = numLoadedReplicas >= 1;
        if (isAlreadyLoadedOnTier && this.replicationThrottler.isReplicationThrottledForTier(tier)) {
            return 0;
        }
        List<ServerHolder> eligibleServers = segmentStatus.getServersEligibleToLoad();
        if (eligibleServers.isEmpty()) {
            this.incrementSkipStat(Stats.Segments.ASSIGN_SKIPPED, "No eligible server", segment, tier);
            return 0;
        }
        Iterator<ServerHolder> iterator = serverIterator = this.useRoundRobinAssignment ? this.serverSelector.getServersInTierToLoadSegment(tier, segment) : this.strategy.findServersToLoadSegment(segment, eligibleServers);
        if (!serverIterator.hasNext()) {
            this.incrementSkipStat(Stats.Segments.ASSIGN_SKIPPED, "No strategic server", segment, tier);
            return 0;
        }
        for (numLoadsQueued = 0; numLoadsQueued < numToLoad && serverIterator.hasNext(); numLoadsQueued += queuedSuccessfully ? 1 : 0) {
            ServerHolder server = serverIterator.next();
            queuedSuccessfully = isAlreadyLoadedOnTier ? this.replicateSegment(segment, server) : this.loadSegment(segment, server);
        }
        return numLoadsQueued;
    }

    private boolean loadSegment(DataSegment segment, ServerHolder server) {
        String tier = server.getServer().getTier();
        boolean assigned = this.loadQueueManager.loadSegment(segment, server, SegmentAction.LOAD);
        if (!assigned) {
            this.incrementSkipStat(Stats.Segments.ASSIGN_SKIPPED, "Encountered error", segment, tier);
        }
        return assigned;
    }

    private boolean replicateSegment(DataSegment segment, ServerHolder server) {
        String tier = server.getServer().getTier();
        if (this.replicationThrottler.isReplicationThrottledForTier(tier)) {
            this.incrementSkipStat(Stats.Segments.ASSIGN_SKIPPED, "Throttled replication", segment, tier);
            return false;
        }
        boolean assigned = this.loadQueueManager.loadSegment(segment, server, SegmentAction.REPLICATE);
        if (!assigned) {
            this.incrementSkipStat(Stats.Segments.ASSIGN_SKIPPED, "Encountered error", segment, tier);
        } else {
            this.replicationThrottler.incrementAssignedReplicas(tier);
        }
        return assigned;
    }

    private static ReplicationThrottler createReplicationThrottler(DruidCluster cluster, SegmentLoadingConfig loadingConfig) {
        HashMap<String, Integer> tierToLoadingReplicaCount = new HashMap<String, Integer>();
        cluster.getHistoricals().forEach((tier, historicals) -> {
            int numLoadingReplicas = historicals.stream().mapToInt(ServerHolder::getNumLoadingReplicas).sum();
            tierToLoadingReplicaCount.put((String)tier, numLoadingReplicas);
        });
        return new ReplicationThrottler(tierToLoadingReplicaCount, loadingConfig.getReplicationThrottleLimit());
    }

    private int cancelOperations(SegmentAction action, int maxNumToCancel, DataSegment segment, SegmentStatusInTier segmentStatus) {
        List<ServerHolder> servers = segmentStatus.getServersPerforming(action);
        if (servers.isEmpty() || maxNumToCancel <= 0) {
            return 0;
        }
        int numCancelled = 0;
        for (int i = 0; i < servers.size() && numCancelled < maxNumToCancel; numCancelled += servers.get(i).cancelOperation(action, segment) ? 1 : 0, ++i) {
        }
        return numCancelled;
    }

    private void incrementSkipStat(CoordinatorStat stat, String reason, DataSegment segment, String tier) {
        RowKey key = RowKey.with(Dimension.TIER, tier).with(Dimension.DATASOURCE, segment.getDataSource()).and(Dimension.DESCRIPTION, reason);
        this.stats.add(stat, key, 1L);
    }

    private void incrementStat(CoordinatorStat stat, DataSegment segment, String tier, long value) {
        this.stats.addToSegmentStat(stat, tier, segment.getDataSource(), value);
    }
}

