/*
 * Decompiled with CFR 0.152.
 */
package io.trino.execution.executor.scheduler;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableSet;
import io.trino.annotation.NotThreadSafe;
import io.trino.execution.executor.scheduler.PriorityQueue;
import io.trino.execution.executor.scheduler.State;
import io.trino.execution.executor.scheduler.Task;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;

@NotThreadSafe
final class SchedulingGroup<T> {
    private State state;
    private long weight;
    private final Map<T, Task> tasks = new HashMap<T, Task>();
    private final PriorityQueue<T> runnableQueue = new PriorityQueue();
    private final Set<T> blocked = new HashSet<T>();
    private final PriorityQueue<T> baselineWeights = new PriorityQueue();

    public SchedulingGroup() {
        this.state = State.BLOCKED;
    }

    public void enqueue(T handle, long deltaWeight) {
        Task task = this.tasks.get(handle);
        if (task == null) {
            task = new Task(this.baselineWeight());
            this.tasks.put(handle, task);
        } else if (task.state() == State.BLOCKED) {
            this.blocked.remove(handle);
            task.addWeight(this.baselineWeight());
        }
        this.weight -= task.uncommittedWeight();
        this.weight += deltaWeight;
        task.commitWeight(deltaWeight);
        task.setState(State.RUNNABLE);
        this.runnableQueue.add(handle, task.weight());
        this.baselineWeights.addOrReplace(handle, task.weight());
        this.updateState();
    }

    public T dequeue(long expectedWeight) {
        Preconditions.checkArgument((this.state == State.RUNNABLE ? 1 : 0) != 0);
        T task = this.runnableQueue.takeOrThrow();
        Task info = this.tasks.get(task);
        info.setUncommittedWeight(expectedWeight);
        info.setState(State.RUNNING);
        this.weight += expectedWeight;
        this.baselineWeights.addOrReplace(task, info.weight());
        this.updateState();
        return task;
    }

    public void finish(T task) {
        Preconditions.checkArgument((boolean)this.tasks.containsKey(task), (String)"Unknown task: %s", task);
        this.tasks.remove(task);
        this.blocked.remove(task);
        this.runnableQueue.removeIfPresent(task);
        this.baselineWeights.removeIfPresent(task);
        this.updateState();
    }

    public void block(T handle, long deltaWeight) {
        Preconditions.checkArgument((boolean)this.tasks.containsKey(handle), (String)"Unknown task: %s", handle);
        Preconditions.checkArgument((!this.runnableQueue.contains(handle) ? 1 : 0) != 0, (String)"Task is already in queue: %s", handle);
        this.weight += deltaWeight;
        Task task = this.tasks.get(handle);
        task.commitWeight(deltaWeight);
        task.setState(State.BLOCKED);
        task.addWeight(-this.baselineWeight());
        this.blocked.add(handle);
        this.baselineWeights.remove(handle);
        this.updateState();
    }

    public long baselineWeight() {
        if (this.baselineWeights.isEmpty()) {
            return 0L;
        }
        return this.baselineWeights.nextPriority();
    }

    public void addWeight(long delta) {
        this.weight += delta;
    }

    private void updateState() {
        this.state = this.blocked.size() == this.tasks.size() ? State.BLOCKED : (this.runnableQueue.isEmpty() ? State.RUNNING : State.RUNNABLE);
    }

    public long weight() {
        return this.weight;
    }

    public Set<T> tasks() {
        return ImmutableSet.copyOf(this.tasks.keySet());
    }

    public State state() {
        return this.state;
    }

    public T peek() {
        return this.runnableQueue.peek();
    }

    public int runnableCount() {
        return this.runnableQueue.size();
    }

    public String toString() {
        StringBuilder builder = new StringBuilder();
        for (Map.Entry<T, Task> entry : this.tasks.entrySet()) {
            T key = entry.getKey();
            Task task = entry.getValue();
            String prefix = "%s %s".formatted(key == this.peek() ? "=>" : "  ", key);
            String details = switch (task.state()) {
                default -> throw new IncompatibleClassChangeError();
                case State.BLOCKED -> "[BLOCKED, saved delta = %s]".formatted(task.weight());
                case State.RUNNABLE -> "[RUNNABLE, weight = %s]".formatted(task.weight());
                case State.RUNNING -> "[RUNNING, weight = %s, uncommitted = %s]".formatted(task.weight(), task.uncommittedWeight());
            };
            builder.append(prefix).append(" ").append(details).append("\n");
        }
        return builder.toString();
    }
}

