/*
 * Decompiled with CFR 0.152.
 */
package io.kestra.core.runners;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.base.Charsets;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.hash.Hashing;
import io.kestra.core.metrics.MetricRegistry;
import io.kestra.core.models.executions.ExecutionKilled;
import io.kestra.core.models.executions.TaskRun;
import io.kestra.core.models.executions.TaskRunAttempt;
import io.kestra.core.models.flows.State;
import io.kestra.core.models.tasks.Output;
import io.kestra.core.models.tasks.RunnableTask;
import io.kestra.core.models.tasks.Task;
import io.kestra.core.models.tasks.retrys.AbstractRetry;
import io.kestra.core.queues.QueueException;
import io.kestra.core.queues.QueueInterface;
import io.kestra.core.queues.WorkerTaskQueueInterface;
import io.kestra.core.runners.RunContext;
import io.kestra.core.runners.WorkerTask;
import io.kestra.core.runners.WorkerTaskResult;
import io.kestra.core.serializers.JacksonMapper;
import io.kestra.core.utils.Await;
import io.kestra.core.utils.ExecutorsUtils;
import io.micronaut.context.ApplicationContext;
import io.micronaut.inject.qualifiers.Qualifiers;
import java.io.Closeable;
import java.io.IOException;
import java.time.Duration;
import java.time.ZonedDateTime;
import java.time.temporal.ChronoUnit;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import lombok.Generated;
import net.jodah.failsafe.Failsafe;
import net.jodah.failsafe.Policy;
import net.jodah.failsafe.RetryPolicy;
import net.jodah.failsafe.Timeout;
import net.jodah.failsafe.TimeoutExceededException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class Worker
implements Runnable,
Closeable {
    @Generated
    private static final Logger log = LoggerFactory.getLogger(Worker.class);
    private static final ObjectMapper MAPPER = JacksonMapper.ofJson();
    private final ApplicationContext applicationContext;
    private final WorkerTaskQueueInterface workerTaskQueue;
    private final QueueInterface<WorkerTaskResult> workerTaskResultQueue;
    private final QueueInterface<ExecutionKilled> executionKilledQueue;
    private final MetricRegistry metricRegistry;
    private final Set<String> killedExecution = ConcurrentHashMap.newKeySet();
    private final ExecutorService executors;
    private final Map<Long, AtomicInteger> metricRunningCount = new ConcurrentHashMap<Long, AtomicInteger>();
    private final List<WorkerThread> workerThreadReferences = new ArrayList<WorkerThread>();

    public Worker(ApplicationContext applicationContext, int thread) {
        this.applicationContext = applicationContext;
        this.workerTaskQueue = (WorkerTaskQueueInterface)applicationContext.getBean(WorkerTaskQueueInterface.class);
        this.workerTaskResultQueue = (QueueInterface)applicationContext.getBean(QueueInterface.class, Qualifiers.byName((String)"workerTaskResultQueue"));
        this.executionKilledQueue = (QueueInterface)applicationContext.getBean(QueueInterface.class, Qualifiers.byName((String)"executionKilledQueue"));
        this.metricRegistry = (MetricRegistry)applicationContext.getBean(MetricRegistry.class);
        ExecutorsUtils executorsUtils = (ExecutorsUtils)applicationContext.getBean(ExecutorsUtils.class);
        this.executors = executorsUtils.maxCachedThreadPool(thread, "worker");
    }

    @Override
    public void run() {
        this.executionKilledQueue.receive(executionKilled -> {
            if (executionKilled != null) {
                this.killedExecution.add(executionKilled.getExecutionId());
            }
            if (executionKilled != null) {
                Worker worker = this;
                synchronized (worker) {
                    this.workerThreadReferences.stream().filter(workerThread -> executionKilled.getExecutionId().equals(workerThread.getWorkerTask().getTaskRun().getExecutionId())).forEach(WorkerThread::kill);
                }
            }
        });
        this.workerTaskQueue.receive(Worker.class, workerTask -> this.executors.execute(() -> {
            if (workerTask.getTask() instanceof RunnableTask) {
                this.run((WorkerTask)workerTask, true);
            } else if (workerTask.getTask() instanceof io.kestra.core.tasks.flows.Worker) {
                RunContext runContext = workerTask.getRunContext();
                try {
                    io.kestra.core.tasks.flows.Worker workerTasks = (io.kestra.core.tasks.flows.Worker)workerTask.getTask();
                    for (Task currentTask : workerTasks.getTasks()) {
                        WorkerTask currentWorkerTask = workerTasks.workerTask(workerTask.getTaskRun(), currentTask, runContext);
                        WorkerTaskResult workerTaskResult = this.run(currentWorkerTask, false);
                        if (workerTaskResult.getTaskRun().getState().isFailed()) {
                            break;
                        }
                        runContext = runContext.updateVariables(workerTaskResult, workerTask.getTaskRun());
                    }
                }
                finally {
                    runContext.cleanup();
                }
            }
        }));
    }

    private static ZonedDateTime now() {
        return ZonedDateTime.now().truncatedTo(ChronoUnit.SECONDS);
    }

    private WorkerTask cleanUpTransient(WorkerTask workerTask) {
        try {
            return (WorkerTask)MAPPER.readValue(MAPPER.writeValueAsString((Object)workerTask), WorkerTask.class);
        }
        catch (JsonProcessingException e) {
            log.warn("Unable to cleanup transient", (Throwable)e);
            return workerTask;
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private WorkerTaskResult run(WorkerTask workerTask, Boolean cleanUp) throws QueueException {
        TaskRunAttempt lastAttempt;
        this.metricRegistry.counter("worker.started.count", this.metricRegistry.tags(workerTask, new String[0])).increment();
        if (workerTask.getTaskRun().getState().getCurrent() == State.Type.CREATED) {
            this.metricRegistry.timer("worker.queued.duration", this.metricRegistry.tags(workerTask, new String[0])).record(Duration.between(workerTask.getTaskRun().getState().getStartDate(), Worker.now()));
        }
        workerTask.logger().info("[namespace: {}] [flow: {}] [task: {}] [execution: {}] [taskrun: {}] [value: {}] Type {} started", new Object[]{workerTask.getTaskRun().getNamespace(), workerTask.getTaskRun().getFlowId(), workerTask.getTaskRun().getTaskId(), workerTask.getTaskRun().getExecutionId(), workerTask.getTaskRun().getId(), workerTask.getTaskRun().getValue(), workerTask.getTask().getClass().getSimpleName()});
        if (workerTask.logger().isDebugEnabled()) {
            workerTask.logger().debug("Variables\n{}", (Object)JacksonMapper.log(workerTask.getRunContext().getVariables()));
        }
        workerTask = workerTask.withTaskRun(workerTask.getTaskRun().withState(State.Type.RUNNING));
        this.workerTaskResultQueue.emit(new WorkerTaskResult(workerTask));
        if (this.killedExecution.contains(workerTask.getTaskRun().getExecutionId())) {
            workerTask = workerTask.withTaskRun(workerTask.getTaskRun().withState(State.Type.KILLED));
            WorkerTaskResult workerTaskResult = new WorkerTaskResult(workerTask);
            this.workerTaskResultQueue.emit(workerTaskResult);
            this.logTerminated(workerTask);
            return workerTaskResult;
        }
        AtomicReference<WorkerTask> current = new AtomicReference<WorkerTask>(workerTask);
        WorkerTask finalWorkerTask = (WorkerTask)Failsafe.with((Policy)((RetryPolicy)AbstractRetry.retryPolicy(workerTask.getTask().getRetry()).handleResultIf(result -> result.getTaskRun().lastAttempt() != null && Objects.requireNonNull(result.getTaskRun().lastAttempt()).getState().getCurrent() == State.Type.FAILED)).onRetry(e -> {
            WorkerTask lastResult = (WorkerTask)e.getLastResult();
            if (cleanUp.booleanValue()) {
                lastResult.getRunContext().cleanup();
            }
            lastResult = this.cleanUpTransient(lastResult);
            current.set(lastResult);
            this.metricRegistry.counter("worker.retryed.count", this.metricRegistry.tags((WorkerTask)current.get(), "attempt_count", String.valueOf(e.getAttemptCount()))).increment();
            this.workerTaskResultQueue.emit(new WorkerTaskResult(lastResult));
        }), (Policy[])new RetryPolicy[0]).get(() -> this.runAttempt((WorkerTask)current.get()));
        List<WorkerTaskResult> dynamicWorkerResults = finalWorkerTask.getRunContext().dynamicWorkerResults();
        if (cleanUp.booleanValue()) {
            finalWorkerTask.getRunContext().cleanup();
        }
        if ((lastAttempt = (finalWorkerTask = this.cleanUpTransient(finalWorkerTask)).getTaskRun().lastAttempt()) == null) {
            throw new IllegalStateException("Can find lastAttempt on taskRun '" + finalWorkerTask.getTaskRun().toString(true) + "'");
        }
        State.Type state = lastAttempt.getState().getCurrent();
        if (workerTask.getTask().getRetry() != null && workerTask.getTask().getRetry().getWarningOnRetry().booleanValue() && finalWorkerTask.getTaskRun().getAttempts().size() > 0 && state == State.Type.SUCCESS) {
            state = State.Type.WARNING;
        }
        finalWorkerTask = finalWorkerTask.withTaskRun(finalWorkerTask.getTaskRun().withState(state));
        try {
            WorkerTaskResult workerTaskResult = new WorkerTaskResult(finalWorkerTask, dynamicWorkerResults);
            this.workerTaskResultQueue.emit(workerTaskResult);
            WorkerTaskResult workerTaskResult2 = workerTaskResult;
            return workerTaskResult2;
        }
        catch (QueueException e2) {
            finalWorkerTask = workerTask.withTaskRun(workerTask.getTaskRun().withState(State.Type.FAILED));
            WorkerTaskResult workerTaskResult = new WorkerTaskResult(finalWorkerTask, dynamicWorkerResults);
            this.workerTaskResultQueue.emit(workerTaskResult);
            WorkerTaskResult workerTaskResult3 = workerTaskResult;
            return workerTaskResult3;
        }
        finally {
            this.logTerminated(finalWorkerTask);
        }
    }

    private void logTerminated(WorkerTask workerTask) {
        this.metricRegistry.counter("worker.ended.count", this.metricRegistry.tags(workerTask, new String[0])).increment();
        this.metricRegistry.timer("worker.ended.duration", this.metricRegistry.tags(workerTask, new String[0])).record(workerTask.getTaskRun().getState().getDuration());
        workerTask.logger().info("[namespace: {}] [flow: {}] [task: {}] [execution: {}] [taskrun: {}] [value: {}] Type {} with state {} completed in {}", new Object[]{workerTask.getTaskRun().getNamespace(), workerTask.getTaskRun().getFlowId(), workerTask.getTaskRun().getTaskId(), workerTask.getTaskRun().getExecutionId(), workerTask.getTaskRun().getId(), workerTask.getTaskRun().getValue(), workerTask.getTask().getClass().getSimpleName(), workerTask.getTaskRun().getState().getCurrent(), workerTask.getTaskRun().getState().humanDuration()});
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private WorkerTask runAttempt(WorkerTask workerTask) {
        State.Type state;
        RunnableTask task = (RunnableTask)((Object)workerTask.getTask());
        RunContext runContext = workerTask.getRunContext().forWorker(this.applicationContext, workerTask);
        Logger logger = runContext.logger();
        TaskRunAttempt.TaskRunAttemptBuilder builder = TaskRunAttempt.builder().state(new State().withState(State.Type.RUNNING));
        AtomicInteger metricRunningCount = this.getMetricRunningCount(workerTask);
        metricRunningCount.incrementAndGet();
        WorkerThread workerThread = new WorkerThread(logger, workerTask, task, runContext, this.metricRegistry);
        workerThread.start();
        this.workerTaskResultQueue.emit(new WorkerTaskResult(workerTask.withTaskRun(workerTask.getTaskRun().withAttempts(this.addAttempt(workerTask, builder.build())))));
        try {
            Worker worker = this;
            synchronized (worker) {
                this.workerThreadReferences.add(workerThread);
            }
            workerThread.join();
            state = workerThread.getTaskState();
        }
        catch (InterruptedException e) {
            logger.error("Failed to join WorkerThread {}", (Object)e.getMessage(), (Object)e);
            state = State.Type.FAILED;
        }
        finally {
            Worker e = this;
            synchronized (e) {
                this.workerThreadReferences.remove(workerThread);
            }
        }
        metricRunningCount.decrementAndGet();
        TaskRunAttempt taskRunAttempt = builder.metrics(runContext.metrics()).build().withState(state);
        if (workerThread.getTaskOutput() != null) {
            log.debug("Outputs\n{}", (Object)JacksonMapper.log(workerThread.getTaskOutput()));
        }
        if (runContext.metrics().size() > 0) {
            log.trace("Metrics\n{}", (Object)JacksonMapper.log(runContext.metrics()));
        }
        List<TaskRunAttempt> attempts = this.addAttempt(workerTask, taskRunAttempt);
        TaskRun taskRun = workerTask.getTaskRun().withAttempts(attempts);
        try {
            taskRun = taskRun.withOutputs(workerThread.getTaskOutput() != null ? workerThread.getTaskOutput().toMap() : ImmutableMap.of());
        }
        catch (Exception e) {
            logger.warn("Unable to save output on taskRun '{}'", (Object)taskRun, (Object)e);
        }
        return workerTask.withTaskRun(taskRun);
    }

    private List<TaskRunAttempt> addAttempt(WorkerTask workerTask, TaskRunAttempt taskRunAttempt) {
        return ImmutableList.builder().addAll(workerTask.getTaskRun().getAttempts() == null ? new ArrayList() : workerTask.getTaskRun().getAttempts()).add((Object)taskRunAttempt).build();
    }

    public AtomicInteger getMetricRunningCount(WorkerTask workerTask) {
        Object[] tags = this.metricRegistry.tags(workerTask, new String[0]);
        Arrays.sort(tags);
        long index = Hashing.goodFastHash((int)64).hashString((CharSequence)String.join((CharSequence)"-", (CharSequence[])tags), Charsets.UTF_8).asLong();
        return this.metricRunningCount.computeIfAbsent(index, l -> this.metricRegistry.gauge("worker.running.count", new AtomicInteger(0), this.metricRegistry.tags(workerTask, new String[0])));
    }

    @Override
    public void close() throws IOException {
        this.workerTaskQueue.pause();
        this.executionKilledQueue.pause();
        new Thread(() -> {
            try {
                this.executors.shutdown();
                this.executors.awaitTermination(5L, TimeUnit.MINUTES);
            }
            catch (InterruptedException e) {
                log.error("Failed to shutdown workers executors", (Throwable)e);
            }
        }, "worker-shutdown").start();
        Await.until(() -> {
            if (this.executors.isTerminated() && this.getWorkerThreadReferences().size() == 0) {
                log.info("No more workers busy, shutting down!");
                try {
                    this.workerTaskResultQueue.close();
                }
                catch (IOException e) {
                    log.error("Failed to close workerTaskResultQueue", (Throwable)e);
                }
                return true;
            }
            log.warn("Waiting worker with still {} thread(s) running, waiting!", (Object)this.getWorkerThreadReferences().size());
            return false;
        }, Duration.ofSeconds(1L));
        this.workerTaskQueue.close();
        this.executionKilledQueue.close();
        this.workerTaskResultQueue.close();
    }

    @Generated
    public Map<Long, AtomicInteger> getMetricRunningCount() {
        return this.metricRunningCount;
    }

    @Generated
    public List<WorkerThread> getWorkerThreadReferences() {
        return this.workerThreadReferences;
    }

    public static class WorkerThread
    extends Thread {
        @Generated
        private final Object $lock = new Object[0];
        Logger logger;
        WorkerTask workerTask;
        RunnableTask<?> task;
        RunContext runContext;
        MetricRegistry metricRegistry;
        Output taskOutput;
        State.Type taskState;
        boolean killed = false;

        public WorkerThread(Logger logger, WorkerTask workerTask, RunnableTask<?> task, RunContext runContext, MetricRegistry metricRegistry) {
            super("WorkerThread");
            this.setUncaughtExceptionHandler(this::exceptionHandler);
            this.logger = logger;
            this.workerTask = workerTask;
            this.task = task;
            this.runContext = runContext;
            this.metricRegistry = metricRegistry;
        }

        @Override
        public void run() {
            try {
                if (this.workerTask.getTask().getTimeout() != null) {
                    Failsafe.with((Policy)((Timeout)Timeout.of((Duration)this.workerTask.getTask().getTimeout()).withInterrupt(true).onFailure(event -> this.metricRegistry.counter("worker.timeout.count", this.metricRegistry.tags(this.workerTask, "attempt_count", String.valueOf(event.getAttemptCount()))).increment())), (Policy[])new Timeout[0]).run(() -> {
                        this.taskOutput = this.task.run(this.runContext);
                    });
                } else {
                    this.taskOutput = this.task.run(this.runContext);
                }
                this.taskState = State.Type.SUCCESS;
                if (this.taskOutput != null && this.taskOutput.finalState().isPresent()) {
                    this.taskState = this.taskOutput.finalState().get();
                }
            }
            catch (TimeoutExceededException e) {
                this.exceptionHandler(this, new io.kestra.core.exceptions.TimeoutExceededException(this.workerTask.getTask().getTimeout(), (Exception)((Object)e)));
            }
            catch (Exception e) {
                this.exceptionHandler(this, e);
            }
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        public void kill() {
            Object object = this.$lock;
            synchronized (object) {
                this.killed = true;
                this.taskState = State.Type.KILLED;
                this.interrupt();
            }
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        private void exceptionHandler(Thread t, Throwable e) {
            Object object = this.$lock;
            synchronized (object) {
                if (!this.killed) {
                    this.logger.error(e.getMessage(), e);
                    this.taskState = State.Type.FAILED;
                }
            }
        }

        @Generated
        public Logger getLogger() {
            return this.logger;
        }

        @Generated
        public WorkerTask getWorkerTask() {
            return this.workerTask;
        }

        @Generated
        public RunnableTask<?> getTask() {
            return this.task;
        }

        @Generated
        public RunContext getRunContext() {
            return this.runContext;
        }

        @Generated
        public MetricRegistry getMetricRegistry() {
            return this.metricRegistry;
        }

        @Generated
        public Output getTaskOutput() {
            return this.taskOutput;
        }

        @Generated
        public State.Type getTaskState() {
            return this.taskState;
        }

        @Generated
        public boolean isKilled() {
            return this.killed;
        }
    }
}

