/*
 * Decompiled with CFR 0.152.
 */
package io.trino.execution;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.base.Throwables;
import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheLoader;
import com.google.common.collect.ImmutableList;
import com.google.common.util.concurrent.ListenableFuture;
import io.airlift.concurrent.ThreadPoolExecutorMBean;
import io.airlift.concurrent.Threads;
import io.airlift.log.Logger;
import io.airlift.node.NodeInfo;
import io.airlift.stats.CounterStat;
import io.airlift.stats.GcMonitor;
import io.airlift.units.DataSize;
import io.airlift.units.Duration;
import io.trino.Session;
import io.trino.SystemSessionProperties;
import io.trino.collect.cache.NonEvictableLoadingCache;
import io.trino.collect.cache.SafeCaches;
import io.trino.event.SplitMonitor;
import io.trino.exchange.ExchangeManagerRegistry;
import io.trino.execution.DynamicFiltersCollector;
import io.trino.execution.LocationFactory;
import io.trino.execution.SplitAssignment;
import io.trino.execution.SqlTask;
import io.trino.execution.SqlTaskExecutionFactory;
import io.trino.execution.SqlTaskIoStats;
import io.trino.execution.StateMachine;
import io.trino.execution.TaskFailureListener;
import io.trino.execution.TaskId;
import io.trino.execution.TaskInfo;
import io.trino.execution.TaskManagementExecutor;
import io.trino.execution.TaskManagerConfig;
import io.trino.execution.TaskState;
import io.trino.execution.TaskStatus;
import io.trino.execution.buffer.BufferResult;
import io.trino.execution.buffer.OutputBuffers;
import io.trino.execution.executor.TaskExecutor;
import io.trino.memory.LocalMemoryManager;
import io.trino.memory.NodeMemoryConfig;
import io.trino.memory.QueryContext;
import io.trino.operator.RetryPolicy;
import io.trino.spi.ErrorCodeSupplier;
import io.trino.spi.QueryId;
import io.trino.spi.StandardErrorCode;
import io.trino.spi.TrinoException;
import io.trino.spi.VersionEmbedder;
import io.trino.spi.predicate.Domain;
import io.trino.spiller.LocalSpillManager;
import io.trino.spiller.NodeSpillConfig;
import io.trino.sql.planner.LocalExecutionPlanner;
import io.trino.sql.planner.PlanFragment;
import io.trino.sql.planner.plan.DynamicFilterId;
import java.io.Closeable;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import javax.annotation.PostConstruct;
import javax.annotation.PreDestroy;
import javax.inject.Inject;
import org.joda.time.DateTime;
import org.joda.time.ReadableInstant;
import org.weakref.jmx.Flatten;
import org.weakref.jmx.Managed;
import org.weakref.jmx.Nested;

public class SqlTaskManager
implements Closeable {
    private static final Logger log = Logger.get(SqlTaskManager.class);
    private final VersionEmbedder versionEmbedder;
    private final ExecutorService taskNotificationExecutor;
    private final ThreadPoolExecutorMBean taskNotificationExecutorMBean;
    private final ScheduledExecutorService taskManagementExecutor;
    private final ScheduledExecutorService driverYieldExecutor;
    private final Duration infoCacheTime;
    private final Duration clientTimeout;
    private final NonEvictableLoadingCache<QueryId, QueryContext> queryContexts;
    private final NonEvictableLoadingCache<TaskId, SqlTask> tasks;
    private final SqlTaskIoStats cachedStats = new SqlTaskIoStats();
    private final SqlTaskIoStats finishedTaskStats = new SqlTaskIoStats();
    private final long queryMaxMemoryPerNode;
    private final CounterStat failedTasks = new CounterStat();

    @Inject
    public SqlTaskManager(VersionEmbedder versionEmbedder, LocalExecutionPlanner planner, LocationFactory locationFactory, TaskExecutor taskExecutor, SplitMonitor splitMonitor, NodeInfo nodeInfo, LocalMemoryManager localMemoryManager, TaskManagementExecutor taskManagementExecutor, TaskManagerConfig config, NodeMemoryConfig nodeMemoryConfig, LocalSpillManager localSpillManager, NodeSpillConfig nodeSpillConfig, GcMonitor gcMonitor, ExchangeManagerRegistry exchangeManagerRegistry) {
        Objects.requireNonNull(nodeInfo, "nodeInfo is null");
        Objects.requireNonNull(config, "config is null");
        this.infoCacheTime = config.getInfoMaxAge();
        this.clientTimeout = config.getClientTimeout();
        DataSize maxBufferSize = config.getSinkMaxBufferSize();
        DataSize maxBroadcastBufferSize = config.getSinkMaxBroadcastBufferSize();
        this.versionEmbedder = Objects.requireNonNull(versionEmbedder, "versionEmbedder is null");
        this.taskNotificationExecutor = Executors.newFixedThreadPool(config.getTaskNotificationThreads(), Threads.threadsNamed((String)"task-notification-%s"));
        this.taskNotificationExecutorMBean = new ThreadPoolExecutorMBean((ThreadPoolExecutor)this.taskNotificationExecutor);
        this.taskManagementExecutor = Objects.requireNonNull(taskManagementExecutor, "taskManagementExecutor cannot be null").getExecutor();
        this.driverYieldExecutor = Executors.newScheduledThreadPool(config.getTaskYieldThreads(), Threads.threadsNamed((String)"task-yield-%s"));
        SqlTaskExecutionFactory sqlTaskExecutionFactory = new SqlTaskExecutionFactory(this.taskNotificationExecutor, taskExecutor, planner, splitMonitor, config);
        DataSize maxQueryMemoryPerNode = nodeMemoryConfig.getMaxQueryMemoryPerNode();
        DataSize maxQuerySpillPerNode = nodeSpillConfig.getQueryMaxSpillPerNode();
        this.queryMaxMemoryPerNode = maxQueryMemoryPerNode.toBytes();
        this.queryContexts = SafeCaches.buildNonEvictableCache((CacheBuilder)CacheBuilder.newBuilder().weakValues(), (CacheLoader)CacheLoader.from(queryId -> this.createQueryContext((QueryId)queryId, localMemoryManager, localSpillManager, gcMonitor, maxQueryMemoryPerNode, maxQuerySpillPerNode)));
        this.tasks = SafeCaches.buildNonEvictableCache((CacheBuilder)CacheBuilder.newBuilder(), (CacheLoader)CacheLoader.from(taskId -> SqlTask.createSqlTask(taskId, locationFactory.createLocalTaskLocation((TaskId)taskId), nodeInfo.getNodeId(), (QueryContext)this.queryContexts.getUnchecked((Object)taskId.getQueryId()), sqlTaskExecutionFactory, this.taskNotificationExecutor, sqlTask -> this.finishedTaskStats.merge(sqlTask.getIoStats()), maxBufferSize, maxBroadcastBufferSize, Objects.requireNonNull(exchangeManagerRegistry, "exchangeManagerRegistry is null"), this.failedTasks)));
    }

    private QueryContext createQueryContext(QueryId queryId, LocalMemoryManager localMemoryManager, LocalSpillManager localSpillManager, GcMonitor gcMonitor, DataSize maxQueryUserMemoryPerNode, DataSize maxQuerySpillPerNode) {
        return new QueryContext(queryId, maxQueryUserMemoryPerNode, localMemoryManager.getMemoryPool(), gcMonitor, this.taskNotificationExecutor, this.driverYieldExecutor, maxQuerySpillPerNode, localSpillManager.getSpillSpaceTracker());
    }

    @PostConstruct
    public void start() {
        this.taskManagementExecutor.scheduleWithFixedDelay(() -> {
            try {
                this.removeOldTasks();
            }
            catch (Throwable e) {
                log.warn(e, "Error removing old tasks");
            }
            try {
                this.failAbandonedTasks();
            }
            catch (Throwable e) {
                log.warn(e, "Error canceling abandoned tasks");
            }
        }, 200L, 200L, TimeUnit.MILLISECONDS);
        this.taskManagementExecutor.scheduleWithFixedDelay(() -> {
            try {
                this.updateStats();
            }
            catch (Throwable e) {
                log.warn(e, "Error updating stats");
            }
        }, 0L, 1L, TimeUnit.SECONDS);
    }

    @Override
    @PreDestroy
    public void close() {
        boolean taskCanceled = false;
        for (SqlTask task : this.tasks.asMap().values()) {
            if (task.getTaskState().isDone()) continue;
            task.failed(new TrinoException((ErrorCodeSupplier)StandardErrorCode.SERVER_SHUTTING_DOWN, String.format("Server is shutting down. Task %s has been canceled", task.getTaskId())));
            taskCanceled = true;
        }
        if (taskCanceled) {
            try {
                TimeUnit.SECONDS.sleep(5L);
            }
            catch (InterruptedException e) {
                Thread.currentThread().interrupt();
            }
        }
        this.taskNotificationExecutor.shutdownNow();
    }

    @Managed
    @Flatten
    public SqlTaskIoStats getIoStats() {
        return this.cachedStats;
    }

    @Managed(description="Task notification executor")
    @Nested
    public ThreadPoolExecutorMBean getTaskNotificationExecutor() {
        return this.taskNotificationExecutorMBean;
    }

    @Managed(description="Failed tasks counter")
    @Nested
    public CounterStat getFailedTasks() {
        return this.failedTasks;
    }

    public List<SqlTask> getAllTasks() {
        return ImmutableList.copyOf(this.tasks.asMap().values());
    }

    public List<TaskInfo> getAllTaskInfo() {
        return (List)this.tasks.asMap().values().stream().map(SqlTask::getTaskInfo).collect(ImmutableList.toImmutableList());
    }

    public TaskInfo getTaskInfo(TaskId taskId) {
        Objects.requireNonNull(taskId, "taskId is null");
        SqlTask sqlTask = (SqlTask)this.tasks.getUnchecked((Object)taskId);
        sqlTask.recordHeartbeat();
        return sqlTask.getTaskInfo();
    }

    public TaskStatus getTaskStatus(TaskId taskId) {
        Objects.requireNonNull(taskId, "taskId is null");
        SqlTask sqlTask = (SqlTask)this.tasks.getUnchecked((Object)taskId);
        sqlTask.recordHeartbeat();
        return sqlTask.getTaskStatus();
    }

    public ListenableFuture<TaskInfo> getTaskInfo(TaskId taskId, long currentVersion) {
        Objects.requireNonNull(taskId, "taskId is null");
        SqlTask sqlTask = (SqlTask)this.tasks.getUnchecked((Object)taskId);
        sqlTask.recordHeartbeat();
        return sqlTask.getTaskInfo(currentVersion);
    }

    public String getTaskInstanceId(TaskId taskId) {
        SqlTask sqlTask = (SqlTask)this.tasks.getUnchecked((Object)taskId);
        sqlTask.recordHeartbeat();
        return sqlTask.getTaskInstanceId();
    }

    public ListenableFuture<TaskStatus> getTaskStatus(TaskId taskId, long currentVersion) {
        Objects.requireNonNull(taskId, "taskId is null");
        SqlTask sqlTask = (SqlTask)this.tasks.getUnchecked((Object)taskId);
        sqlTask.recordHeartbeat();
        return sqlTask.getTaskStatus(currentVersion);
    }

    public DynamicFiltersCollector.VersionedDynamicFilterDomains acknowledgeAndGetNewDynamicFilterDomains(TaskId taskId, long currentDynamicFiltersVersion) {
        Objects.requireNonNull(taskId, "taskId is null");
        SqlTask sqlTask = (SqlTask)this.tasks.getUnchecked((Object)taskId);
        sqlTask.recordHeartbeat();
        return sqlTask.acknowledgeAndGetNewDynamicFilterDomains(currentDynamicFiltersVersion);
    }

    public TaskInfo updateTask(Session session, TaskId taskId, Optional<PlanFragment> fragment, List<SplitAssignment> splitAssignments, OutputBuffers outputBuffers, Map<DynamicFilterId, Domain> dynamicFilterDomains) {
        try {
            return (TaskInfo)this.versionEmbedder.embedVersion(() -> this.doUpdateTask(session, taskId, fragment, splitAssignments, outputBuffers, dynamicFilterDomains)).call();
        }
        catch (Exception e) {
            Throwables.throwIfUnchecked((Throwable)e);
            throw new RuntimeException(e);
        }
    }

    private TaskInfo doUpdateTask(Session session, TaskId taskId, Optional<PlanFragment> fragment, List<SplitAssignment> splitAssignments, OutputBuffers outputBuffers, Map<DynamicFilterId, Domain> dynamicFilterDomains) {
        Objects.requireNonNull(session, "session is null");
        Objects.requireNonNull(taskId, "taskId is null");
        Objects.requireNonNull(fragment, "fragment is null");
        Objects.requireNonNull(splitAssignments, "splitAssignments is null");
        Objects.requireNonNull(outputBuffers, "outputBuffers is null");
        SqlTask sqlTask = (SqlTask)this.tasks.getUnchecked((Object)taskId);
        QueryContext queryContext = sqlTask.getQueryContext();
        if (!queryContext.isMemoryLimitsInitialized()) {
            RetryPolicy retryPolicy = SystemSessionProperties.getRetryPolicy(session);
            if (retryPolicy == RetryPolicy.TASK) {
                queryContext.initializeMemoryLimits(false, Long.MAX_VALUE);
            } else {
                long sessionQueryMaxMemoryPerNode = SystemSessionProperties.getQueryMaxMemoryPerNode(session).toBytes();
                queryContext.initializeMemoryLimits(SystemSessionProperties.resourceOvercommit(session), Math.min(sessionQueryMaxMemoryPerNode, this.queryMaxMemoryPerNode));
            }
        }
        sqlTask.recordHeartbeat();
        return sqlTask.updateTask(session, fragment, splitAssignments, outputBuffers, dynamicFilterDomains);
    }

    public ListenableFuture<BufferResult> getTaskResults(TaskId taskId, OutputBuffers.OutputBufferId bufferId, long startingSequenceId, DataSize maxSize) {
        Objects.requireNonNull(taskId, "taskId is null");
        Objects.requireNonNull(bufferId, "bufferId is null");
        Preconditions.checkArgument((startingSequenceId >= 0L ? 1 : 0) != 0, (Object)"startingSequenceId is negative");
        Objects.requireNonNull(maxSize, "maxSize is null");
        return ((SqlTask)this.tasks.getUnchecked((Object)taskId)).getTaskResults(bufferId, startingSequenceId, maxSize);
    }

    public void acknowledgeTaskResults(TaskId taskId, OutputBuffers.OutputBufferId bufferId, long sequenceId) {
        Objects.requireNonNull(taskId, "taskId is null");
        Objects.requireNonNull(bufferId, "bufferId is null");
        Preconditions.checkArgument((sequenceId >= 0L ? 1 : 0) != 0, (Object)"sequenceId is negative");
        ((SqlTask)this.tasks.getUnchecked((Object)taskId)).acknowledgeTaskResults(bufferId, sequenceId);
    }

    public TaskInfo destroyTaskResults(TaskId taskId, OutputBuffers.OutputBufferId bufferId) {
        Objects.requireNonNull(taskId, "taskId is null");
        Objects.requireNonNull(bufferId, "bufferId is null");
        return ((SqlTask)this.tasks.getUnchecked((Object)taskId)).destroyTaskResults(bufferId);
    }

    public TaskInfo cancelTask(TaskId taskId) {
        Objects.requireNonNull(taskId, "taskId is null");
        return ((SqlTask)this.tasks.getUnchecked((Object)taskId)).cancel();
    }

    public TaskInfo abortTask(TaskId taskId) {
        Objects.requireNonNull(taskId, "taskId is null");
        return ((SqlTask)this.tasks.getUnchecked((Object)taskId)).abort();
    }

    public TaskInfo failTask(TaskId taskId, Throwable failure) {
        Objects.requireNonNull(taskId, "taskId is null");
        Objects.requireNonNull(failure, "failure is null");
        return ((SqlTask)this.tasks.getUnchecked((Object)taskId)).failed(failure);
    }

    @VisibleForTesting
    void removeOldTasks() {
        DateTime oldestAllowedTask = DateTime.now().minus(this.infoCacheTime.toMillis());
        this.tasks.asMap().values().stream().map(SqlTask::getTaskInfo).filter(Objects::nonNull).forEach(taskInfo -> {
            TaskId taskId = taskInfo.getTaskStatus().getTaskId();
            try {
                DateTime endTime = taskInfo.getStats().getEndTime();
                if (endTime != null && endTime.isBefore((ReadableInstant)oldestAllowedTask)) {
                    this.tasks.asMap().remove(taskId);
                }
            }
            catch (RuntimeException e) {
                log.warn((Throwable)e, "Error while inspecting age of complete task %s", new Object[]{taskId});
            }
        });
    }

    private void failAbandonedTasks() {
        DateTime now = DateTime.now();
        DateTime oldestAllowedHeartbeat = now.minus(this.clientTimeout.toMillis());
        for (SqlTask sqlTask : this.tasks.asMap().values()) {
            try {
                DateTime lastHeartbeat;
                TaskInfo taskInfo = sqlTask.getTaskInfo();
                TaskStatus taskStatus = taskInfo.getTaskStatus();
                if (taskStatus.getState().isDone() || (lastHeartbeat = taskInfo.getLastHeartbeat()) == null || !lastHeartbeat.isBefore((ReadableInstant)oldestAllowedHeartbeat)) continue;
                log.info("Failing abandoned task %s", new Object[]{taskStatus.getTaskId()});
                sqlTask.failed(new TrinoException((ErrorCodeSupplier)StandardErrorCode.ABANDONED_TASK, String.format("Task %s has not been accessed since %s: currentTime %s", taskStatus.getTaskId(), lastHeartbeat, now)));
            }
            catch (RuntimeException e) {
                log.warn((Throwable)e, "Error while inspecting age of task %s", new Object[]{sqlTask.getTaskId()});
            }
        }
    }

    private void updateStats() {
        SqlTaskIoStats tempIoStats = new SqlTaskIoStats();
        tempIoStats.merge(this.finishedTaskStats);
        this.tasks.asMap().values().stream().filter(task -> !task.getTaskState().isDone()).forEach(task -> tempIoStats.merge(task.getIoStats()));
        this.cachedStats.resetTo(tempIoStats);
    }

    public void addStateChangeListener(TaskId taskId, StateMachine.StateChangeListener<TaskState> stateChangeListener) {
        Objects.requireNonNull(taskId, "taskId is null");
        ((SqlTask)this.tasks.getUnchecked((Object)taskId)).addStateChangeListener(stateChangeListener);
    }

    public void addSourceTaskFailureListener(TaskId taskId, TaskFailureListener listener) {
        ((SqlTask)this.tasks.getUnchecked((Object)taskId)).addSourceTaskFailureListener(listener);
    }

    public Optional<String> getTraceToken(TaskId taskId) {
        return ((SqlTask)this.tasks.getUnchecked((Object)taskId)).getTraceToken();
    }

    @VisibleForTesting
    public QueryContext getQueryContext(QueryId queryId) {
        return (QueryContext)this.queryContexts.getUnchecked((Object)queryId);
    }
}

