/*
 * Decompiled with CFR 0.152.
 */
package io.trino.execution;

import com.google.common.base.MoreObjects;
import com.google.common.collect.Sets;
import com.google.inject.Inject;
import io.airlift.log.Logger;
import io.trino.execution.PartitionedSplitsInfo;
import io.trino.execution.RemoteTask;
import io.trino.execution.TaskId;
import io.trino.metadata.InternalNode;
import io.trino.util.FinalizerService;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Consumer;
import javax.annotation.concurrent.ThreadSafe;

@ThreadSafe
public class NodeTaskMap {
    private static final Logger log = Logger.get(NodeTaskMap.class);
    private final ConcurrentHashMap<InternalNode, NodeTasks> nodeTasksMap = new ConcurrentHashMap();
    private final FinalizerService finalizerService;

    @Inject
    public NodeTaskMap(FinalizerService finalizerService) {
        this.finalizerService = Objects.requireNonNull(finalizerService, "finalizerService is null");
    }

    public void addTask(InternalNode node, RemoteTask task) {
        this.createOrGetNodeTasks(node).addTask(task);
    }

    public PartitionedSplitsInfo getPartitionedSplitsOnNode(InternalNode node) {
        return this.createOrGetNodeTasks(node).getPartitionedSplitsInfo();
    }

    public PartitionedSplitCountTracker createPartitionedSplitCountTracker(InternalNode node, TaskId taskId) {
        return this.createOrGetNodeTasks(node).createPartitionedSplitCountTracker(taskId);
    }

    private NodeTasks createOrGetNodeTasks(InternalNode node) {
        NodeTasks nodeTasks = this.nodeTasksMap.get(node);
        if (nodeTasks == null) {
            nodeTasks = this.addNodeTask(node);
        }
        return nodeTasks;
    }

    private NodeTasks addNodeTask(InternalNode node) {
        NodeTasks newNodeTasks = new NodeTasks(this.finalizerService);
        NodeTasks nodeTasks = this.nodeTasksMap.putIfAbsent(node, newNodeTasks);
        if (nodeTasks == null) {
            return newNodeTasks;
        }
        return nodeTasks;
    }

    private static class NodeTasks {
        private final Set<RemoteTask> remoteTasks = Sets.newConcurrentHashSet();
        private final AtomicInteger nodeTotalPartitionedSplitCount = new AtomicInteger();
        private final AtomicLong nodeTotalPartitionedSplitWeight = new AtomicLong();
        private final FinalizerService finalizerService;

        public NodeTasks(FinalizerService finalizerService) {
            this.finalizerService = Objects.requireNonNull(finalizerService, "finalizerService is null");
        }

        private PartitionedSplitsInfo getPartitionedSplitsInfo() {
            return PartitionedSplitsInfo.forSplitCountAndWeightSum(this.nodeTotalPartitionedSplitCount.get(), this.nodeTotalPartitionedSplitWeight.get());
        }

        private void addTask(RemoteTask task) {
            if (this.remoteTasks.add(task)) {
                if (task.getTaskStatus().getState().isDone()) {
                    this.remoteTasks.remove(task);
                    return;
                }
                task.addStateChangeListener(taskStatus -> {
                    if (taskStatus.getState().isDone()) {
                        this.remoteTasks.remove(task);
                    }
                });
            }
        }

        public PartitionedSplitCountTracker createPartitionedSplitCountTracker(TaskId taskId) {
            Objects.requireNonNull(taskId, "taskId is null");
            TaskPartitionedSplitCountTracker tracker = new TaskPartitionedSplitCountTracker(taskId, this.nodeTotalPartitionedSplitCount, this.nodeTotalPartitionedSplitWeight);
            PartitionedSplitCountTracker partitionedSplitCountTracker = new PartitionedSplitCountTracker(tracker);
            this.finalizerService.addFinalizer(partitionedSplitCountTracker, tracker::cleanup);
            return partitionedSplitCountTracker;
        }

        @ThreadSafe
        private static class TaskPartitionedSplitCountTracker
        implements Consumer<PartitionedSplitsInfo> {
            private final TaskId taskId;
            private final AtomicInteger nodeTotalPartitionedSplitCount;
            private final AtomicLong nodeTotalPartitionedSplitWeight;
            private final AtomicInteger localPartitionedSplitCount = new AtomicInteger();
            private final AtomicLong localPartitionedSplitWeight = new AtomicLong();

            public TaskPartitionedSplitCountTracker(TaskId taskId, AtomicInteger nodeTotalPartitionedSplitCount, AtomicLong nodeTotalPartitionedSplitWeight) {
                this.taskId = Objects.requireNonNull(taskId, "taskId is null");
                this.nodeTotalPartitionedSplitCount = Objects.requireNonNull(nodeTotalPartitionedSplitCount, "nodeTotalPartitionedSplitCount is null");
                this.nodeTotalPartitionedSplitWeight = Objects.requireNonNull(nodeTotalPartitionedSplitWeight, "nodeTotalPartitionedSplitWeight is null");
            }

            @Override
            public synchronized void accept(PartitionedSplitsInfo partitionedSplits) {
                if (partitionedSplits == null || partitionedSplits.getCount() < 0 || partitionedSplits.getWeightSum() < 0L) {
                    this.clearLocalSplitInfo(false);
                    Objects.requireNonNull(partitionedSplits, "partitionedSplits is null");
                    throw new IllegalArgumentException("Invalid negative value: " + partitionedSplits);
                }
                int newCount = partitionedSplits.getCount();
                long newWeight = partitionedSplits.getWeightSum();
                int countDelta = newCount - this.localPartitionedSplitCount.getAndSet(newCount);
                long weightDelta = newWeight - this.localPartitionedSplitWeight.getAndSet(newWeight);
                if (countDelta != 0) {
                    this.nodeTotalPartitionedSplitCount.addAndGet(countDelta);
                }
                if (weightDelta != 0L) {
                    this.nodeTotalPartitionedSplitWeight.addAndGet(weightDelta);
                }
            }

            private void clearLocalSplitInfo(boolean reportAsLeaked) {
                int leakedCount = this.localPartitionedSplitCount.getAndSet(0);
                long leakedWeight = this.localPartitionedSplitWeight.getAndSet(0L);
                if (leakedCount == 0 && leakedWeight == 0L) {
                    return;
                }
                if (reportAsLeaked) {
                    log.error("BUG! %s for %s leaked with %s partitioned splits (weight: %s). Cleaning up so server can continue to function.", new Object[]{this.getClass().getName(), this.taskId, leakedCount, leakedWeight});
                }
                this.nodeTotalPartitionedSplitCount.addAndGet(-leakedCount);
                this.nodeTotalPartitionedSplitWeight.addAndGet(-leakedWeight);
            }

            public void cleanup() {
                this.clearLocalSplitInfo(true);
            }

            public String toString() {
                return MoreObjects.toStringHelper((Object)this).add("taskId", (Object)this.taskId).add("splits", (Object)this.localPartitionedSplitCount).add("weight", (Object)this.localPartitionedSplitWeight).toString();
            }
        }
    }

    public static class PartitionedSplitCountTracker {
        private final Consumer<PartitionedSplitsInfo> splitSetter;

        public PartitionedSplitCountTracker(Consumer<PartitionedSplitsInfo> splitSetter) {
            this.splitSetter = Objects.requireNonNull(splitSetter, "splitSetter is null");
        }

        public void setPartitionedSplits(PartitionedSplitsInfo partitionedSplits) {
            this.splitSetter.accept(partitionedSplits);
        }

        public String toString() {
            return this.splitSetter.toString();
        }
    }
}

