/*
 * Decompiled with CFR 0.152.
 */
package io.trino.execution.executor.dedicated;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.base.Ticker;
import com.google.common.collect.ImmutableSet;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.errorprone.annotations.ThreadSafe;
import com.google.errorprone.annotations.concurrent.GuardedBy;
import com.google.inject.Inject;
import io.airlift.concurrent.ThreadPoolExecutorMBean;
import io.airlift.concurrent.Threads;
import io.airlift.log.Logger;
import io.airlift.units.Duration;
import io.opentelemetry.api.trace.Tracer;
import io.trino.execution.SplitRunner;
import io.trino.execution.TaskId;
import io.trino.execution.TaskManagerConfig;
import io.trino.execution.executor.RunningSplitInfo;
import io.trino.execution.executor.TaskExecutor;
import io.trino.execution.executor.TaskHandle;
import io.trino.execution.executor.dedicated.TaskEntry;
import io.trino.execution.executor.scheduler.FairScheduler;
import io.trino.spi.VersionEmbedder;
import jakarta.annotation.PostConstruct;
import jakarta.annotation.PreDestroy;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.OptionalInt;
import java.util.Set;
import java.util.concurrent.ScheduledThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.function.DoubleSupplier;
import java.util.function.Predicate;
import org.weakref.jmx.Managed;
import org.weakref.jmx.Nested;

@ThreadSafe
public class ThreadPerDriverTaskExecutor
implements TaskExecutor {
    private static final Logger LOG = Logger.get(ThreadPerDriverTaskExecutor.class);
    private final FairScheduler scheduler;
    private final Tracer tracer;
    private final VersionEmbedder versionEmbedder;
    private final int targetGlobalLeafDrivers;
    private final int minDriversPerTask;
    private final int maxDriversPerTask;
    private final ScheduledThreadPoolExecutor backgroundTasks = new ScheduledThreadPoolExecutor(2, Threads.daemonThreadsNamed((String)"task-executor-scheduler-%s"));
    @GuardedBy(value="this")
    private final Map<TaskId, TaskEntry> tasks = new HashMap<TaskId, TaskEntry>();
    @GuardedBy(value="this")
    private boolean closed;
    @GuardedBy(value="this")
    private int runningLeafDrivers;

    @Inject
    public ThreadPerDriverTaskExecutor(TaskManagerConfig config, Tracer tracer, VersionEmbedder versionEmbedder) {
        this(tracer, versionEmbedder, new FairScheduler(config.getMaxWorkerThreads(), "SplitRunner-%d", Ticker.systemTicker()), config.getMinDriversPerTask(), config.getMaxDriversPerTask(), config.getMinDrivers());
    }

    @VisibleForTesting
    public ThreadPerDriverTaskExecutor(Tracer tracer, VersionEmbedder versionEmbedder, FairScheduler scheduler, int minDriversPerTask, int maxDriversPerTask, int targetGlobalLeafDrivers) {
        this.scheduler = scheduler;
        this.tracer = Objects.requireNonNull(tracer, "tracer is null");
        this.versionEmbedder = Objects.requireNonNull(versionEmbedder, "versionEmbedder is null");
        this.minDriversPerTask = minDriversPerTask;
        this.maxDriversPerTask = maxDriversPerTask;
        this.targetGlobalLeafDrivers = targetGlobalLeafDrivers;
    }

    @Override
    @PostConstruct
    public synchronized void start() {
        this.scheduler.start();
        this.backgroundTasks.scheduleWithFixedDelay(this::scheduleMoreLeafSplits, 0L, 100L, TimeUnit.MILLISECONDS);
        this.backgroundTasks.scheduleWithFixedDelay(this::adjustConcurrency, 0L, 10L, TimeUnit.MILLISECONDS);
        this.backgroundTasks.scheduleWithFixedDelay(this::logDiagnostics, 0L, 30L, TimeUnit.SECONDS);
    }

    @Override
    @PreDestroy
    public synchronized void stop() {
        this.closed = true;
        this.tasks.values().forEach(TaskEntry::destroy);
        this.backgroundTasks.shutdownNow();
        this.scheduler.close();
    }

    @Override
    public synchronized TaskHandle addTask(TaskId taskId, DoubleSupplier utilizationSupplier, int initialSplitConcurrency, Duration splitConcurrencyAdjustFrequency, OptionalInt maxDriversPerTask) {
        Preconditions.checkArgument((!this.closed ? 1 : 0) != 0, (Object)"Executor is already closed");
        TaskEntry task = new TaskEntry(taskId, this.scheduler, this.versionEmbedder, this.tracer, initialSplitConcurrency, utilizationSupplier);
        this.tasks.put(taskId, task);
        return task;
    }

    @Override
    public synchronized void removeTask(TaskHandle handle) {
        TaskEntry entry = (TaskEntry)handle;
        this.tasks.remove(entry.taskId());
        if (!entry.isDestroyed()) {
            entry.destroy();
        }
    }

    @Override
    public synchronized List<ListenableFuture<Void>> enqueueSplits(TaskHandle handle, boolean intermediate, List<? extends SplitRunner> splits) {
        Preconditions.checkArgument((!this.closed ? 1 : 0) != 0, (Object)"Executor is already closed");
        TaskEntry entry = (TaskEntry)handle;
        ArrayList<ListenableFuture<Void>> futures = new ArrayList<ListenableFuture<Void>>();
        for (SplitRunner splitRunner : splits) {
            if (intermediate) {
                futures.add(entry.runSplit(splitRunner));
                continue;
            }
            futures.add(entry.enqueueLeafSplit(splitRunner));
        }
        this.scheduleMoreLeafSplits();
        return futures;
    }

    private boolean scheduleLeafSplit(TaskEntry task) {
        boolean scheduled = task.dequeueAndRunLeafSplit(this::leafSplitDone);
        if (scheduled) {
            ++this.runningLeafDrivers;
        }
        return scheduled;
    }

    private synchronized void leafSplitDone() {
        --this.runningLeafDrivers;
        this.scheduleMoreLeafSplits();
    }

    private synchronized void scheduleMoreLeafSplits() {
        for (TaskEntry task : this.tasks.values()) {
            int target = Math.max(0, this.minDriversPerTask - task.runningLeafSplits());
            for (int i = 0; i < target && this.scheduleLeafSplit(task); ++i) {
            }
        }
        ArrayDeque<TaskEntry> queue = new ArrayDeque<TaskEntry>(this.tasks.values());
        int target = this.targetGlobalLeafDrivers - this.runningLeafDrivers;
        for (int i = 0; i < target && !queue.isEmpty(); ++i) {
            TaskEntry task = (TaskEntry)queue.poll();
            if (task.runningLeafSplits() >= Math.min(task.targetConcurrency(), this.maxDriversPerTask)) continue;
            this.scheduleLeafSplit(task);
            if (!task.hasPendingLeafSplits()) continue;
            queue.add(task);
        }
    }

    private void adjustConcurrency() {
        for (TaskEntry task : this.tasks.values()) {
            task.updateConcurrency();
        }
    }

    private void logDiagnostics() {
        if (LOG.isDebugEnabled()) {
            StringBuilder builder = new StringBuilder();
            builder.append("Queue:\n");
            builder.append(this.scheduler.diagnostics().indent(4));
            builder.append("Query tasks:\n");
            for (TaskEntry task : this.tasks.values()) {
                builder.append("%s: [total running = %s, leaf running = %s, leaf pending = %s, target concurrency = %s]\n".formatted(task.taskId(), task.totalRunningSplits(), task.runningLeafSplits(), task.pendingLeafSplitCount(), task.targetConcurrency()).indent(4));
            }
            LOG.debug("\n" + String.valueOf(builder));
        }
    }

    @Override
    public Set<TaskId> getStuckSplitTaskIds(Duration processingDurationThreshold, Predicate<RunningSplitInfo> filter) {
        return ImmutableSet.of();
    }

    @Managed
    public synchronized int getTasks() {
        return this.tasks.size();
    }

    @Managed
    public synchronized int getTotalRunningSplits() {
        return this.tasks.values().stream().mapToInt(TaskEntry::totalRunningSplits).sum();
    }

    @Managed
    public synchronized int getTotalRunningLeafSplits() {
        return this.tasks.values().stream().mapToInt(TaskEntry::runningLeafSplits).sum();
    }

    @Managed
    public synchronized int getTotalPendingLeafSplits() {
        return this.tasks.values().stream().mapToInt(TaskEntry::pendingLeafSplitCount).sum();
    }

    @Managed(description="Scheduler executor")
    @Nested
    public ThreadPoolExecutorMBean getSchedulerExecutor() {
        return this.scheduler.getSchedulerExecutor();
    }

    @Managed(description="Task executor")
    @Nested
    public ThreadPoolExecutorMBean getTaskExecutor() {
        return this.scheduler.getTaskExecutor();
    }

    @Managed
    public int getConcurrencyControlTotalSlots() {
        return this.scheduler.getConcurrencyControlTotalSlots();
    }

    @Managed
    public int getConcurrencyControlAvailableSlots() {
        return this.scheduler.getConcurrencyControlAvailableSlots();
    }
}

