/*
 * Decompiled with CFR 0.152.
 */
package io.trino.execution.scheduler;

import com.google.common.base.MoreObjects;
import com.google.common.base.Preconditions;
import com.google.common.base.Verify;
import com.google.common.base.VerifyException;
import io.airlift.units.DataSize;
import io.trino.execution.QueryManagerConfig;
import io.trino.execution.StageId;
import io.trino.execution.scheduler.TaskDescriptor;
import io.trino.spi.ErrorCodeSupplier;
import io.trino.spi.QueryId;
import io.trino.spi.StandardErrorCode;
import io.trino.spi.TrinoException;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Objects;
import java.util.Optional;
import javax.annotation.concurrent.GuardedBy;
import javax.annotation.concurrent.NotThreadSafe;
import javax.inject.Inject;
import org.weakref.jmx.Managed;

public class TaskDescriptorStorage {
    private final long maxMemoryInBytes;
    @GuardedBy(value="this")
    private final Map<QueryId, TaskDescriptors> storages = new HashMap<QueryId, TaskDescriptors>();
    @GuardedBy(value="this")
    private long reservedBytes;

    @Inject
    public TaskDescriptorStorage(QueryManagerConfig config) {
        this(Objects.requireNonNull(config, "config is null").getFaultTolerantExecutionTaskDescriptorStorageMaxMemory());
    }

    public TaskDescriptorStorage(DataSize maxMemory) {
        this.maxMemoryInBytes = Objects.requireNonNull(maxMemory, "maxMemory is null").toBytes();
    }

    public synchronized void initialize(QueryId queryId) {
        TaskDescriptors storage = new TaskDescriptors();
        Verify.verify((this.storages.putIfAbsent(queryId, storage) == null ? 1 : 0) != 0, (String)"storage is already initialized for query: %s", (Object)queryId);
        this.updateMemoryReservation(storage.getReservedBytes());
    }

    public synchronized void put(StageId stageId, TaskDescriptor descriptor) {
        TaskDescriptors storage = this.storages.get(stageId.getQueryId());
        if (storage == null) {
            return;
        }
        long previousReservedBytes = storage.getReservedBytes();
        storage.put(stageId, descriptor.getPartitionId(), descriptor);
        long currentReservedBytes = storage.getReservedBytes();
        long delta = currentReservedBytes - previousReservedBytes;
        this.updateMemoryReservation(delta);
    }

    public synchronized Optional<TaskDescriptor> get(StageId stageId, int partitionId) {
        TaskDescriptors storage = this.storages.get(stageId.getQueryId());
        if (storage == null) {
            return Optional.empty();
        }
        return Optional.of(storage.get(stageId, partitionId));
    }

    public synchronized void remove(StageId stageId, int partitionId) {
        TaskDescriptors storage = this.storages.get(stageId.getQueryId());
        if (storage == null) {
            return;
        }
        long previousReservedBytes = storage.getReservedBytes();
        storage.remove(stageId, partitionId);
        long currentReservedBytes = storage.getReservedBytes();
        long delta = currentReservedBytes - previousReservedBytes;
        this.updateMemoryReservation(delta);
    }

    public synchronized void destroy(QueryId queryId) {
        TaskDescriptors storage = this.storages.remove(queryId);
        if (storage != null) {
            this.updateMemoryReservation(-storage.getReservedBytes());
        }
    }

    private synchronized void updateMemoryReservation(long delta) {
        this.reservedBytes += delta;
        if (delta <= 0L) {
            return;
        }
        while (this.reservedBytes > this.maxMemoryInBytes) {
            QueryId killCandidate = this.storages.entrySet().stream().max(Comparator.comparingLong(entry -> ((TaskDescriptors)entry.getValue()).getReservedBytes())).map(Map.Entry::getKey).orElseThrow(() -> new VerifyException(String.format("storage is empty but reservedBytes (%s) is still greater than maxMemoryInBytes (%s)", this.reservedBytes, this.maxMemoryInBytes)));
            TaskDescriptors storage = this.storages.get(killCandidate);
            long previousReservedBytes = storage.getReservedBytes();
            storage.fail((RuntimeException)((Object)new TrinoException((ErrorCodeSupplier)StandardErrorCode.EXCEEDED_TASK_DESCRIPTOR_STORAGE_CAPACITY, String.format("Task descriptor storage capacity has been exceeded: %s > %s", DataSize.succinctBytes((long)this.maxMemoryInBytes), DataSize.succinctBytes((long)this.reservedBytes)))));
            long currentReservedBytes = storage.getReservedBytes();
            this.reservedBytes += currentReservedBytes - previousReservedBytes;
        }
    }

    @Managed
    public synchronized long getReservedBytes() {
        return this.reservedBytes;
    }

    private static class TaskDescriptorKey {
        private final StageId stageId;
        private final int partitionId;

        private TaskDescriptorKey(StageId stageId, int partitionId) {
            this.stageId = Objects.requireNonNull(stageId, "stageId is null");
            this.partitionId = partitionId;
        }

        public StageId getStageId() {
            return this.stageId;
        }

        public int getPartitionId() {
            return this.partitionId;
        }

        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || this.getClass() != o.getClass()) {
                return false;
            }
            TaskDescriptorKey key = (TaskDescriptorKey)o;
            return this.partitionId == key.partitionId && Objects.equals(this.stageId, key.stageId);
        }

        public int hashCode() {
            return Objects.hash(this.stageId, this.partitionId);
        }

        public String toString() {
            return MoreObjects.toStringHelper((Object)this).add("stageId", (Object)this.stageId).add("partitionId", this.partitionId).toString();
        }
    }

    @NotThreadSafe
    private static class TaskDescriptors {
        private final Map<TaskDescriptorKey, TaskDescriptor> descriptors = new HashMap<TaskDescriptorKey, TaskDescriptor>();
        private long reservedBytes;
        private RuntimeException failure;

        private TaskDescriptors() {
        }

        public void put(StageId stageId, int partitionId, TaskDescriptor descriptor) {
            this.throwIfFailed();
            TaskDescriptorKey key = new TaskDescriptorKey(stageId, partitionId);
            Preconditions.checkState((this.descriptors.putIfAbsent(key, descriptor) == null ? 1 : 0) != 0, (String)"task descriptor is already present for key %s ", (Object)key);
            this.reservedBytes += descriptor.getRetainedSizeInBytes();
        }

        public TaskDescriptor get(StageId stageId, int partitionId) {
            this.throwIfFailed();
            TaskDescriptorKey key = new TaskDescriptorKey(stageId, partitionId);
            TaskDescriptor descriptor = this.descriptors.get(key);
            if (descriptor == null) {
                throw new NoSuchElementException(String.format("descriptor not found for key %s", key));
            }
            return descriptor;
        }

        public void remove(StageId stageId, int partitionId) {
            this.throwIfFailed();
            TaskDescriptorKey key = new TaskDescriptorKey(stageId, partitionId);
            TaskDescriptor descriptor = this.descriptors.remove(key);
            if (descriptor == null) {
                throw new NoSuchElementException(String.format("descriptor not found for key %s", key));
            }
            this.reservedBytes -= descriptor.getRetainedSizeInBytes();
        }

        public long getReservedBytes() {
            return this.reservedBytes;
        }

        private void fail(RuntimeException failure) {
            if (this.failure == null) {
                this.descriptors.clear();
                this.reservedBytes = 0L;
                this.failure = failure;
            }
        }

        private void throwIfFailed() {
            if (this.failure != null) {
                throw this.failure;
            }
        }
    }
}

