/*
 * Decompiled with CFR 0.152.
 */
package io.trino.execution.scheduler;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.base.Suppliers;
import com.google.common.base.Ticker;
import com.google.common.collect.HashMultimap;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Multimap;
import com.google.common.util.concurrent.ListenableFuture;
import io.airlift.log.Logger;
import io.trino.execution.NodeTaskMap;
import io.trino.execution.RemoteTask;
import io.trino.execution.scheduler.BucketNodeMap;
import io.trino.execution.scheduler.NodeAssignmentStats;
import io.trino.execution.scheduler.NodeMap;
import io.trino.execution.scheduler.NodeScheduler;
import io.trino.execution.scheduler.NodeSchedulerConfig;
import io.trino.execution.scheduler.NodeSelector;
import io.trino.execution.scheduler.ResettableRandomizedIterator;
import io.trino.execution.scheduler.SplitPlacementResult;
import io.trino.metadata.InternalNode;
import io.trino.metadata.InternalNodeManager;
import io.trino.metadata.Split;
import io.trino.spi.ErrorCodeSupplier;
import io.trino.spi.StandardErrorCode;
import io.trino.spi.TrinoException;
import jakarta.annotation.Nullable;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Supplier;

public class UniformNodeSelector
implements NodeSelector {
    private static final Logger log = Logger.get(UniformNodeSelector.class);
    private final InternalNodeManager nodeManager;
    private final NodeTaskMap nodeTaskMap;
    private final boolean includeCoordinator;
    private final AtomicReference<Supplier<NodeMap>> nodeMap;
    private final int minCandidates;
    private final long maxSplitsWeightPerNode;
    private final long minPendingSplitsWeightPerTask;
    private final int maxUnacknowledgedSplitsPerTask;
    private final NodeSchedulerConfig.SplitsBalancingPolicy splitsBalancingPolicy;
    private final boolean optimizedLocalScheduling;
    private final QueueSizeAdjuster queueSizeAdjuster;

    public UniformNodeSelector(InternalNodeManager nodeManager, NodeTaskMap nodeTaskMap, boolean includeCoordinator, Supplier<NodeMap> nodeMap, int minCandidates, long maxSplitsWeightPerNode, long minPendingSplitsWeightPerTask, long maxAdjustedPendingSplitsWeightPerTask, int maxUnacknowledgedSplitsPerTask, NodeSchedulerConfig.SplitsBalancingPolicy splitsBalancingPolicy, boolean optimizedLocalScheduling) {
        this(nodeManager, nodeTaskMap, includeCoordinator, nodeMap, minCandidates, maxSplitsWeightPerNode, minPendingSplitsWeightPerTask, maxUnacknowledgedSplitsPerTask, splitsBalancingPolicy, optimizedLocalScheduling, new QueueSizeAdjuster(minPendingSplitsWeightPerTask, maxAdjustedPendingSplitsWeightPerTask));
    }

    @VisibleForTesting
    UniformNodeSelector(InternalNodeManager nodeManager, NodeTaskMap nodeTaskMap, boolean includeCoordinator, Supplier<NodeMap> nodeMap, int minCandidates, long maxSplitsWeightPerNode, long minPendingSplitsWeightPerTask, int maxUnacknowledgedSplitsPerTask, NodeSchedulerConfig.SplitsBalancingPolicy splitsBalancingPolicy, boolean optimizedLocalScheduling, QueueSizeAdjuster queueSizeAdjuster) {
        this.nodeManager = Objects.requireNonNull(nodeManager, "nodeManager is null");
        this.nodeTaskMap = Objects.requireNonNull(nodeTaskMap, "nodeTaskMap is null");
        this.includeCoordinator = includeCoordinator;
        this.nodeMap = new AtomicReference<Supplier<NodeMap>>(nodeMap);
        this.minCandidates = minCandidates;
        this.maxSplitsWeightPerNode = maxSplitsWeightPerNode;
        this.minPendingSplitsWeightPerTask = minPendingSplitsWeightPerTask;
        this.maxUnacknowledgedSplitsPerTask = maxUnacknowledgedSplitsPerTask;
        Preconditions.checkArgument((maxUnacknowledgedSplitsPerTask > 0 ? 1 : 0) != 0, (String)"maxUnacknowledgedSplitsPerTask must be > 0, found: %s", (int)maxUnacknowledgedSplitsPerTask);
        this.splitsBalancingPolicy = Objects.requireNonNull(splitsBalancingPolicy, "splitsBalancingPolicy is null");
        this.optimizedLocalScheduling = optimizedLocalScheduling;
        this.queueSizeAdjuster = queueSizeAdjuster;
    }

    @Override
    public void lockDownNodes() {
        this.nodeMap.set((Supplier<NodeMap>)Suppliers.ofInstance((Object)this.nodeMap.get().get()));
    }

    @Override
    public List<InternalNode> allNodes() {
        return NodeScheduler.getAllNodes(this.nodeMap.get().get(), this.includeCoordinator);
    }

    @Override
    public InternalNode selectCurrentNode() {
        return this.nodeManager.getCurrentNode();
    }

    @Override
    public List<InternalNode> selectRandomNodes(int limit, Set<InternalNode> excludedNodes) {
        return NodeScheduler.selectNodes(limit, NodeScheduler.randomizedNodes(this.nodeMap.get().get(), this.includeCoordinator, excludedNodes));
    }

    @Override
    public SplitPlacementResult computeAssignments(Set<Split> splits, List<RemoteTask> existingTasks) {
        HashMultimap assignment = HashMultimap.create();
        NodeMap nodeMap = this.nodeMap.get().get();
        NodeAssignmentStats assignmentStats = new NodeAssignmentStats(this.nodeTaskMap, nodeMap, existingTasks);
        this.queueSizeAdjuster.update(existingTasks, assignmentStats);
        HashSet<InternalNode> blockedExactNodes = new HashSet<InternalNode>();
        boolean splitWaitingForAnyNode = false;
        List<InternalNode> filteredNodes = NodeScheduler.filterNodes(nodeMap, this.includeCoordinator, (Set<InternalNode>)ImmutableSet.of());
        ResettableRandomizedIterator<InternalNode> randomCandidates = new ResettableRandomizedIterator<InternalNode>(filteredNodes);
        HashSet<InternalNode> schedulableNodes = new HashSet<InternalNode>(filteredNodes);
        for (Split split : splits) {
            boolean exactNodes;
            List<InternalNode> candidateNodes;
            randomCandidates.reset();
            if (!split.isRemotelyAccessible()) {
                candidateNodes = NodeScheduler.selectExactNodes(nodeMap, split.getAddresses(), this.includeCoordinator);
                exactNodes = true;
            } else if (this.optimizedLocalScheduling && !split.getAddresses().isEmpty()) {
                candidateNodes = NodeScheduler.selectExactNodes(nodeMap, split.getAddresses(), this.includeCoordinator);
                if (candidateNodes.isEmpty()) {
                    candidateNodes = NodeScheduler.selectNodes(this.minCandidates, randomCandidates);
                    exactNodes = false;
                } else {
                    exactNodes = true;
                }
            } else {
                candidateNodes = NodeScheduler.selectNodes(this.minCandidates, randomCandidates);
                exactNodes = false;
            }
            if (candidateNodes.isEmpty()) {
                log.debug("No nodes available to schedule %s. Available nodes %s", new Object[]{split, nodeMap.getNodesByHost().keys()});
                throw new TrinoException((ErrorCodeSupplier)StandardErrorCode.NO_NODES_AVAILABLE, "No nodes available to run query");
            }
            InternalNode chosenNode = this.chooseNodeForSplit(assignmentStats, candidateNodes);
            if (chosenNode == null) {
                long minWeight = Long.MAX_VALUE;
                for (InternalNode node : candidateNodes) {
                    long queuedWeight = assignmentStats.getQueuedSplitsWeightForStage(node);
                    long adjustedMaxPendingSplitsWeightPerTask = this.queueSizeAdjuster.getAdjustedMaxPendingSplitsWeightPerTask(node.getNodeIdentifier());
                    if (queuedWeight <= minWeight && queuedWeight < adjustedMaxPendingSplitsWeightPerTask && assignmentStats.getUnacknowledgedSplitCountForStage(node) < this.maxUnacknowledgedSplitsPerTask) {
                        chosenNode = node;
                        minWeight = queuedWeight;
                    }
                    if (queuedWeight < adjustedMaxPendingSplitsWeightPerTask) continue;
                    this.queueSizeAdjuster.scheduleAdjustmentForNode(node.getNodeIdentifier());
                }
            }
            if (chosenNode != null) {
                assignment.put((Object)chosenNode, (Object)split);
                assignmentStats.addAssignedSplit(chosenNode, split.getSplitWeight());
                continue;
            }
            candidateNodes.forEach(schedulableNodes::remove);
            if (!exactNodes) {
                splitWaitingForAnyNode = true;
            } else if (!splitWaitingForAnyNode) {
                blockedExactNodes.addAll(candidateNodes);
            }
            if (!splitWaitingForAnyNode || !schedulableNodes.isEmpty()) continue;
            break;
        }
        ListenableFuture<Void> blocked = splitWaitingForAnyNode ? NodeScheduler.toWhenHasSplitQueueSpaceFuture(existingTasks, NodeScheduler.calculateLowWatermark(this.minPendingSplitsWeightPerTask)) : NodeScheduler.toWhenHasSplitQueueSpaceFuture(blockedExactNodes, existingTasks, NodeScheduler.calculateLowWatermark(this.minPendingSplitsWeightPerTask));
        return new SplitPlacementResult(blocked, (Multimap<InternalNode, Split>)assignment);
    }

    @Override
    public SplitPlacementResult computeAssignments(Set<Split> splits, List<RemoteTask> existingTasks, BucketNodeMap bucketNodeMap) {
        return NodeScheduler.selectDistributionNodes(this.nodeMap.get().get(), this.nodeTaskMap, this.maxSplitsWeightPerNode, this.minPendingSplitsWeightPerTask, this.maxUnacknowledgedSplitsPerTask, splits, existingTasks, bucketNodeMap);
    }

    @Nullable
    private InternalNode chooseNodeForSplit(NodeAssignmentStats assignmentStats, List<InternalNode> candidateNodes) {
        InternalNode chosenNode = null;
        long minWeight = Long.MAX_VALUE;
        List<InternalNode> freeNodes = this.getFreeNodesForStage(assignmentStats, candidateNodes);
        switch (this.splitsBalancingPolicy) {
            case STAGE: {
                for (InternalNode node : freeNodes) {
                    long queuedWeight = assignmentStats.getQueuedSplitsWeightForStage(node);
                    if (queuedWeight > minWeight) continue;
                    chosenNode = node;
                    minWeight = queuedWeight;
                }
                break;
            }
            case NODE: {
                for (InternalNode node : freeNodes) {
                    long totalSplitsWeight = assignmentStats.getTotalSplitsWeight(node);
                    if (totalSplitsWeight > minWeight) continue;
                    chosenNode = node;
                    minWeight = totalSplitsWeight;
                }
                break;
            }
            default: {
                throw new UnsupportedOperationException("Unsupported split balancing policy " + String.valueOf((Object)this.splitsBalancingPolicy));
            }
        }
        return chosenNode;
    }

    private List<InternalNode> getFreeNodesForStage(NodeAssignmentStats assignmentStats, List<InternalNode> nodes) {
        ImmutableList.Builder freeNodes = ImmutableList.builder();
        for (InternalNode node : nodes) {
            if (assignmentStats.getTotalSplitsWeight(node) >= this.maxSplitsWeightPerNode || assignmentStats.getUnacknowledgedSplitCountForStage(node) >= this.maxUnacknowledgedSplitsPerTask) continue;
            freeNodes.add((Object)node);
        }
        return freeNodes.build();
    }

    static class QueueSizeAdjuster {
        private static final long SCALE_DOWN_INTERVAL = TimeUnit.SECONDS.toNanos(1L);
        private final Ticker ticker;
        private final Map<String, TaskAdjustmentInfo> taskAdjustmentInfos = new HashMap<String, TaskAdjustmentInfo>();
        private final Set<String> previousScheduleFullTasks = new HashSet<String>();
        private final long minPendingSplitsWeightPerTask;
        private final long maxAdjustedPendingSplitsWeightPerTask;

        private QueueSizeAdjuster(long minPendingSplitsWeightPerTask, long maxAdjustedPendingSplitsWeightPerTask) {
            this(minPendingSplitsWeightPerTask, maxAdjustedPendingSplitsWeightPerTask, Ticker.systemTicker());
        }

        @VisibleForTesting
        QueueSizeAdjuster(long minPendingSplitsWeightPerTask, long maxAdjustedPendingSplitsWeightPerTask, Ticker ticker) {
            this.ticker = Objects.requireNonNull(ticker, "ticker is null");
            this.maxAdjustedPendingSplitsWeightPerTask = maxAdjustedPendingSplitsWeightPerTask;
            this.minPendingSplitsWeightPerTask = minPendingSplitsWeightPerTask;
        }

        public void update(List<RemoteTask> existingTasks, NodeAssignmentStats nodeAssignmentStats) {
            if (!this.isEnabled()) {
                return;
            }
            for (RemoteTask task : existingTasks) {
                String nodeId = task.getNodeId();
                TaskAdjustmentInfo nodeTaskAdjustmentInfo = this.taskAdjustmentInfos.computeIfAbsent(nodeId, key -> new TaskAdjustmentInfo(this.minPendingSplitsWeightPerTask));
                Optional<Long> lastAdjustmentTime = nodeTaskAdjustmentInfo.getLastAdjustmentNanos();
                if (this.previousScheduleFullTasks.contains(nodeId) && nodeAssignmentStats.getQueuedSplitsWeightForStage(nodeId) == 0L) {
                    nodeTaskAdjustmentInfo.setAdjustedMaxSplitsWeightPerTask(Math.min(this.maxAdjustedPendingSplitsWeightPerTask, nodeTaskAdjustmentInfo.getAdjustedMaxSplitsWeightPerTask() * 2L));
                    continue;
                }
                if (!lastAdjustmentTime.isPresent() || this.ticker.read() - lastAdjustmentTime.get() < SCALE_DOWN_INTERVAL) continue;
                nodeTaskAdjustmentInfo.setAdjustedMaxSplitsWeightPerTask((long)Math.max((double)this.minPendingSplitsWeightPerTask, (double)nodeTaskAdjustmentInfo.getAdjustedMaxSplitsWeightPerTask() / 1.5));
            }
            this.previousScheduleFullTasks.clear();
        }

        public long getAdjustedMaxPendingSplitsWeightPerTask(String nodeId) {
            TaskAdjustmentInfo nodeTaskAdjustmentInfo = this.taskAdjustmentInfos.get(nodeId);
            return nodeTaskAdjustmentInfo != null ? nodeTaskAdjustmentInfo.getAdjustedMaxSplitsWeightPerTask() : this.minPendingSplitsWeightPerTask;
        }

        public void scheduleAdjustmentForNode(String nodeIdentifier) {
            if (!this.isEnabled()) {
                return;
            }
            this.previousScheduleFullTasks.add(nodeIdentifier);
        }

        private boolean isEnabled() {
            return this.maxAdjustedPendingSplitsWeightPerTask != this.minPendingSplitsWeightPerTask;
        }

        private class TaskAdjustmentInfo {
            private long adjustedMaxSplitsWeightPerTask;
            private Optional<Long> lastAdjustmentNanos;

            public TaskAdjustmentInfo(long adjustedMaxSplitsWeightPerTask) {
                this.adjustedMaxSplitsWeightPerTask = adjustedMaxSplitsWeightPerTask;
                this.lastAdjustmentNanos = Optional.empty();
            }

            public long getAdjustedMaxSplitsWeightPerTask() {
                return this.adjustedMaxSplitsWeightPerTask;
            }

            public void setAdjustedMaxSplitsWeightPerTask(long adjustedMaxSplitsWeightPerTask) {
                this.adjustedMaxSplitsWeightPerTask = adjustedMaxSplitsWeightPerTask;
                this.lastAdjustmentNanos = Optional.of(QueueSizeAdjuster.this.ticker.read());
            }

            public Optional<Long> getLastAdjustmentNanos() {
                return this.lastAdjustmentNanos;
            }
        }
    }
}

