/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.examples.util;

import ai.djl.Device;
import ai.djl.engine.Engine;
import ai.djl.metric.Metric;
import ai.djl.metric.Metrics;
import java.io.BufferedWriter;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.lang.management.ManagementFactory;
import java.lang.management.MemoryMXBean;
import java.lang.management.MemoryUsage;
import java.lang.management.RuntimeMXBean;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.nio.file.StandardOpenOption;
import java.nio.file.attribute.FileAttribute;
import java.util.ArrayList;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public final class MemoryUtils {
    private static final Logger logger = LoggerFactory.getLogger(MemoryUtils.class);

    private MemoryUtils() {
    }

    public static void collectMemoryInfo(Metrics metrics) {
        if (metrics != null && Boolean.getBoolean("collect-memory")) {
            MemoryMXBean memBean = ManagementFactory.getMemoryMXBean();
            MemoryUsage heap = memBean.getHeapMemoryUsage();
            MemoryUsage nonHeap = memBean.getNonHeapMemoryUsage();
            long heapCommitted = heap.getCommitted();
            long nonHeapCommitted = nonHeap.getCommitted();
            MemoryUtils.getProcessInfo(metrics);
            metrics.addMetric("Heap", (Number)heapCommitted, "bytes");
            metrics.addMetric("NonHeap", (Number)nonHeapCommitted, "bytes");
            Engine engine = Engine.getInstance();
            int gpuCount = engine.getGpuCount();
            for (int i = 0; i < gpuCount; ++i) {
                Device device = Device.gpu((int)i);
                MemoryUsage mem = engine.getGpuMemory(device);
                metrics.addMetric("GPU-" + i, (Number)mem.getCommitted(), "bytes");
            }
        }
    }

    public static void dumpMemoryInfo(Metrics metrics, String logDir) {
        if (logDir == null) {
            return;
        }
        try {
            Path dir = Paths.get(logDir, new String[0]);
            Files.createDirectories(dir, new FileAttribute[0]);
            Path file = dir.resolve("memory.log");
            try (BufferedWriter writer = Files.newBufferedWriter(file, StandardOpenOption.CREATE, StandardOpenOption.APPEND);){
                ArrayList list = new ArrayList();
                list.addAll(metrics.getMetric("Heap"));
                list.addAll(metrics.getMetric("NonHeap"));
                list.addAll(metrics.getMetric("cpu"));
                list.addAll(metrics.getMetric("rss"));
                int gpuCount = Engine.getInstance().getGpuCount();
                for (int i = 0; i < gpuCount; ++i) {
                    list.addAll(metrics.getMetric("GPU-" + i));
                }
                for (Metric metric : list) {
                    writer.append(metric.toString());
                    writer.newLine();
                }
            }
        }
        catch (IOException e) {
            logger.error("Failed dump memory log", (Throwable)e);
        }
    }

    private static void getProcessInfo(Metrics metrics) {
        if (System.getProperty("os.name").startsWith("Linux") || System.getProperty("os.name").startsWith("Mac")) {
            RuntimeMXBean mxBean = ManagementFactory.getRuntimeMXBean();
            String pid = mxBean.getName().split("@")[0];
            String cmd = "ps -o %cpu= -o rss= -p " + pid;
            try {
                Process process = Runtime.getRuntime().exec(cmd);
                try (InputStream is = process.getInputStream();){
                    String line = new String(MemoryUtils.readAll(is), StandardCharsets.UTF_8).trim();
                    String[] tokens = line.split("\\s+");
                    if (tokens.length != 2) {
                        logger.error("Invalid ps output: " + line);
                        return;
                    }
                    float cpu = Float.parseFloat(tokens[0]);
                    long rss = Long.parseLong(tokens[1]);
                    metrics.addMetric("cpu", (Number)Float.valueOf(cpu), "%");
                    metrics.addMetric("rss", (Number)rss, "KB");
                }
            }
            catch (IOException e) {
                logger.error("Failed execute cmd: " + cmd, (Throwable)e);
            }
        }
    }

    private static byte[] readAll(InputStream is) throws IOException {
        try (ByteArrayOutputStream bos = new ByteArrayOutputStream();){
            int read;
            byte[] buf = new byte[8192];
            while ((read = is.read(buf)) != -1) {
                bos.write(buf, 0, read);
            }
            byte[] byArray = bos.toByteArray();
            return byArray;
        }
    }
}

