/*
 * Decompiled with CFR 0.152.
 */
package io.trino.execution.scheduler.faulttolerant;

import com.google.common.collect.Ordering;
import com.google.inject.Inject;
import io.airlift.log.Logger;
import io.airlift.stats.TDigest;
import io.airlift.units.DataSize;
import io.trino.Session;
import io.trino.SystemSessionProperties;
import io.trino.execution.scheduler.ErrorCodes;
import io.trino.execution.scheduler.faulttolerant.PartitionMemoryEstimator;
import io.trino.execution.scheduler.faulttolerant.PartitionMemoryEstimatorFactory;
import io.trino.memory.ClusterMemoryManager;
import io.trino.memory.MemoryInfo;
import io.trino.memory.MemoryManagerConfig;
import io.trino.spi.ErrorCode;
import io.trino.spi.memory.MemoryPoolInfo;
import io.trino.sql.planner.PlanFragment;
import io.trino.sql.planner.SystemPartitioningHandle;
import io.trino.sql.planner.plan.PlanFragmentId;
import jakarta.annotation.PostConstruct;
import jakarta.annotation.PreDestroy;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.assertj.core.util.VisibleForTesting;

public class ExponentialGrowthPartitionMemoryEstimator
implements PartitionMemoryEstimator {
    private final DataSize defaultInitialMemoryLimit;
    private final boolean memoryRequirementIncreaseOnWorkerCrashEnabled;
    private final double growthFactor;
    private final double estimationQuantile;
    private final Supplier<Optional<DataSize>> maxNodePoolSizeSupplier;
    private final TDigest memoryUsageDistribution = new TDigest();

    private ExponentialGrowthPartitionMemoryEstimator(DataSize defaultInitialMemoryLimit, boolean memoryRequirementIncreaseOnWorkerCrashEnabled, double growthFactor, double estimationQuantile, Supplier<Optional<DataSize>> maxNodePoolSizeSupplier) {
        this.defaultInitialMemoryLimit = Objects.requireNonNull(defaultInitialMemoryLimit, "defaultInitialMemoryLimit is null");
        this.memoryRequirementIncreaseOnWorkerCrashEnabled = memoryRequirementIncreaseOnWorkerCrashEnabled;
        this.growthFactor = growthFactor;
        this.estimationQuantile = estimationQuantile;
        this.maxNodePoolSizeSupplier = Objects.requireNonNull(maxNodePoolSizeSupplier, "maxNodePoolSizeSupplier is null");
    }

    @Override
    public PartitionMemoryEstimator.MemoryRequirements getInitialMemoryRequirements() {
        DataSize memory = (DataSize)Ordering.natural().max((Object)this.defaultInitialMemoryLimit, (Object)this.getEstimatedMemoryUsage());
        memory = this.capMemoryToMaxNodeSize(memory);
        return new PartitionMemoryEstimator.MemoryRequirements(memory);
    }

    @Override
    public PartitionMemoryEstimator.MemoryRequirements getNextRetryMemoryRequirements(PartitionMemoryEstimator.MemoryRequirements previousMemoryRequirements, DataSize peakMemoryUsage, ErrorCode errorCode) {
        DataSize previousMemory = previousMemoryRequirements.getRequiredMemory();
        DataSize newMemory = (DataSize)Ordering.natural().max((Object)peakMemoryUsage, (Object)previousMemory);
        if (this.shouldIncreaseMemoryRequirement(errorCode)) {
            newMemory = DataSize.of((long)((long)((double)newMemory.toBytes() * this.growthFactor)), (DataSize.Unit)DataSize.Unit.BYTE);
        }
        newMemory = (DataSize)Ordering.natural().max((Object)newMemory, (Object)this.getEstimatedMemoryUsage());
        newMemory = this.capMemoryToMaxNodeSize(newMemory);
        return new PartitionMemoryEstimator.MemoryRequirements(newMemory);
    }

    private DataSize capMemoryToMaxNodeSize(DataSize memory) {
        Optional<DataSize> currentMaxNodePoolSize = this.maxNodePoolSizeSupplier.get();
        if (currentMaxNodePoolSize.isEmpty()) {
            return memory;
        }
        return (DataSize)Ordering.natural().min((Object)memory, (Object)currentMaxNodePoolSize.get());
    }

    @Override
    public synchronized void registerPartitionFinished(PartitionMemoryEstimator.MemoryRequirements previousMemoryRequirements, DataSize peakMemoryUsage, boolean success, Optional<ErrorCode> errorCode) {
        if (success) {
            this.memoryUsageDistribution.add((double)peakMemoryUsage.toBytes());
        }
        if (!success && errorCode.isPresent() && this.shouldIncreaseMemoryRequirement(errorCode.get())) {
            long previousRequiredBytes = previousMemoryRequirements.getRequiredMemory().toBytes();
            long previousPeakBytes = peakMemoryUsage.toBytes();
            this.memoryUsageDistribution.add((double)Math.max(previousRequiredBytes, previousPeakBytes) * this.growthFactor);
        }
    }

    private synchronized DataSize getEstimatedMemoryUsage() {
        double estimation = this.memoryUsageDistribution.valueAt(this.estimationQuantile);
        if (Double.isNaN(estimation)) {
            return DataSize.ofBytes((long)0L);
        }
        return DataSize.ofBytes((long)((long)estimation));
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private String memoryUsageDistributionInfo() {
        double[] values;
        double[] quantiles = new double[]{0.01, 0.05, 0.1, 0.2, 0.5, 0.8, 0.9, 0.95, 0.99};
        ExponentialGrowthPartitionMemoryEstimator exponentialGrowthPartitionMemoryEstimator = this;
        synchronized (exponentialGrowthPartitionMemoryEstimator) {
            values = this.memoryUsageDistribution.valuesAt(quantiles);
        }
        return IntStream.range(0, quantiles.length).mapToObj(i -> quantiles[i] + "=" + values[i]).collect(Collectors.joining(", ", "[", "]"));
    }

    public String toString() {
        return "memoryUsageDistribution=" + this.memoryUsageDistributionInfo();
    }

    private boolean shouldIncreaseMemoryRequirement(ErrorCode errorCode) {
        return ErrorCodes.isOutOfMemoryError(errorCode) || this.memoryRequirementIncreaseOnWorkerCrashEnabled && ErrorCodes.isWorkerCrashAssociatedError(errorCode);
    }

    public static class Factory
    implements PartitionMemoryEstimatorFactory {
        private static final Logger log = Logger.get(Factory.class);
        private final Supplier<Map<String, Optional<MemoryInfo>>> workerMemoryInfoSupplier;
        private final boolean memoryRequirementIncreaseOnWorkerCrashEnabled;
        private final ScheduledExecutorService executor = Executors.newSingleThreadScheduledExecutor();
        private final AtomicReference<Optional<DataSize>> maxNodePoolSize = new AtomicReference(Optional.empty());

        @Inject
        public Factory(ClusterMemoryManager clusterMemoryManager, MemoryManagerConfig memoryManagerConfig) {
            this(clusterMemoryManager::getWorkerMemoryInfo, memoryManagerConfig.isFaultTolerantExecutionMemoryRequirementIncreaseOnWorkerCrashEnabled());
        }

        @VisibleForTesting
        Factory(Supplier<Map<String, Optional<MemoryInfo>>> workerMemoryInfoSupplier, boolean memoryRequirementIncreaseOnWorkerCrashEnabled) {
            this.workerMemoryInfoSupplier = Objects.requireNonNull(workerMemoryInfoSupplier, "workerMemoryInfoSupplier is null");
            this.memoryRequirementIncreaseOnWorkerCrashEnabled = memoryRequirementIncreaseOnWorkerCrashEnabled;
        }

        @PostConstruct
        public void start() {
            this.refreshNodePoolMemoryInfos();
            this.executor.scheduleWithFixedDelay(() -> {
                try {
                    this.refreshNodePoolMemoryInfos();
                }
                catch (Throwable e) {
                    log.error(e, "Unexpected error while refreshing node pool memory infos");
                }
            }, 1L, 1L, TimeUnit.SECONDS);
        }

        @PreDestroy
        public void stop() {
            this.executor.shutdownNow();
        }

        @VisibleForTesting
        void refreshNodePoolMemoryInfos() {
            Map<String, Optional<MemoryInfo>> workerMemoryInfos = this.workerMemoryInfoSupplier.get();
            long maxNodePoolSizeBytes = -1L;
            for (Map.Entry<String, Optional<MemoryInfo>> entry : workerMemoryInfos.entrySet()) {
                if (entry.getValue().isEmpty()) continue;
                MemoryPoolInfo poolInfo = entry.getValue().get().getPool();
                maxNodePoolSizeBytes = Math.max(poolInfo.getMaxBytes(), maxNodePoolSizeBytes);
            }
            this.maxNodePoolSize.set(maxNodePoolSizeBytes == -1L ? Optional.empty() : Optional.of(DataSize.ofBytes((long)maxNodePoolSizeBytes)));
        }

        @Override
        public PartitionMemoryEstimator createPartitionMemoryEstimator(Session session, PlanFragment planFragment, Function<PlanFragmentId, PlanFragment> sourceFragmentLookup) {
            DataSize defaultInitialMemoryLimit = planFragment.getPartitioning().equals(SystemPartitioningHandle.COORDINATOR_DISTRIBUTION) ? SystemSessionProperties.getFaultTolerantExecutionDefaultCoordinatorTaskMemory(session) : SystemSessionProperties.getFaultTolerantExecutionDefaultTaskMemory(session);
            return new ExponentialGrowthPartitionMemoryEstimator(defaultInitialMemoryLimit, this.memoryRequirementIncreaseOnWorkerCrashEnabled, SystemSessionProperties.getFaultTolerantExecutionTaskMemoryGrowthFactor(session), SystemSessionProperties.getFaultTolerantExecutionTaskMemoryEstimationQuantile(session), this.maxNodePoolSize::get);
        }
    }
}

