/*
 * Decompiled with CFR 0.152.
 */
package io.trino.server;

import com.google.common.collect.ImmutableList;
import com.google.common.util.concurrent.Uninterruptibles;
import com.google.inject.Inject;
import io.airlift.bootstrap.LifeCycleManager;
import io.airlift.concurrent.Threads;
import io.airlift.log.Logger;
import io.airlift.units.Duration;
import io.trino.execution.SqlTaskManager;
import io.trino.execution.StateMachine;
import io.trino.execution.TaskId;
import io.trino.execution.TaskInfo;
import io.trino.execution.TaskState;
import io.trino.metadata.NodeState;
import io.trino.server.ServerConfig;
import io.trino.server.ShutdownAction;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Supplier;
import org.assertj.core.util.VisibleForTesting;

public class NodeStateManager {
    private static final Logger log = Logger.get(NodeStateManager.class);
    private static final Duration LIFECYCLE_STOP_TIMEOUT = new Duration(30.0, TimeUnit.SECONDS);
    private final ScheduledExecutorService shutdownHandler = Executors.newSingleThreadScheduledExecutor(Threads.threadsNamed((String)"shutdown-handler-%s"));
    private final ExecutorService lifeCycleStopper = Executors.newSingleThreadExecutor(Threads.threadsNamed((String)"lifecycle-stopper-%s"));
    private final LifeCycleManager lifeCycleManager;
    private final SqlTasksObservable sqlTasksObservable;
    private final Supplier<List<TaskInfo>> taskInfoSupplier;
    private final boolean isCoordinator;
    private final ShutdownAction shutdownAction;
    private final Duration gracePeriod;
    private final ScheduledExecutorService executor;
    private final AtomicReference<VersionedState> nodeState = new AtomicReference<VersionedState>(new VersionedState(NodeState.ACTIVE, 0L));
    private final AtomicLong stateVersionProvider = new AtomicLong(0L);

    @Inject
    public NodeStateManager(SqlTaskManager sqlTaskManager, ServerConfig serverConfig, ShutdownAction shutdownAction, LifeCycleManager lifeCycleManager) {
        this(Objects.requireNonNull(sqlTaskManager, "sqlTaskManager is null")::addStateChangeListener, Objects.requireNonNull(sqlTaskManager, "sqlTaskManager is null")::getAllTaskInfo, serverConfig, shutdownAction, lifeCycleManager, Executors.newSingleThreadScheduledExecutor(Threads.threadsNamed((String)"drain-handler-%s")));
    }

    @VisibleForTesting
    public NodeStateManager(SqlTasksObservable sqlTasksObservable, Supplier<List<TaskInfo>> taskInfoSupplier, ServerConfig serverConfig, ShutdownAction shutdownAction, LifeCycleManager lifeCycleManager, ScheduledExecutorService executor) {
        this.sqlTasksObservable = Objects.requireNonNull(sqlTasksObservable, "sqlTasksObservable is null");
        this.taskInfoSupplier = Objects.requireNonNull(taskInfoSupplier, "taskInfoSupplier is null");
        this.shutdownAction = Objects.requireNonNull(shutdownAction, "shutdownAction is null");
        this.lifeCycleManager = Objects.requireNonNull(lifeCycleManager, "lifeCycleManager is null");
        this.isCoordinator = serverConfig.isCoordinator();
        this.gracePeriod = serverConfig.getGracePeriod();
        this.executor = Objects.requireNonNull(executor, "executor is null");
    }

    public NodeState getServerState() {
        return this.nodeState.get().state();
    }

    public synchronized void transitionState(NodeState state) {
        VersionedState currState = this.nodeState.get();
        if (currState.state() == state) {
            return;
        }
        switch (state) {
            case ACTIVE: {
                if (currState.state() == NodeState.DRAINING && this.nodeState.compareAndSet(currState, currState.toActive())) {
                    return;
                }
                if (currState.state() != NodeState.DRAINED || !this.nodeState.compareAndSet(currState, currState.toActive())) break;
                return;
            }
            case SHUTTING_DOWN: {
                if (currState.state() == NodeState.DRAINED && this.nodeState.compareAndSet(currState, currState.toShuttingDown())) {
                    this.requestTerminate();
                    return;
                }
                this.requestGracefulShutdown();
                this.nodeState.set(currState.toShuttingDown());
                return;
            }
            case DRAINING: {
                if (currState.state() != NodeState.ACTIVE || !this.nodeState.compareAndSet(currState, currState.toDraining())) break;
                this.requestDrain();
                return;
            }
            case DRAINED: {
                throw new IllegalStateException(String.format("Invalid state transition from %s to %s, transition to DRAINED is internal only", new Object[]{currState, state}));
            }
            case INACTIVE: {
                throw new IllegalStateException(String.format("Invalid state transition from %s to %s, INACTIVE is not a valid internal state", new Object[]{currState, state}));
            }
        }
        throw new IllegalStateException(String.format("Invalid state transition from %s to %s", new Object[]{currState, state}));
    }

    private long nextStateVersion() {
        return this.stateVersionProvider.incrementAndGet();
    }

    private synchronized void requestDrain() {
        log.debug("Drain requested, NodeState: %s", new Object[]{this.getServerState()});
        if (this.isCoordinator) {
            throw new UnsupportedOperationException("Cannot drain coordinator");
        }
        VersionedState expectedState = this.nodeState.get();
        this.executor.schedule(() -> this.drain(expectedState), this.gracePeriod.toMillis(), TimeUnit.MILLISECONDS);
    }

    private void requestTerminate() {
        log.info("Immediate Shutdown requested");
        if (this.isCoordinator) {
            throw new UnsupportedOperationException("Cannot shutdown coordinator");
        }
        this.shutdownHandler.schedule(this::terminate, 0L, TimeUnit.MILLISECONDS);
    }

    private void requestGracefulShutdown() {
        log.info("Shutdown requested");
        if (this.isCoordinator) {
            throw new UnsupportedOperationException("Cannot shutdown coordinator");
        }
        VersionedState expectedState = this.nodeState.get();
        this.shutdownHandler.schedule(() -> this.shutdown(expectedState), this.gracePeriod.toMillis(), TimeUnit.MILLISECONDS);
    }

    private void shutdown(VersionedState expectedState) {
        this.waitActiveTasksToFinish(expectedState);
        this.terminate();
    }

    private void terminate() {
        Future<Object> shutdownFuture = this.lifeCycleStopper.submit(() -> {
            this.lifeCycleManager.stop();
            return null;
        });
        try {
            shutdownFuture.get(LIFECYCLE_STOP_TIMEOUT.toMillis(), TimeUnit.MILLISECONDS);
        }
        catch (TimeoutException e) {
            log.warn((Throwable)e, "Timed out waiting for the life cycle to stop");
        }
        catch (InterruptedException e) {
            log.warn((Throwable)e, "Interrupted while waiting for the life cycle to stop");
            Thread.currentThread().interrupt();
        }
        catch (ExecutionException e) {
            log.warn((Throwable)e, "Problem stopping the life cycle");
        }
        this.shutdownAction.onShutdown();
    }

    private void drain(VersionedState expectedState) {
        if (this.nodeState.get() == expectedState) {
            this.waitActiveTasksToFinish(expectedState);
        }
        this.drainingComplete(expectedState);
    }

    private synchronized void drainingComplete(VersionedState expectedState) {
        VersionedState drained = expectedState.toDrained();
        boolean success = this.nodeState.compareAndSet(expectedState, drained);
        if (success) {
            log.info("Worker State change: DRAINING -> DRAINED, server can be safely SHUT DOWN.");
        } else {
            log.info("Worker State change: %s, expected: %s, will not transition to DRAINED", new Object[]{this.nodeState.get(), expectedState});
        }
    }

    private void waitActiveTasksToFinish(VersionedState expectedState) {
        while (this.nodeState.get() == expectedState) {
            List<TaskInfo> activeTasks = this.getActiveTasks();
            log.info("Waiting for %s active tasks to finish", new Object[]{activeTasks.size()});
            if (activeTasks.isEmpty()) break;
            this.waitTasksToFinish(activeTasks, expectedState);
        }
        if (this.nodeState.get() == expectedState) {
            Uninterruptibles.sleepUninterruptibly((long)this.gracePeriod.toMillis(), (TimeUnit)TimeUnit.MILLISECONDS);
        }
    }

    private void waitTasksToFinish(List<TaskInfo> activeTasks, VersionedState expectedState) {
        CountDownLatch countDownLatch = new CountDownLatch(activeTasks.size());
        for (TaskInfo taskInfo : activeTasks) {
            this.sqlTasksObservable.addStateChangeListener(taskInfo.taskStatus().getTaskId(), newState -> {
                if (newState.isDone()) {
                    log.info("Task %s has finished", new Object[]{taskInfo.taskStatus().getTaskId()});
                    countDownLatch.countDown();
                }
            });
        }
        try {
            while (!countDownLatch.await(1L, TimeUnit.SECONDS)) {
                if (this.nodeState.get() == expectedState) continue;
                log.info("Wait for tasks interrupted by state change, worker is no longer draining.");
                break;
            }
        }
        catch (InterruptedException e) {
            log.warn("Interrupted while waiting for all tasks to finish");
            Thread.currentThread().interrupt();
        }
    }

    private List<TaskInfo> getActiveTasks() {
        return (List)this.taskInfoSupplier.get().stream().filter(taskInfo -> !taskInfo.taskStatus().getState().isDone()).collect(ImmutableList.toImmutableList());
    }

    public static interface SqlTasksObservable {
        public void addStateChangeListener(TaskId var1, StateMachine.StateChangeListener<TaskState> var2);
    }

    private class VersionedState {
        private final NodeState state;
        private final long version;

        private VersionedState(NodeState state, long version) {
            this.state = Objects.requireNonNull(state, "state is null");
            this.version = version;
        }

        public VersionedState toActive() {
            return new VersionedState(NodeState.ACTIVE, NodeStateManager.this.nextStateVersion());
        }

        public VersionedState toDraining() {
            return new VersionedState(NodeState.DRAINING, NodeStateManager.this.nextStateVersion());
        }

        public VersionedState toDrained() {
            return new VersionedState(NodeState.DRAINED, NodeStateManager.this.nextStateVersion());
        }

        public VersionedState toShuttingDown() {
            return new VersionedState(NodeState.SHUTTING_DOWN, NodeStateManager.this.nextStateVersion());
        }

        public NodeState state() {
            return this.state;
        }

        public long version() {
            return this.version;
        }

        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || this.getClass() != o.getClass()) {
                return false;
            }
            VersionedState that = (VersionedState)o;
            return this.version == that.version && this.state == that.state;
        }

        public int hashCode() {
            return Objects.hash(new Object[]{this.state, this.version});
        }

        public String toString() {
            return String.format("%s-%s", this.state.toString(), this.version);
        }
    }
}

