/*
 * Decompiled with CFR 0.152.
 */
package io.trino.memory;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Joiner;
import com.google.common.base.Preconditions;
import com.google.common.base.Verify;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.common.collect.MoreCollectors;
import com.google.common.collect.Sets;
import com.google.common.collect.Streams;
import com.google.common.io.Closer;
import io.airlift.http.client.HttpClient;
import io.airlift.json.JsonCodec;
import io.airlift.log.Logger;
import io.airlift.units.DataSize;
import io.airlift.units.Duration;
import io.trino.ExceededMemoryLimitException;
import io.trino.SystemSessionProperties;
import io.trino.execution.LocationFactory;
import io.trino.execution.QueryExecution;
import io.trino.execution.QueryIdGenerator;
import io.trino.execution.QueryInfo;
import io.trino.execution.StageInfo;
import io.trino.execution.TaskId;
import io.trino.execution.TaskInfo;
import io.trino.memory.ClusterMemoryLeakDetector;
import io.trino.memory.ClusterMemoryPool;
import io.trino.memory.ForMemoryManager;
import io.trino.memory.KillTarget;
import io.trino.memory.LowMemoryKiller;
import io.trino.memory.MemoryInfo;
import io.trino.memory.MemoryManagerConfig;
import io.trino.memory.NodeMemoryConfig;
import io.trino.memory.RemoteNodeMemory;
import io.trino.metadata.InternalNode;
import io.trino.metadata.InternalNodeManager;
import io.trino.metadata.NodeState;
import io.trino.operator.RetryPolicy;
import io.trino.server.BasicQueryInfo;
import io.trino.server.ServerConfig;
import io.trino.spi.ErrorCodeSupplier;
import io.trino.spi.QueryId;
import io.trino.spi.StandardErrorCode;
import io.trino.spi.TrinoException;
import io.trino.spi.memory.ClusterMemoryPoolManager;
import io.trino.spi.memory.MemoryPoolInfo;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Consumer;
import java.util.function.Supplier;
import javax.annotation.PreDestroy;
import javax.annotation.concurrent.GuardedBy;
import javax.inject.Inject;
import org.weakref.jmx.JmxException;
import org.weakref.jmx.MBeanExporter;
import org.weakref.jmx.Managed;

public class ClusterMemoryManager
implements ClusterMemoryPoolManager {
    private static final Logger log = Logger.get(ClusterMemoryManager.class);
    private static final String EXPORTED_POOL_NAME = "general";
    private final ExecutorService listenerExecutor = Executors.newSingleThreadExecutor();
    private final ClusterMemoryLeakDetector memoryLeakDetector = new ClusterMemoryLeakDetector();
    private final InternalNodeManager nodeManager;
    private final LocationFactory locationFactory;
    private final HttpClient httpClient;
    private final MBeanExporter exporter;
    private final JsonCodec<MemoryInfo> memoryInfoCodec;
    private final DataSize maxQueryMemory;
    private final DataSize maxQueryTotalMemory;
    private final List<LowMemoryKiller> lowMemoryKillers;
    private final Duration killOnOutOfMemoryDelay;
    private final AtomicLong totalAvailableProcessors = new AtomicLong();
    private final AtomicLong clusterUserMemoryReservation = new AtomicLong();
    private final AtomicLong clusterTotalMemoryReservation = new AtomicLong();
    private final AtomicLong clusterMemoryBytes = new AtomicLong();
    private final AtomicLong queriesKilledDueToOutOfMemory = new AtomicLong();
    private final AtomicLong tasksKilledDueToOutOfMemory = new AtomicLong();
    @GuardedBy(value="this")
    private final Map<String, RemoteNodeMemory> nodes = new HashMap<String, RemoteNodeMemory>();
    @GuardedBy(value="this")
    private final List<Consumer<MemoryPoolInfo>> changeListeners = new ArrayList<Consumer<MemoryPoolInfo>>();
    private final ClusterMemoryPool pool;
    @GuardedBy(value="this")
    private long lastTimeNotOutOfMemory = System.nanoTime();
    @GuardedBy(value="this")
    private Optional<KillTarget> lastKillTarget = Optional.empty();

    @Inject
    public ClusterMemoryManager(@ForMemoryManager HttpClient httpClient, InternalNodeManager nodeManager, LocationFactory locationFactory, MBeanExporter exporter, JsonCodec<MemoryInfo> memoryInfoCodec, QueryIdGenerator queryIdGenerator, @LowMemoryKiller.ForTaskLowMemoryKiller LowMemoryKiller taskLowMemoryKiller, @LowMemoryKiller.ForQueryLowMemoryKiller LowMemoryKiller queryLowMemoryKiller, ServerConfig serverConfig, MemoryManagerConfig config, NodeMemoryConfig nodeMemoryConfig) {
        Objects.requireNonNull(config, "config is null");
        Objects.requireNonNull(nodeMemoryConfig, "nodeMemoryConfig is null");
        Objects.requireNonNull(serverConfig, "serverConfig is null");
        Preconditions.checkState((boolean)serverConfig.isCoordinator(), (Object)"ClusterMemoryManager must not be bound on worker");
        this.nodeManager = Objects.requireNonNull(nodeManager, "nodeManager is null");
        this.locationFactory = Objects.requireNonNull(locationFactory, "locationFactory is null");
        this.httpClient = Objects.requireNonNull(httpClient, "httpClient is null");
        this.exporter = Objects.requireNonNull(exporter, "exporter is null");
        this.memoryInfoCodec = Objects.requireNonNull(memoryInfoCodec, "memoryInfoCodec is null");
        Objects.requireNonNull(taskLowMemoryKiller, "taskLowMemoryKiller is null");
        Objects.requireNonNull(queryLowMemoryKiller, "queryLowMemoryKiller is null");
        this.lowMemoryKillers = ImmutableList.of((Object)taskLowMemoryKiller, (Object)queryLowMemoryKiller);
        this.maxQueryMemory = config.getMaxQueryMemory();
        this.maxQueryTotalMemory = config.getMaxQueryTotalMemory();
        this.killOnOutOfMemoryDelay = config.getKillOnOutOfMemoryDelay();
        Verify.verify((this.maxQueryMemory.toBytes() <= this.maxQueryTotalMemory.toBytes() ? 1 : 0) != 0, (String)"maxQueryMemory cannot be greater than maxQueryTotalMemory", (Object[])new Object[0]);
        this.pool = new ClusterMemoryPool();
        this.exportMemoryPool();
    }

    private void exportMemoryPool() {
        try {
            this.exporter.exportWithGeneratedName((Object)this.pool, ClusterMemoryPool.class, EXPORTED_POOL_NAME);
        }
        catch (JmxException e) {
            log.error((Throwable)e, "Error exporting memory pool");
        }
    }

    public synchronized void addChangeListener(Consumer<MemoryPoolInfo> listener) {
        this.changeListeners.add(listener);
    }

    public synchronized void process(Iterable<QueryExecution> runningQueries, Supplier<List<BasicQueryInfo>> allQueryInfoSupplier) {
        this.memoryLeakDetector.checkForMemoryLeaks(allQueryInfoSupplier, this.pool.getQueryMemoryReservations());
        boolean outOfMemory = this.isClusterOutOfMemory();
        if (!outOfMemory) {
            this.lastTimeNotOutOfMemory = System.nanoTime();
        }
        boolean queryKilled = false;
        long totalUserMemoryBytes = 0L;
        long totalMemoryBytes = 0L;
        for (QueryExecution query : runningQueries) {
            long totalMemoryLimit;
            boolean resourceOvercommit = SystemSessionProperties.resourceOvercommit(query.getSession());
            long userMemoryReservation = query.getUserMemoryReservation().toBytes();
            long totalMemoryReservation = query.getTotalMemoryReservation().toBytes();
            totalUserMemoryBytes += userMemoryReservation;
            totalMemoryBytes += totalMemoryReservation;
            if (SystemSessionProperties.getRetryPolicy(query.getSession()) == RetryPolicy.TASK) continue;
            if (resourceOvercommit && outOfMemory) {
                DataSize memory = DataSize.succinctBytes((long)this.getQueryMemoryReservation(query));
                query.fail((Throwable)new TrinoException((ErrorCodeSupplier)StandardErrorCode.CLUSTER_OUT_OF_MEMORY, String.format("The cluster is out of memory and %s=true, so this query was killed. It was using %s of memory", "resource_overcommit", memory)));
                queryKilled = true;
            }
            if (resourceOvercommit) continue;
            long userMemoryLimit = Math.min(this.maxQueryMemory.toBytes(), SystemSessionProperties.getQueryMaxMemory(query.getSession()).toBytes());
            if (userMemoryReservation > userMemoryLimit) {
                query.fail((Throwable)((Object)ExceededMemoryLimitException.exceededGlobalUserLimit(DataSize.succinctBytes((long)userMemoryLimit))));
                queryKilled = true;
            }
            if (totalMemoryReservation <= (totalMemoryLimit = Math.min(this.maxQueryTotalMemory.toBytes(), SystemSessionProperties.getQueryMaxTotalMemory(query.getSession()).toBytes()))) continue;
            query.fail((Throwable)((Object)ExceededMemoryLimitException.exceededGlobalTotalLimit(DataSize.succinctBytes((long)totalMemoryLimit))));
            queryKilled = true;
        }
        this.clusterUserMemoryReservation.set(totalUserMemoryBytes);
        this.clusterTotalMemoryReservation.set(totalMemoryBytes);
        if (!this.lowMemoryKillers.isEmpty() && outOfMemory && !queryKilled && Duration.nanosSince((long)this.lastTimeNotOutOfMemory).compareTo(this.killOnOutOfMemoryDelay) > 0) {
            if (this.isLastKillTargetGone()) {
                this.callOomKiller(runningQueries);
            } else {
                log.debug("Last killed target is still not gone: %s", new Object[]{this.lastKillTarget});
            }
        }
        this.updateMemoryPool(Iterables.size(runningQueries));
        this.updateNodes();
    }

    private synchronized void callOomKiller(Iterable<QueryExecution> runningQueries) {
        List runningQueryInfos = (List)Streams.stream(runningQueries).map(this::createQueryMemoryInfo).collect(ImmutableList.toImmutableList());
        Map nodeMemoryInfosByNode = (Map)this.nodes.entrySet().stream().filter(entry -> ((RemoteNodeMemory)entry.getValue()).getInfo().isPresent()).collect(ImmutableMap.toImmutableMap(Map.Entry::getKey, entry -> ((RemoteNodeMemory)entry.getValue()).getInfo().get()));
        for (LowMemoryKiller lowMemoryKiller : this.lowMemoryKillers) {
            ImmutableList nodeMemoryInfos;
            Optional<KillTarget> killTarget = lowMemoryKiller.chooseTargetToKill(runningQueryInfos, (List<MemoryInfo>)(nodeMemoryInfos = ImmutableList.copyOf(nodeMemoryInfosByNode.values())));
            if (!killTarget.isPresent()) continue;
            if (killTarget.get().isWholeQuery()) {
                QueryId queryId = killTarget.get().getQuery();
                log.debug("Low memory killer chose %s", new Object[]{queryId});
                Optional<QueryExecution> chosenQuery = this.findRunningQuery(runningQueries, killTarget.get().getQuery());
                if (!chosenQuery.isPresent()) break;
                chosenQuery.get().fail((Throwable)new TrinoException((ErrorCodeSupplier)StandardErrorCode.CLUSTER_OUT_OF_MEMORY, "Query killed because the cluster is out of memory. Please try again in a few minutes."));
                this.queriesKilledDueToOutOfMemory.incrementAndGet();
                this.lastKillTarget = killTarget;
                this.logQueryKill(queryId, nodeMemoryInfosByNode);
                break;
            }
            Set<TaskId> tasks = killTarget.get().getTasks();
            log.debug("Low memory killer chose %s", new Object[]{tasks});
            ImmutableSet.Builder killedTasksBuilder = ImmutableSet.builder();
            for (TaskId task : tasks) {
                Optional<QueryExecution> runningQuery = this.findRunningQuery(runningQueries, task.getQueryId());
                if (!runningQuery.isPresent()) continue;
                runningQuery.get().failTask(task, (Exception)new TrinoException((ErrorCodeSupplier)StandardErrorCode.CLUSTER_OUT_OF_MEMORY, "Task killed because the cluster is out of memory."));
                this.tasksKilledDueToOutOfMemory.incrementAndGet();
                killedTasksBuilder.add((Object)task);
            }
            ImmutableSet killedTasks = killedTasksBuilder.build();
            if (killedTasks.isEmpty()) break;
            this.lastKillTarget = Optional.of(KillTarget.selectedTasks((Set<TaskId>)killedTasks));
            this.logTasksKill((Set<TaskId>)killedTasks, nodeMemoryInfosByNode);
            break;
        }
    }

    @GuardedBy(value="this")
    private boolean isLastKillTargetGone() {
        if (this.lastKillTarget.isEmpty()) {
            return true;
        }
        if (this.lastKillTarget.get().isWholeQuery()) {
            return this.isQueryGone(this.lastKillTarget.get().getQuery());
        }
        return this.areTasksGone(this.lastKillTarget.get().getTasks());
    }

    private boolean isQueryGone(QueryId killedQuery) {
        if (this.memoryLeakDetector.wasQueryPossiblyLeaked(killedQuery)) {
            this.lastKillTarget = Optional.empty();
            return true;
        }
        return !this.pool.getQueryMemoryReservations().containsKey(killedQuery);
    }

    private boolean areTasksGone(Set<TaskId> tasks) {
        ImmutableSet<TaskId> runningTasks = this.getRunningTasks();
        return tasks.stream().noneMatch(arg_0 -> runningTasks.contains(arg_0));
    }

    private ImmutableSet<TaskId> getRunningTasks() {
        return (ImmutableSet)this.nodes.values().stream().map(RemoteNodeMemory::getInfo).filter(Optional::isPresent).map(Optional::get).flatMap(memoryInfo -> memoryInfo.getPool().getTaskMemoryReservations().keySet().stream()).map(TaskId::valueOf).collect(ImmutableSet.toImmutableSet());
    }

    private Optional<QueryExecution> findRunningQuery(Iterable<QueryExecution> runningQueries, QueryId queryId) {
        return (Optional)Streams.stream(runningQueries).filter(query -> queryId.equals((Object)query.getQueryId())).collect(MoreCollectors.toOptional());
    }

    private void logQueryKill(QueryId killedQueryId, Map<String, MemoryInfo> nodeMemoryInfosByNode) {
        if (!log.isInfoEnabled()) {
            return;
        }
        StringBuilder nodeDescription = new StringBuilder();
        nodeDescription.append("Query Kill Decision: Killed ").append(killedQueryId).append("\n");
        nodeDescription.append(this.formatKillScenario(nodeMemoryInfosByNode));
        log.info("%s", new Object[]{nodeDescription});
    }

    private void logTasksKill(Set<TaskId> tasks, Map<String, MemoryInfo> nodeMemoryInfosByNode) {
        if (!log.isInfoEnabled()) {
            return;
        }
        StringBuilder nodeDescription = new StringBuilder();
        nodeDescription.append("Query Kill Decision: Tasks Killed ").append(tasks).append("\n");
        nodeDescription.append(this.formatKillScenario(nodeMemoryInfosByNode));
        log.info("%s", new Object[]{nodeDescription});
    }

    private String formatKillScenario(Map<String, MemoryInfo> nodes) {
        StringBuilder stringBuilder = new StringBuilder();
        for (Map.Entry<String, MemoryInfo> entry : nodes.entrySet()) {
            String nodeId = entry.getKey();
            MemoryInfo nodeMemoryInfo = entry.getValue();
            MemoryPoolInfo memoryPoolInfo = nodeMemoryInfo.getPool();
            stringBuilder.append("Node[").append(nodeId).append("]: ");
            stringBuilder.append("MaxBytes ").append(memoryPoolInfo.getMaxBytes()).append(' ');
            stringBuilder.append("FreeBytes ").append(memoryPoolInfo.getFreeBytes() + memoryPoolInfo.getReservedRevocableBytes()).append(' ');
            stringBuilder.append("Queries ");
            Joiner.on((String)",").withKeyValueSeparator("=").appendTo(stringBuilder, memoryPoolInfo.getQueryMemoryReservations()).append(' ');
            stringBuilder.append("Tasks ");
            Joiner.on((String)",").withKeyValueSeparator("=").appendTo(stringBuilder, memoryPoolInfo.getTaskMemoryReservations());
            stringBuilder.append('\n');
        }
        return stringBuilder.toString();
    }

    @VisibleForTesting
    ClusterMemoryPool getPool() {
        return this.pool;
    }

    private boolean isClusterOutOfMemory() {
        return this.pool.getBlockedNodes() > 0;
    }

    private LowMemoryKiller.RunningQueryInfo createQueryMemoryInfo(QueryExecution query) {
        QueryInfo queryInfo = query.getQueryInfo();
        ImmutableMap.Builder taskInfosBuilder = ImmutableMap.builder();
        queryInfo.getOutputStage().ifPresent(stage -> this.getTaskInfos((StageInfo)stage, (ImmutableMap.Builder<TaskId, TaskInfo>)taskInfosBuilder));
        return new LowMemoryKiller.RunningQueryInfo(query.getQueryId(), query.getTotalMemoryReservation().toBytes(), (Map<TaskId, TaskInfo>)taskInfosBuilder.buildOrThrow(), SystemSessionProperties.getRetryPolicy(query.getSession()));
    }

    private void getTaskInfos(StageInfo stageInfo, ImmutableMap.Builder<TaskId, TaskInfo> taskInfosBuilder) {
        for (TaskInfo taskInfo : stageInfo.getTasks()) {
            taskInfosBuilder.put((Object)taskInfo.getTaskStatus().getTaskId(), (Object)taskInfo);
        }
        for (StageInfo subStage : stageInfo.getSubStages()) {
            this.getTaskInfos(subStage, taskInfosBuilder);
        }
    }

    private long getQueryMemoryReservation(QueryExecution query) {
        return query.getTotalMemoryReservation().toBytes();
    }

    private synchronized void updateNodes() {
        ImmutableSet.Builder builder = ImmutableSet.builder();
        ImmutableSet aliveNodes = builder.addAll(this.nodeManager.getNodes(NodeState.ACTIVE)).addAll(this.nodeManager.getNodes(NodeState.SHUTTING_DOWN)).build();
        ImmutableSet aliveNodeIds = (ImmutableSet)aliveNodes.stream().map(InternalNode::getNodeIdentifier).collect(ImmutableSet.toImmutableSet());
        ImmutableSet deadNodes = ImmutableSet.copyOf((Collection)Sets.difference(this.nodes.keySet(), (Set)aliveNodeIds));
        this.nodes.keySet().removeAll((Collection<?>)deadNodes);
        for (InternalNode internalNode : aliveNodes) {
            if (this.nodes.containsKey(internalNode.getNodeIdentifier())) continue;
            this.nodes.put(internalNode.getNodeIdentifier(), new RemoteNodeMemory(internalNode, this.httpClient, this.memoryInfoCodec, this.locationFactory.createMemoryInfoLocation(internalNode)));
        }
        for (RemoteNodeMemory remoteNodeMemory : this.nodes.values()) {
            remoteNodeMemory.asyncRefresh();
        }
    }

    private synchronized void updateMemoryPool(int queryCount) {
        List nodeMemoryInfos = (List)this.nodes.values().stream().map(RemoteNodeMemory::getInfo).filter(Optional::isPresent).map(Optional::get).collect(ImmutableList.toImmutableList());
        long totalProcessors = nodeMemoryInfos.stream().mapToLong(MemoryInfo::getAvailableProcessors).sum();
        this.totalAvailableProcessors.set(totalProcessors);
        long totalClusterMemory = nodeMemoryInfos.stream().mapToLong(memoryInfo -> memoryInfo.getPool().getMaxBytes()).sum();
        this.clusterMemoryBytes.set(totalClusterMemory);
        this.pool.update(nodeMemoryInfos, queryCount);
        if (!this.changeListeners.isEmpty()) {
            MemoryPoolInfo info = this.pool.getInfo();
            for (Consumer<MemoryPoolInfo> listener : this.changeListeners) {
                this.listenerExecutor.execute(() -> listener.accept(info));
            }
        }
    }

    public synchronized Map<String, Optional<MemoryInfo>> getWorkerMemoryInfo() {
        HashMap<String, Optional<MemoryInfo>> memoryInfo = new HashMap<String, Optional<MemoryInfo>>();
        for (Map.Entry<String, RemoteNodeMemory> entry : this.nodes.entrySet()) {
            String workerId = entry.getKey();
            memoryInfo.put(workerId, entry.getValue().getInfo());
        }
        return memoryInfo;
    }

    @PreDestroy
    public synchronized void destroy() throws IOException {
        try (Closer closer = Closer.create();){
            closer.register(() -> this.exporter.unexportWithGeneratedName(ClusterMemoryPool.class, EXPORTED_POOL_NAME));
            closer.register(this.listenerExecutor::shutdownNow);
        }
    }

    @Managed
    public long getTotalAvailableProcessors() {
        return this.totalAvailableProcessors.get();
    }

    @Managed
    public int getNumberOfLeakedQueries() {
        return this.memoryLeakDetector.getNumberOfLeakedQueries();
    }

    @Managed
    public long getClusterUserMemoryReservation() {
        return this.clusterUserMemoryReservation.get();
    }

    @Managed
    public long getClusterTotalMemoryReservation() {
        return this.clusterTotalMemoryReservation.get();
    }

    @Managed
    public long getClusterMemoryBytes() {
        return this.clusterMemoryBytes.get();
    }

    @Managed
    public long getQueriesKilledDueToOutOfMemory() {
        return this.queriesKilledDueToOutOfMemory.get();
    }

    @Managed
    public long getTasksKilledDueToOutOfMemory() {
        return this.tasksKilledDueToOutOfMemory.get();
    }
}

