/*
 * Decompiled with CFR 0.152.
 */
package org.apache.kafka.streams.processor.internals.assignment;

import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import java.util.SortedMap;
import java.util.SortedSet;
import java.util.TreeMap;
import java.util.TreeSet;
import java.util.UUID;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.BiConsumer;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.kafka.common.utils.Utils;
import org.apache.kafka.streams.processor.TaskId;
import org.apache.kafka.streams.processor.internals.assignment.AssignorConfiguration;
import org.apache.kafka.streams.processor.internals.assignment.ClientState;
import org.apache.kafka.streams.processor.internals.assignment.ConstrainedPrioritySet;
import org.apache.kafka.streams.processor.internals.assignment.TaskAssignor;
import org.apache.kafka.streams.processor.internals.assignment.TaskMovement;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class HighAvailabilityTaskAssignor
implements TaskAssignor {
    private static final Logger log = LoggerFactory.getLogger(HighAvailabilityTaskAssignor.class);

    @Override
    public boolean assign(Map<UUID, ClientState> clients, Set<TaskId> allTaskIds, Set<TaskId> statefulTaskIds, AssignorConfiguration.AssignmentConfigs configs) {
        TreeSet<TaskId> statefulTasks = new TreeSet<TaskId>(statefulTaskIds);
        TreeMap<UUID, ClientState> clientStates = new TreeMap<UUID, ClientState>(clients);
        HighAvailabilityTaskAssignor.assignActiveStatefulTasks(clientStates, statefulTasks);
        HighAvailabilityTaskAssignor.assignStandbyReplicaTasks(clientStates, statefulTasks, configs.numStandbyReplicas);
        AtomicInteger remainingWarmupReplicas = new AtomicInteger(configs.maxWarmupReplicas);
        Map<TaskId, SortedSet<UUID>> tasksToCaughtUpClients = HighAvailabilityTaskAssignor.tasksToCaughtUpClients(statefulTasks, clientStates, configs.acceptableRecoveryLag);
        TreeMap<UUID, Set<TaskId>> warmups = new TreeMap<UUID, Set<TaskId>>();
        int neededActiveTaskMovements = TaskMovement.assignActiveTaskMovements(tasksToCaughtUpClients, clientStates, warmups, remainingWarmupReplicas);
        int neededStandbyTaskMovements = TaskMovement.assignStandbyTaskMovements(tasksToCaughtUpClients, clientStates, remainingWarmupReplicas, warmups);
        HighAvailabilityTaskAssignor.assignStatelessActiveTasks(clientStates, Utils.diff(TreeSet::new, allTaskIds, statefulTasks));
        boolean probingRebalanceNeeded = neededActiveTaskMovements + neededStandbyTaskMovements > 0;
        log.info("Decided on assignment: " + clientStates + " with" + (probingRebalanceNeeded ? "" : " no") + " followup probing rebalance.");
        return probingRebalanceNeeded;
    }

    private static void assignActiveStatefulTasks(SortedMap<UUID, ClientState> clientStates, SortedSet<TaskId> statefulTasks) {
        Iterator<ClientState> clientStateIterator = null;
        for (TaskId task : statefulTasks) {
            if (clientStateIterator == null || !clientStateIterator.hasNext()) {
                clientStateIterator = clientStates.values().iterator();
            }
            clientStateIterator.next().assignActive(task);
        }
        HighAvailabilityTaskAssignor.balanceTasksOverThreads(clientStates, ClientState::activeTasks, ClientState::unassignActive, ClientState::assignActive);
    }

    private static void assignStandbyReplicaTasks(TreeMap<UUID, ClientState> clientStates, Set<TaskId> statefulTasks, int numStandbyReplicas) {
        Map<TaskId, Integer> tasksToRemainingStandbys = statefulTasks.stream().collect(Collectors.toMap(task -> task, t -> numStandbyReplicas));
        ConstrainedPrioritySet standbyTaskClientsByTaskLoad = new ConstrainedPrioritySet((client, task) -> !((ClientState)clientStates.get(client)).hasAssignedTask((TaskId)task), client -> ((ClientState)clientStates.get(client)).assignedTaskLoad());
        standbyTaskClientsByTaskLoad.offerAll(clientStates.keySet());
        for (TaskId task2 : statefulTasks) {
            UUID client2;
            int numRemainingStandbys;
            for (numRemainingStandbys = tasksToRemainingStandbys.get(task2).intValue(); numRemainingStandbys > 0 && (client2 = standbyTaskClientsByTaskLoad.poll(task2)) != null; --numRemainingStandbys) {
                clientStates.get(client2).assignStandby(task2);
                standbyTaskClientsByTaskLoad.offer(client2);
            }
            if (numRemainingStandbys <= 0) continue;
            log.warn("Unable to assign {} of {} standby tasks for task [{}]. There is not enough available capacity. You should increase the number of threads and/or application instances to maintain the requested number of standby replicas.", new Object[]{numRemainingStandbys, numStandbyReplicas, task2});
        }
        HighAvailabilityTaskAssignor.balanceTasksOverThreads(clientStates, ClientState::standbyTasks, ClientState::unassignStandby, ClientState::assignStandby);
    }

    private static void balanceTasksOverThreads(SortedMap<UUID, ClientState> clientStates, Function<ClientState, Set<TaskId>> currentAssignmentAccessor, BiConsumer<ClientState, TaskId> taskUnassignor, BiConsumer<ClientState, TaskId> taskAssignor) {
        boolean keepBalancing = true;
        while (keepBalancing) {
            keepBalancing = false;
            for (Map.Entry<UUID, ClientState> sourceEntry : clientStates.entrySet()) {
                UUID sourceClient = sourceEntry.getKey();
                ClientState sourceClientState = sourceEntry.getValue();
                for (Map.Entry<UUID, ClientState> destinationEntry : clientStates.entrySet()) {
                    UUID destinationClient = destinationEntry.getKey();
                    ClientState destinationClientState = destinationEntry.getValue();
                    if (sourceClient.equals(destinationClient)) continue;
                    TreeSet sourceTasks = new TreeSet(currentAssignmentAccessor.apply(sourceClientState));
                    Iterator sourceIterator = sourceTasks.iterator();
                    while (HighAvailabilityTaskAssignor.shouldMoveATask(sourceClientState, destinationClientState) && sourceIterator.hasNext()) {
                        TaskId taskToMove = (TaskId)sourceIterator.next();
                        boolean canMove = !destinationClientState.hasAssignedTask(taskToMove);
                        if (!canMove) continue;
                        taskUnassignor.accept(sourceClientState, taskToMove);
                        taskAssignor.accept(destinationClientState, taskToMove);
                        keepBalancing = true;
                    }
                }
            }
        }
    }

    private static boolean shouldMoveATask(ClientState sourceClientState, ClientState destinationClientState) {
        double skew = sourceClientState.assignedTaskLoad() - destinationClientState.assignedTaskLoad();
        if (skew <= 0.0) {
            return false;
        }
        double proposedAssignedTasksPerStreamThreadAtDestination = ((double)destinationClientState.assignedTaskCount() + 1.0) / (double)destinationClientState.capacity();
        double proposedAssignedTasksPerStreamThreadAtSource = ((double)sourceClientState.assignedTaskCount() - 1.0) / (double)sourceClientState.capacity();
        double proposedSkew = proposedAssignedTasksPerStreamThreadAtSource - proposedAssignedTasksPerStreamThreadAtDestination;
        if (proposedSkew < 0.0) {
            return false;
        }
        return proposedSkew < skew;
    }

    private static void assignStatelessActiveTasks(TreeMap<UUID, ClientState> clientStates, Iterable<TaskId> statelessTasks) {
        ConstrainedPrioritySet statelessActiveTaskClientsByTaskLoad = new ConstrainedPrioritySet((client, task) -> true, client -> ((ClientState)clientStates.get(client)).activeTaskLoad());
        statelessActiveTaskClientsByTaskLoad.offerAll(clientStates.keySet());
        for (TaskId task2 : statelessTasks) {
            UUID client2 = statelessActiveTaskClientsByTaskLoad.poll(task2);
            ClientState state = clientStates.get(client2);
            state.assignActive(task2);
            statelessActiveTaskClientsByTaskLoad.offer(client2);
        }
    }

    private static Map<TaskId, SortedSet<UUID>> tasksToCaughtUpClients(Set<TaskId> statefulTasks, Map<UUID, ClientState> clientStates, long acceptableRecoveryLag) {
        HashMap<TaskId, SortedSet<UUID>> taskToCaughtUpClients = new HashMap<TaskId, SortedSet<UUID>>();
        for (TaskId task : statefulTasks) {
            TreeSet<UUID> caughtUpClients = new TreeSet<UUID>();
            for (Map.Entry<UUID, ClientState> clientEntry : clientStates.entrySet()) {
                UUID client = clientEntry.getKey();
                long taskLag = clientEntry.getValue().lagFor(task);
                if (!HighAvailabilityTaskAssignor.activeRunning(taskLag) && !HighAvailabilityTaskAssignor.unbounded(acceptableRecoveryLag) && !HighAvailabilityTaskAssignor.acceptable(acceptableRecoveryLag, taskLag)) continue;
                caughtUpClients.add(client);
            }
            taskToCaughtUpClients.put(task, caughtUpClients);
        }
        return taskToCaughtUpClients;
    }

    private static boolean unbounded(long acceptableRecoveryLag) {
        return acceptableRecoveryLag == Long.MAX_VALUE;
    }

    private static boolean acceptable(long acceptableRecoveryLag, long taskLag) {
        return taskLag >= 0L && taskLag <= acceptableRecoveryLag;
    }

    private static boolean activeRunning(long taskLag) {
        return taskLag == -2L;
    }
}

