/*
 * Decompiled with CFR 0.152.
 */
package io.trino.execution;

import com.google.common.base.MoreObjects;
import com.google.common.collect.Sets;
import io.airlift.log.Logger;
import io.trino.execution.RemoteTask;
import io.trino.execution.TaskId;
import io.trino.metadata.InternalNode;
import io.trino.util.FinalizerService;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.IntConsumer;
import javax.annotation.concurrent.ThreadSafe;
import javax.inject.Inject;

@ThreadSafe
public class NodeTaskMap {
    private static final Logger log = Logger.get(NodeTaskMap.class);
    private final ConcurrentHashMap<InternalNode, NodeTasks> nodeTasksMap = new ConcurrentHashMap();
    private final FinalizerService finalizerService;

    @Inject
    public NodeTaskMap(FinalizerService finalizerService) {
        this.finalizerService = Objects.requireNonNull(finalizerService, "finalizerService is null");
    }

    public void addTask(InternalNode node, RemoteTask task) {
        this.createOrGetNodeTasks(node).addTask(task);
    }

    public int getPartitionedSplitsOnNode(InternalNode node) {
        return this.createOrGetNodeTasks(node).getPartitionedSplitCount();
    }

    public PartitionedSplitCountTracker createPartitionedSplitCountTracker(InternalNode node, TaskId taskId) {
        return this.createOrGetNodeTasks(node).createPartitionedSplitCountTracker(taskId);
    }

    private NodeTasks createOrGetNodeTasks(InternalNode node) {
        NodeTasks nodeTasks = this.nodeTasksMap.get(node);
        if (nodeTasks == null) {
            nodeTasks = this.addNodeTask(node);
        }
        return nodeTasks;
    }

    private NodeTasks addNodeTask(InternalNode node) {
        NodeTasks newNodeTasks = new NodeTasks(this.finalizerService);
        NodeTasks nodeTasks = this.nodeTasksMap.putIfAbsent(node, newNodeTasks);
        if (nodeTasks == null) {
            return newNodeTasks;
        }
        return nodeTasks;
    }

    public static class PartitionedSplitCountTracker {
        private final IntConsumer splitSetter;

        public PartitionedSplitCountTracker(IntConsumer splitSetter) {
            this.splitSetter = Objects.requireNonNull(splitSetter, "splitSetter is null");
        }

        public void setPartitionedSplitCount(int partitionedSplitCount) {
            this.splitSetter.accept(partitionedSplitCount);
        }

        public String toString() {
            return this.splitSetter.toString();
        }
    }

    private static class NodeTasks {
        private final Set<RemoteTask> remoteTasks = Sets.newConcurrentHashSet();
        private final AtomicInteger nodeTotalPartitionedSplitCount = new AtomicInteger();
        private final FinalizerService finalizerService;

        public NodeTasks(FinalizerService finalizerService) {
            this.finalizerService = Objects.requireNonNull(finalizerService, "finalizerService is null");
        }

        private int getPartitionedSplitCount() {
            return this.nodeTotalPartitionedSplitCount.get();
        }

        private void addTask(RemoteTask task) {
            if (this.remoteTasks.add(task)) {
                task.addStateChangeListener(taskStatus -> {
                    if (taskStatus.getState().isDone()) {
                        this.remoteTasks.remove(task);
                    }
                });
                if (task.getTaskStatus().getState().isDone()) {
                    this.remoteTasks.remove(task);
                }
            }
        }

        public PartitionedSplitCountTracker createPartitionedSplitCountTracker(TaskId taskId) {
            Objects.requireNonNull(taskId, "taskId is null");
            TaskPartitionedSplitCountTracker tracker = new TaskPartitionedSplitCountTracker(taskId);
            PartitionedSplitCountTracker partitionedSplitCountTracker = new PartitionedSplitCountTracker(tracker::setPartitionedSplitCount);
            this.finalizerService.addFinalizer(partitionedSplitCountTracker, tracker::cleanup);
            return partitionedSplitCountTracker;
        }

        @ThreadSafe
        private class TaskPartitionedSplitCountTracker {
            private final TaskId taskId;
            private final AtomicInteger localPartitionedSplitCount = new AtomicInteger();

            public TaskPartitionedSplitCountTracker(TaskId taskId) {
                this.taskId = Objects.requireNonNull(taskId, "taskId is null");
            }

            public synchronized void setPartitionedSplitCount(int partitionedSplitCount) {
                if (partitionedSplitCount < 0) {
                    int oldValue = this.localPartitionedSplitCount.getAndSet(0);
                    NodeTasks.this.nodeTotalPartitionedSplitCount.addAndGet(-oldValue);
                    throw new IllegalArgumentException("partitionedSplitCount is negative");
                }
                int oldValue = this.localPartitionedSplitCount.getAndSet(partitionedSplitCount);
                NodeTasks.this.nodeTotalPartitionedSplitCount.addAndGet(partitionedSplitCount - oldValue);
            }

            public void cleanup() {
                int leakedSplits = this.localPartitionedSplitCount.getAndSet(0);
                if (leakedSplits == 0) {
                    return;
                }
                log.error("BUG! %s for %s leaked with %s partitioned splits.  Cleaning up so server can continue to function.", new Object[]{this.getClass().getName(), this.taskId, leakedSplits});
                NodeTasks.this.nodeTotalPartitionedSplitCount.addAndGet(-leakedSplits);
            }

            public String toString() {
                return MoreObjects.toStringHelper((Object)this).add("taskId", (Object)this.taskId).add("splits", (Object)this.localPartitionedSplitCount).toString();
            }
        }
    }
}

