/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.util.cuda;

import ai.djl.Device;
import ai.djl.engine.EngineException;
import ai.djl.util.cuda.CudaLibrary;
import com.sun.jna.Native;
import java.lang.management.MemoryUsage;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public final class CudaUtils {
    private static final Logger logger = LoggerFactory.getLogger(CudaUtils.class);
    private static final CudaLibrary LIB = CudaUtils.loadLibrary();

    private CudaUtils() {
    }

    public static boolean hasCuda() {
        return LIB != null;
    }

    public static int getGpuCount() {
        if (LIB == null) {
            return 0;
        }
        int[] count = new int[1];
        int result = LIB.cudaGetDeviceCount(count);
        switch (result) {
            case 0: {
                return count[0];
            }
            case 100: {
                logger.debug("No GPU device found: {} ({})", (Object)LIB.cudaGetErrorString(result), (Object)result);
                return 0;
            }
        }
        logger.warn("Failed to detect GPU count: {} ({})", (Object)LIB.cudaGetErrorString(result), (Object)result);
        return 0;
    }

    public static int getCudaVersion() {
        if (LIB == null) {
            throw new IllegalStateException("No cuda library is loaded.");
        }
        int[] version = new int[1];
        int result = LIB.cudaRuntimeGetVersion(version);
        CudaUtils.checkCall(result);
        return version[0];
    }

    public static String getCudaVersionString() {
        if (LIB == null) {
            throw new IllegalStateException("No cuda library is loaded.");
        }
        int version = CudaUtils.getCudaVersion();
        int major = version / 1000;
        int minor = version / 10 % 10;
        return String.valueOf(major) + minor;
    }

    public static String getComputeCapability(int device) {
        if (LIB == null) {
            throw new IllegalStateException("No cuda library is loaded.");
        }
        int attrComputeCapabilityMajor = 75;
        int attrComputeCapabilityMinor = 76;
        int[] major = new int[1];
        int[] minor = new int[1];
        CudaUtils.checkCall(LIB.cudaDeviceGetAttribute(major, attrComputeCapabilityMajor, device));
        CudaUtils.checkCall(LIB.cudaDeviceGetAttribute(minor, attrComputeCapabilityMinor, device));
        return String.valueOf(major[0]) + minor[0];
    }

    public static MemoryUsage getGpuMemory(Device device) {
        if (!"gpu".equals(device.getDeviceType())) {
            throw new IllegalArgumentException("Only GPU device is allowed.");
        }
        if (LIB == null) {
            throw new IllegalStateException("No GPU device detected.");
        }
        int[] currentDevice = new int[1];
        CudaUtils.checkCall(LIB.cudaGetDevice(currentDevice));
        CudaUtils.checkCall(LIB.cudaSetDevice(device.getDeviceId()));
        long[] free = new long[1];
        long[] total = new long[1];
        CudaUtils.checkCall(LIB.cudaMemGetInfo(free, total));
        CudaUtils.checkCall(LIB.cudaSetDevice(currentDevice[0]));
        long committed = total[0] - free[0];
        return new MemoryUsage(-1L, committed, committed, total[0]);
    }

    private static CudaLibrary loadLibrary() {
        try {
            return (CudaLibrary)Native.load((String)"cudart", CudaLibrary.class);
        }
        catch (UnsatisfiedLinkError e) {
            logger.debug("cudart library not found.");
            logger.trace("", (Throwable)e);
            return null;
        }
    }

    private static void checkCall(int ret) {
        if (LIB == null) {
            throw new IllegalStateException("No cuda library is loaded.");
        }
        if (ret != 0) {
            throw new EngineException("CUDA API call failed: " + LIB.cudaGetErrorString(ret) + " (" + ret + ')');
        }
    }
}

