/*
 * Decompiled with CFR 0.152.
 */
package io.trino.execution.scheduler.faulttolerant;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.base.Suppliers;
import com.google.common.base.Verify;
import com.google.common.base.VerifyException;
import com.google.common.collect.HashBasedTable;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.ImmutableSetMultimap;
import com.google.common.collect.Multimap;
import com.google.common.collect.Table;
import com.google.common.math.Quantiles;
import com.google.common.math.Stats;
import com.google.errorprone.annotations.concurrent.GuardedBy;
import com.google.inject.Inject;
import io.airlift.compress.v3.zstd.ZstdCompressor;
import io.airlift.compress.v3.zstd.ZstdDecompressor;
import io.airlift.concurrent.Threads;
import io.airlift.json.JsonCodec;
import io.airlift.log.Logger;
import io.airlift.slice.SizeOf;
import io.airlift.units.DataSize;
import io.trino.annotation.NotThreadSafe;
import io.trino.execution.QueryManagerConfig;
import io.trino.execution.StageId;
import io.trino.execution.scheduler.faulttolerant.TaskDescriptor;
import io.trino.metadata.Split;
import io.trino.spi.ErrorCodeSupplier;
import io.trino.spi.QueryId;
import io.trino.spi.StandardErrorCode;
import io.trino.spi.TrinoException;
import io.trino.sql.planner.plan.PlanNodeId;
import jakarta.annotation.PostConstruct;
import jakarta.annotation.PreDestroy;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.LongStream;
import java.util.stream.Stream;
import org.weakref.jmx.Managed;
import org.weakref.jmx.Nested;

public class TaskDescriptorStorage {
    private static final Logger log = Logger.get(TaskDescriptorStorage.class);
    public static final int SINGLE_STEP_COMPRESSION_LIMIT = 1000;
    private final long maxMemoryInBytes;
    private final long compressingHighWaterMark;
    private final long compressingLowWaterMark;
    private final JsonCodec<TaskDescriptor> taskDescriptorJsonCodec;
    private final JsonCodec<Split> splitJsonCodec;
    private final StorageStats storageStats;
    @GuardedBy(value="this")
    private final Map<QueryId, TaskDescriptors> storages = new HashMap<QueryId, TaskDescriptors>();
    @GuardedBy(value="this")
    private long reservedUncompressedBytes;
    @GuardedBy(value="this")
    private long reservedCompressedBytes;
    @GuardedBy(value="this")
    private long originalCompressedBytes;
    @GuardedBy(value="this")
    private boolean compressing;
    private final ScheduledExecutorService executor = Executors.newSingleThreadScheduledExecutor(Threads.daemonThreadsNamed((String)"task-descriptor-storage"));
    private volatile boolean running;

    @Inject
    public TaskDescriptorStorage(QueryManagerConfig config, JsonCodec<TaskDescriptor> taskDescriptorJsonCodec, JsonCodec<Split> splitJsonCodec) {
        this(config.getFaultTolerantExecutionTaskDescriptorStorageMaxMemory(), config.getFaultTolerantExecutionTaskDescriptorStorageHighWaterMark(), config.getFaultTolerantExecutionTaskDescriptorStorageLowWaterMark(), taskDescriptorJsonCodec, splitJsonCodec);
    }

    public TaskDescriptorStorage(DataSize maxMemory, DataSize compressingHighWaterMark, DataSize compressingLowWaterMark, JsonCodec<TaskDescriptor> taskDescriptorJsonCodec, JsonCodec<Split> splitJsonCodec) {
        this.maxMemoryInBytes = maxMemory.toBytes();
        this.compressingHighWaterMark = compressingHighWaterMark.toBytes();
        this.compressingLowWaterMark = compressingLowWaterMark.toBytes();
        this.taskDescriptorJsonCodec = Objects.requireNonNull(taskDescriptorJsonCodec, "taskDescriptorJsonCodec is null");
        this.splitJsonCodec = Objects.requireNonNull(splitJsonCodec, "splitJsonCodec is null");
        this.storageStats = new StorageStats((Supplier<StorageStatsValues>)Suppliers.memoizeWithExpiration(this::computeStats, (long)1L, (TimeUnit)TimeUnit.SECONDS));
    }

    @PostConstruct
    public void start() {
        this.running = true;
        this.executor.schedule(this::compressTaskDescriptorsJob, 10L, TimeUnit.SECONDS);
    }

    private void compressTaskDescriptorsJob() {
        if (!this.running) {
            return;
        }
        int delaySeconds = 10;
        try {
            if (!this.compressTaskDescriptorsStep()) {
                delaySeconds = 0;
            }
        }
        catch (Throwable e) {
            log.error(e, "Error in compressTaskDescriptorsJob");
        }
        finally {
            this.executor.schedule(this::compressTaskDescriptorsJob, (long)delaySeconds, TimeUnit.SECONDS);
        }
    }

    @PreDestroy
    public void stop() {
        this.running = false;
        this.executor.shutdownNow();
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private boolean compressTaskDescriptorsStep() {
        int limit = 1000;
        TaskDescriptorStorage taskDescriptorStorage = this;
        synchronized (taskDescriptorStorage) {
            if (!this.compressing) {
                return true;
            }
            for (Map.Entry<QueryId, TaskDescriptors> entry : this.storages.entrySet()) {
                if (limit <= 0) {
                    return false;
                }
                TaskDescriptors storage = entry.getValue();
                if (storage.isFullyCompressed()) continue;
                var holder = new Object(this){
                    int limitDelta;
                };
                int limitFinal = limit;
                this.runAndUpdateMemory(storage, () -> {
                    holder.limitDelta = storage.compress(limitFinal);
                }, false);
                limit -= holder.limitDelta;
            }
        }
        return limit > 0;
    }

    public synchronized void initialize(QueryId queryId) {
        TaskDescriptors storage = new TaskDescriptors();
        Verify.verify((this.storages.putIfAbsent(queryId, storage) == null ? 1 : 0) != 0, (String)"storage is already initialized for query: %s", (Object)queryId);
        this.updateMemoryReservation(storage.getReservedUncompressedBytes(), storage.getReservedCompressedBytes(), storage.getOriginalCompressedBytes(), true);
        this.updateCompressingFlag();
    }

    public synchronized void put(StageId stageId, TaskDescriptor descriptor) {
        TaskDescriptors storage = this.storages.get(stageId.getQueryId());
        if (storage == null) {
            return;
        }
        this.runAndUpdateMemory(storage, () -> storage.put(stageId, descriptor.getPartitionId(), descriptor), true);
    }

    @GuardedBy(value="this")
    private void runAndUpdateMemory(TaskDescriptors storage, Runnable operation, boolean considerKilling) {
        long previousReservedUncompressedBytes = storage.getReservedUncompressedBytes();
        long previousReservedCompressedBytes = storage.getReservedCompressedBytes();
        long previousOriginalCompressedBytes = storage.getOriginalCompressedBytes();
        operation.run();
        long currentReservedUncompressedBytes = storage.getReservedUncompressedBytes();
        long currentReservedCompressedBytes = storage.getReservedCompressedBytes();
        long currentOriginalCompressedBytes = storage.getOriginalCompressedBytes();
        long reservedUncompressedDelta = currentReservedUncompressedBytes - previousReservedUncompressedBytes;
        long reservedCompressedDelta = currentReservedCompressedBytes - previousReservedCompressedBytes;
        long originalCompressedDelta = currentOriginalCompressedBytes - previousOriginalCompressedBytes;
        this.updateMemoryReservation(reservedUncompressedDelta, reservedCompressedDelta, originalCompressedDelta, considerKilling);
        this.updateCompressingFlag();
    }

    public synchronized Optional<TaskDescriptor> get(StageId stageId, int partitionId) {
        TaskDescriptors storage = this.storages.get(stageId.getQueryId());
        if (storage == null) {
            return Optional.empty();
        }
        return Optional.of(storage.get(stageId, partitionId));
    }

    public synchronized void remove(StageId stageId, int partitionId) {
        TaskDescriptors storage = this.storages.get(stageId.getQueryId());
        if (storage == null) {
            return;
        }
        this.runAndUpdateMemory(storage, () -> storage.remove(stageId, partitionId), false);
    }

    public synchronized void destroy(QueryId queryId) {
        TaskDescriptors storage = this.storages.remove(queryId);
        if (storage != null) {
            this.updateMemoryReservation(-storage.getReservedUncompressedBytes(), -storage.getReservedCompressedBytes(), -storage.getOriginalCompressedBytes(), false);
            this.updateCompressingFlag();
        }
    }

    @GuardedBy(value="this")
    private void updateCompressingFlag() {
        if (!this.compressing && this.originalCompressedBytes + this.reservedUncompressedBytes > this.compressingHighWaterMark) {
            this.compressing = true;
        } else if (this.compressing && this.originalCompressedBytes + this.reservedUncompressedBytes < this.compressingLowWaterMark) {
            this.compressing = false;
        }
    }

    @GuardedBy(value="this")
    private void updateMemoryReservation(long reservedUncompressedDelta, long reservedCompressedDelta, long originalCompressedDelta, boolean considerKilling) {
        this.reservedUncompressedBytes += reservedUncompressedDelta;
        this.reservedCompressedBytes += reservedCompressedDelta;
        this.originalCompressedBytes += originalCompressedDelta;
        this.checkStatsNotNegative();
        if (reservedUncompressedDelta + reservedCompressedDelta <= 0L || !considerKilling) {
            return;
        }
        while (this.reservedUncompressedBytes + this.reservedCompressedBytes > this.maxMemoryInBytes) {
            QueryId killCandidate = this.storages.entrySet().stream().max(Comparator.comparingLong(entry -> ((TaskDescriptors)entry.getValue()).getReservedBytes())).map(Map.Entry::getKey).orElseThrow(() -> new VerifyException(String.format("storage is empty but reservedBytes (%s + %s) is still greater than maxMemoryInBytes (%s)", reservedUncompressedDelta, reservedCompressedDelta, this.maxMemoryInBytes)));
            TaskDescriptors storage = this.storages.get(killCandidate);
            if (log.isInfoEnabled()) {
                log.info("Failing query %s; reclaiming %s of %s/%s task descriptor memory from %s queries; extraStorageInfo=%s", new Object[]{killCandidate, storage.getReservedBytes(), DataSize.succinctBytes((long)this.reservedUncompressedBytes), DataSize.succinctBytes((long)this.reservedCompressedBytes), this.storages.size(), storage.getDebugInfo()});
            }
            this.runAndUpdateMemory(storage, () -> storage.fail(new TrinoException((ErrorCodeSupplier)StandardErrorCode.EXCEEDED_TASK_DESCRIPTOR_STORAGE_CAPACITY, String.format("Task descriptor storage capacity has been exceeded: %s > %s", DataSize.succinctBytes((long)(this.reservedUncompressedBytes + this.reservedCompressedBytes)), DataSize.succinctBytes((long)this.maxMemoryInBytes)))), false);
        }
    }

    @GuardedBy(value="this")
    private void checkStatsNotNegative() {
        Preconditions.checkState((this.reservedUncompressedBytes >= 0L ? 1 : 0) != 0, (Object)"reservedUncompressedBytes is negative");
        Preconditions.checkState((this.reservedUncompressedBytes >= 0L ? 1 : 0) != 0, (Object)"reservedCompressedBytes is negative");
        Preconditions.checkState((this.originalCompressedBytes >= 0L ? 1 : 0) != 0, (Object)"originalCompressedBytes is negative");
    }

    @VisibleForTesting
    synchronized long getReservedUncompressedBytes() {
        return this.reservedUncompressedBytes;
    }

    @VisibleForTesting
    synchronized long getReservedCompressedBytes() {
        return this.reservedCompressedBytes;
    }

    @VisibleForTesting
    synchronized long getOriginalCompressedBytes() {
        return this.originalCompressedBytes;
    }

    private TaskDescriptorHolder createTaskDescriptorHolder(TaskDescriptor taskDescriptor) {
        return new TaskDescriptorHolder(taskDescriptor);
    }

    @Managed
    @Nested
    public StorageStats getStats() {
        return this.storageStats;
    }

    private synchronized StorageStatsValues computeStats() {
        int queriesCount = this.storages.size();
        long stagesCount = this.storages.values().stream().mapToLong(TaskDescriptors::getStagesCount).sum();
        StorageStatsValue uncompressedReservedStats = this.getStorageStatsValue(queriesCount, stagesCount, this.reservedUncompressedBytes, TaskDescriptors::getReservedUncompressedBytes, TaskDescriptors::getStagesReservedUncompressedBytes);
        StorageStatsValue compressedReservedStats = this.getStorageStatsValue(queriesCount, stagesCount, this.reservedCompressedBytes, TaskDescriptors::getReservedCompressedBytes, TaskDescriptors::getStagesReservedCompressedBytes);
        StorageStatsValue originalCompressedStats = this.getStorageStatsValue(queriesCount, stagesCount, this.originalCompressedBytes, TaskDescriptors::getOriginalCompressedBytes, TaskDescriptors::getStagesOriginalCompressedBytes);
        return new StorageStatsValues(queriesCount, stagesCount, this.compressing, uncompressedReservedStats, compressedReservedStats, originalCompressedStats);
    }

    @GuardedBy(value="this")
    private StorageStatsValue getStorageStatsValue(int queriesCount, long stagesCount, long totalBytes, Function<TaskDescriptors, Long> queryBytes, Function<TaskDescriptors, Stream<? extends Long>> stageBytes) {
        Quantiles.ScaleAndIndexes percentiles = Quantiles.percentiles().indexes(new int[]{50, 90, 95});
        long queryBytesP50 = 0L;
        long queryBytesP90 = 0L;
        long queryBytesP95 = 0L;
        long queryBytesAvg = 0L;
        long stageBytesP50 = 0L;
        long stageBytesP90 = 0L;
        long stageBytesP95 = 0L;
        long stageBytesAvg = 0L;
        if (queriesCount > 0) {
            Map queryBytesPercentiles = percentiles.compute((Collection)this.storages.values().stream().map(queryBytes).collect(ImmutableList.toImmutableList()));
            queryBytesP50 = ((Double)queryBytesPercentiles.get(50)).longValue();
            queryBytesP90 = ((Double)queryBytesPercentiles.get(90)).longValue();
            queryBytesP95 = ((Double)queryBytesPercentiles.get(95)).longValue();
            queryBytesAvg = totalBytes / (long)queriesCount;
            List storagesReservedBytes = (List)this.storages.values().stream().flatMap(stageBytes).collect(ImmutableList.toImmutableList());
            if (!storagesReservedBytes.isEmpty()) {
                Map stagesReservedBytesPercentiles = percentiles.compute((Collection)storagesReservedBytes);
                stageBytesP50 = ((Double)stagesReservedBytesPercentiles.get(50)).longValue();
                stageBytesP90 = ((Double)stagesReservedBytesPercentiles.get(90)).longValue();
                stageBytesP95 = ((Double)stagesReservedBytesPercentiles.get(95)).longValue();
                stageBytesAvg = totalBytes / stagesCount;
            }
        }
        return new StorageStatsValue(totalBytes, queryBytesAvg, queryBytesP50, queryBytesP90, queryBytesP95, stageBytesAvg, stageBytesP50, stageBytesP90, stageBytesP95);
    }

    public static class StorageStats {
        private final Supplier<StorageStatsValues> statsSupplier;

        StorageStats(Supplier<StorageStatsValues> statsSupplier) {
            this.statsSupplier = Objects.requireNonNull(statsSupplier, "statsSupplier is null");
        }

        @Managed
        public long getQueriesCount() {
            return this.statsSupplier.get().queriesCount();
        }

        @Managed
        public long getStagesCount() {
            return this.statsSupplier.get().stagesCount();
        }

        @Managed
        public long getCompressionActive() {
            return this.statsSupplier.get().compressionActive() ? 1L : 0L;
        }

        @Managed
        public long getUncompressedReservedBytes() {
            return this.statsSupplier.get().uncompressedReservedStats().bytes();
        }

        @Managed
        public long getQueryUncompressedReservedBytesAvg() {
            return this.statsSupplier.get().uncompressedReservedStats().queryBytesAvg();
        }

        @Managed
        public long getQueryUncompressedReservedBytesP50() {
            return this.statsSupplier.get().uncompressedReservedStats().queryBytesP50();
        }

        @Managed
        public long getQueryUncompressedReservedBytesP90() {
            return this.statsSupplier.get().uncompressedReservedStats().queryBytesP90();
        }

        @Managed
        public long getQueryUncompressedReservedBytesP95() {
            return this.statsSupplier.get().uncompressedReservedStats().queryBytesP95();
        }

        @Managed
        public long getStageUncompressedReservedBytesAvg() {
            return this.statsSupplier.get().uncompressedReservedStats().stageBytesP50();
        }

        @Managed
        public long getStageUncompressedReservedBytesP50() {
            return this.statsSupplier.get().uncompressedReservedStats().stageBytesP50();
        }

        @Managed
        public long getStageUncompressedReservedBytesP90() {
            return this.statsSupplier.get().uncompressedReservedStats().stageBytesP90();
        }

        @Managed
        public long getStageUncompressedReservedBytesP95() {
            return this.statsSupplier.get().uncompressedReservedStats().stageBytesP95();
        }

        @Managed
        public long getCompressedReservedBytes() {
            return this.statsSupplier.get().compressedReservedStats().bytes();
        }

        @Managed
        public long getQueryCompressedReservedBytesAvg() {
            return this.statsSupplier.get().compressedReservedStats().queryBytesAvg();
        }

        @Managed
        public long getQueryCompressedReservedBytesP50() {
            return this.statsSupplier.get().compressedReservedStats().queryBytesP50();
        }

        @Managed
        public long getQueryCompressedReservedBytesP90() {
            return this.statsSupplier.get().compressedReservedStats().queryBytesP90();
        }

        @Managed
        public long getQueryCompressedReservedBytesP95() {
            return this.statsSupplier.get().compressedReservedStats().queryBytesP95();
        }

        @Managed
        public long getStageCompressedReservedBytesAvg() {
            return this.statsSupplier.get().compressedReservedStats().stageBytesP50();
        }

        @Managed
        public long getStageCompressedReservedBytesP50() {
            return this.statsSupplier.get().compressedReservedStats().stageBytesP50();
        }

        @Managed
        public long getStageCompressedReservedBytesP90() {
            return this.statsSupplier.get().compressedReservedStats().stageBytesP90();
        }

        @Managed
        public long getStageCompressedReservedBytesP95() {
            return this.statsSupplier.get().compressedReservedStats().stageBytesP95();
        }

        @Managed
        public long getOriginalCompressedBytes() {
            return this.statsSupplier.get().originalCompressedStats().bytes();
        }

        @Managed
        public long getQueryOriginalCompressedBytesAvg() {
            return this.statsSupplier.get().originalCompressedStats().queryBytesAvg();
        }

        @Managed
        public long getQueryOriginalCompressedBytesP50() {
            return this.statsSupplier.get().originalCompressedStats().queryBytesP50();
        }

        @Managed
        public long getQueryOriginalCompressedBytesP90() {
            return this.statsSupplier.get().originalCompressedStats().queryBytesP90();
        }

        @Managed
        public long getQueryOriginalCompressedBytesP95() {
            return this.statsSupplier.get().originalCompressedStats().queryBytesP95();
        }

        @Managed
        public long getStageOriginalCompressedBytesAvg() {
            return this.statsSupplier.get().originalCompressedStats().stageBytesP50();
        }

        @Managed
        public long getStageOriginalCompressedBytesP50() {
            return this.statsSupplier.get().originalCompressedStats().stageBytesP50();
        }

        @Managed
        public long getStageOriginalCompressedBytesP90() {
            return this.statsSupplier.get().originalCompressedStats().stageBytesP90();
        }

        @Managed
        public long getStageOriginalCompressedBytesP95() {
            return this.statsSupplier.get().originalCompressedStats().stageBytesP95();
        }
    }

    @NotThreadSafe
    private class TaskDescriptors {
        private final Table<StageId, Integer, TaskDescriptorHolder> descriptors = HashBasedTable.create();
        public boolean fullyCompressed;
        private long reservedUncompressedBytes;
        private long reservedCompressedBytes;
        private long originalCompressedBytes;
        private final Map<StageId, AtomicLong> stagesReservedUncompressedBytes = new HashMap<StageId, AtomicLong>();
        private final Map<StageId, AtomicLong> stagesReservedCompressedBytes = new HashMap<StageId, AtomicLong>();
        private final Map<StageId, AtomicLong> stagesOriginalCompressedBytes = new HashMap<StageId, AtomicLong>();
        private TrinoException failure;

        private TaskDescriptors() {
        }

        @GuardedBy(value="TaskDescriptorStorage.this")
        public void put(StageId stageId2, int partitionId, TaskDescriptor descriptor) {
            this.throwIfFailed();
            Preconditions.checkState((!this.descriptors.contains((Object)stageId2, (Object)partitionId) ? 1 : 0) != 0, (String)"task descriptor is already present for key %s/%s ", (Object)stageId2, (int)partitionId);
            TaskDescriptorHolder descriptorHolder = TaskDescriptorStorage.this.createTaskDescriptorHolder(descriptor);
            if (TaskDescriptorStorage.this.compressing) {
                descriptorHolder.compress();
            } else {
                this.fullyCompressed = false;
            }
            this.descriptors.put((Object)stageId2, (Object)partitionId, (Object)descriptorHolder);
            if (descriptorHolder.isCompressed()) {
                this.originalCompressedBytes += descriptorHolder.getUncompressedSize();
                this.reservedCompressedBytes += descriptorHolder.getCompressedSize();
                this.stagesOriginalCompressedBytes.computeIfAbsent(stageId2, stageId -> new AtomicLong()).addAndGet(descriptorHolder.getUncompressedSize());
                this.stagesReservedCompressedBytes.computeIfAbsent(stageId2, stageId -> new AtomicLong()).addAndGet(descriptorHolder.getCompressedSize());
            } else {
                this.reservedUncompressedBytes += descriptorHolder.getUncompressedSize();
                this.stagesReservedUncompressedBytes.computeIfAbsent(stageId2, stageId -> new AtomicLong()).addAndGet(descriptorHolder.getUncompressedSize());
            }
        }

        public TaskDescriptor get(StageId stageId, int partitionId) {
            this.throwIfFailed();
            TaskDescriptorHolder descriptor = (TaskDescriptorHolder)this.descriptors.get((Object)stageId, (Object)partitionId);
            if (descriptor == null) {
                throw new NoSuchElementException(String.format("descriptor not found for key %s/%s", stageId, partitionId));
            }
            return descriptor.getTaskDescriptor();
        }

        public void remove(StageId stageId2, int partitionId) {
            this.throwIfFailed();
            TaskDescriptorHolder descriptorHolder = (TaskDescriptorHolder)this.descriptors.remove((Object)stageId2, (Object)partitionId);
            if (descriptorHolder == null) {
                throw new NoSuchElementException(String.format("descriptor not found for key %s/%s", stageId2, partitionId));
            }
            if (descriptorHolder.isCompressed()) {
                this.originalCompressedBytes -= descriptorHolder.getUncompressedSize();
                this.reservedCompressedBytes -= descriptorHolder.getCompressedSize();
                this.stagesOriginalCompressedBytes.computeIfAbsent(stageId2, stageId -> new AtomicLong()).addAndGet(-descriptorHolder.getUncompressedSize());
                this.stagesReservedCompressedBytes.computeIfAbsent(stageId2, stageId -> new AtomicLong()).addAndGet(-descriptorHolder.getCompressedSize());
            } else {
                this.reservedUncompressedBytes -= descriptorHolder.getUncompressedSize();
                this.stagesReservedUncompressedBytes.computeIfAbsent(stageId2, stageId -> new AtomicLong()).addAndGet(-descriptorHolder.getUncompressedSize());
            }
        }

        public long getReservedUncompressedBytes() {
            return this.reservedUncompressedBytes;
        }

        public long getReservedCompressedBytes() {
            return this.reservedCompressedBytes;
        }

        public long getOriginalCompressedBytes() {
            return this.originalCompressedBytes;
        }

        public long getReservedBytes() {
            return this.reservedUncompressedBytes + this.reservedCompressedBytes;
        }

        private String getDebugInfo() {
            Multimap descriptorsByStageId = (Multimap)this.descriptors.cellSet().stream().collect(ImmutableSetMultimap.toImmutableSetMultimap(Table.Cell::getRowKey, Table.Cell::getValue));
            Map debugInfoByStageId = (Map)descriptorsByStageId.asMap().entrySet().stream().collect(ImmutableMap.toImmutableMap(Map.Entry::getKey, entry -> this.getDebugInfo((Collection)entry.getValue())));
            List<String> biggestSplits = descriptorsByStageId.entries().stream().flatMap(entry -> ((TaskDescriptorHolder)entry.getValue()).getTaskDescriptor().getSplits().getSplitsFlat().entries().stream().map(splitEntry -> Map.entry("%s/%s".formatted(entry.getKey(), splitEntry.getKey()), (Split)splitEntry.getValue()))).sorted(Comparator.comparingLong(entry -> ((Split)entry.getValue()).getRetainedSizeInBytes()).reversed()).limit(3L).map(entry -> "{nodeId=%s, size=%s, split=%s}".formatted(entry.getKey(), ((Split)entry.getValue()).getRetainedSizeInBytes(), TaskDescriptorStorage.this.splitJsonCodec.toJson((Object)((Split)entry.getValue())))).toList();
            return "stagesInfo=%s; biggestSplits=%s".formatted(debugInfoByStageId, biggestSplits);
        }

        private String getDebugInfo(Collection<TaskDescriptorHolder> taskDescriptors) {
            int taskDescriptorsCount = taskDescriptors.size();
            Stats taskDescriptorsRetainedSizeStats = Stats.of((LongStream)taskDescriptors.stream().mapToLong(TaskDescriptorHolder::getRetainedSizeInBytes));
            Set planNodeIds = (Set)taskDescriptors.stream().flatMap(taskDescriptor -> taskDescriptor.getTaskDescriptor().getSplits().getSplitsFlat().keySet().stream()).collect(ImmutableSet.toImmutableSet());
            HashMap<PlanNodeId, String> splitsDebugInfo = new HashMap<PlanNodeId, String>();
            for (PlanNodeId planNodeId : planNodeIds) {
                Stats splitCountStats = Stats.of((LongStream)taskDescriptors.stream().mapToLong(taskDescriptor -> ((Collection)taskDescriptor.getTaskDescriptor().getSplits().getSplitsFlat().asMap().get(planNodeId)).size()));
                Stats splitSizeStats = Stats.of((LongStream)taskDescriptors.stream().flatMap(taskDescriptor -> taskDescriptor.getTaskDescriptor().getSplits().getSplitsFlat().get((Object)planNodeId).stream()).mapToLong(Split::getRetainedSizeInBytes));
                splitsDebugInfo.put(planNodeId, "{splitCountMean=%s, splitCountStdDev=%s, splitSizeMean=%s, splitSizeStdDev=%s}".formatted(splitCountStats.mean(), splitCountStats.populationStandardDeviation(), splitSizeStats.mean(), splitSizeStats.populationStandardDeviation()));
            }
            return "[taskDescriptorsCount=%s, taskDescriptorsRetainedSizeMean=%s, taskDescriptorsRetainedSizeStdDev=%s, splits=%s]".formatted(taskDescriptorsCount, taskDescriptorsRetainedSizeStats.mean(), taskDescriptorsRetainedSizeStats.populationStandardDeviation(), splitsDebugInfo);
        }

        private void fail(TrinoException failure) {
            if (this.failure == null) {
                this.descriptors.clear();
                this.reservedUncompressedBytes = 0L;
                this.reservedCompressedBytes = 0L;
                this.originalCompressedBytes = 0L;
                this.failure = failure;
            }
        }

        private void throwIfFailed() {
            if (this.failure != null) {
                throw new TrinoException(() -> ((TrinoException)this.failure).getErrorCode(), this.failure.getMessage(), (Throwable)this.failure);
            }
        }

        public int getStagesCount() {
            return this.descriptors.rowMap().size();
        }

        public Stream<Long> getStagesReservedUncompressedBytes() {
            return this.stagesReservedUncompressedBytes.values().stream().map(AtomicLong::get);
        }

        public Stream<Long> getStagesReservedCompressedBytes() {
            return this.stagesReservedCompressedBytes.values().stream().map(AtomicLong::get);
        }

        public Stream<Long> getStagesOriginalCompressedBytes() {
            return this.stagesOriginalCompressedBytes.values().stream().map(AtomicLong::get);
        }

        public boolean isFullyCompressed() {
            return this.fullyCompressed;
        }

        @GuardedBy(value="TaskDescriptorStorage.this")
        public int compress(int limit) {
            if (this.fullyCompressed) {
                return 0;
            }
            List selectedForCompresssion = (List)this.descriptors.values().stream().filter(descriptor -> !descriptor.isCompressed()).limit(limit).collect(ImmutableList.toImmutableList());
            for (TaskDescriptorHolder holder : selectedForCompresssion) {
                long uncompressedSize = holder.getUncompressedSize();
                holder.compress();
                this.reservedUncompressedBytes -= uncompressedSize;
                this.originalCompressedBytes += uncompressedSize;
                this.reservedCompressedBytes += holder.getCompressedSize();
                TaskDescriptorStorage.this.checkStatsNotNegative();
            }
            if (selectedForCompresssion.size() < limit) {
                this.fullyCompressed = true;
            }
            return selectedForCompresssion.size();
        }
    }

    private class TaskDescriptorHolder {
        private static final int INSTANCE_SIZE = SizeOf.instanceSize(TaskDescriptorHolder.class);
        private TaskDescriptor taskDescriptor;
        private final long uncompressedSize;
        private byte[] compressedTaskDescriptor;

        private TaskDescriptorHolder(TaskDescriptor taskDescriptor) {
            this.taskDescriptor = Objects.requireNonNull(taskDescriptor, "taskDescriptor is null");
            this.uncompressedSize = taskDescriptor.getRetainedSizeInBytes();
        }

        public TaskDescriptor getTaskDescriptor() {
            if (this.taskDescriptor != null) {
                return this.taskDescriptor;
            }
            Verify.verify((this.compressedTaskDescriptor != null ? 1 : 0) != 0, (String)"compressedTaskDescriptor is null", (Object[])new Object[0]);
            ZstdDecompressor decompressor = ZstdDecompressor.create();
            long decompressedSize = decompressor.getDecompressedSize(this.compressedTaskDescriptor, 0, this.compressedTaskDescriptor.length);
            byte[] output = new byte[Math.toIntExact(decompressedSize)];
            decompressor.decompress(this.compressedTaskDescriptor, 0, this.compressedTaskDescriptor.length, output, 0, output.length);
            return (TaskDescriptor)TaskDescriptorStorage.this.taskDescriptorJsonCodec.fromJson(output);
        }

        public void compress() {
            Preconditions.checkState((!this.isCompressed() ? 1 : 0) != 0, (Object)"TaskDescriptor is compressed");
            byte[] taskDescriptorJson = TaskDescriptorStorage.this.taskDescriptorJsonCodec.toJsonBytes((Object)this.taskDescriptor);
            ZstdCompressor compressor = ZstdCompressor.create();
            int maxCompressedSize = compressor.maxCompressedLength(taskDescriptorJson.length);
            byte[] tmpCompressedTaskDescriptor = new byte[maxCompressedSize];
            int compressedSize = compressor.compress(taskDescriptorJson, 0, taskDescriptorJson.length, tmpCompressedTaskDescriptor, 0, maxCompressedSize);
            this.compressedTaskDescriptor = new byte[compressedSize];
            System.arraycopy(tmpCompressedTaskDescriptor, 0, this.compressedTaskDescriptor, 0, compressedSize);
            this.taskDescriptor = null;
        }

        public void decompress() {
            Preconditions.checkState((boolean)this.isCompressed(), (Object)"TaskDescriptor is not compressed");
            this.taskDescriptor = this.getTaskDescriptor();
            this.compressedTaskDescriptor = null;
        }

        public boolean isCompressed() {
            return this.compressedTaskDescriptor != null;
        }

        public long getUncompressedSize() {
            return this.uncompressedSize;
        }

        public long getCompressedSize() {
            Preconditions.checkState((boolean)this.isCompressed(), (Object)"TaskDescriptor is not compressed");
            return this.compressedTaskDescriptor.length;
        }

        public long getRetainedSizeInBytes() {
            return (long)INSTANCE_SIZE + (this.isCompressed() ? (long)this.compressedTaskDescriptor.length : this.uncompressedSize);
        }
    }

    private record StorageStatsValue(long bytes, long queryBytesAvg, long queryBytesP50, long queryBytesP90, long queryBytesP95, long stageBytesAvg, long stageBytesP50, long stageBytesP90, long stageBytesP95) {
    }

    private record StorageStatsValues(long queriesCount, long stagesCount, boolean compressionActive, StorageStatsValue uncompressedReservedStats, StorageStatsValue compressedReservedStats, StorageStatsValue originalCompressedStats) {
        private StorageStatsValues {
            Objects.requireNonNull(uncompressedReservedStats, "uncompressedReservedStats is null");
            Objects.requireNonNull(compressedReservedStats, "compressedReservedStats is null");
            Objects.requireNonNull(originalCompressedStats, "originalCompressedStats is null");
        }
    }
}

