/*
 * Decompiled with CFR 0.152.
 */
package org.dromara.easyai.extensions.cuda;

import jcuda.Pointer;
import jcuda.jcublas.JCublas;
import jcuda.jcudnn.JCudnn;
import jcuda.jcudnn.cudnnHandle;
import jcuda.jcudnn.cudnnTensorDescriptor;
import jcuda.runtime.JCuda;
import org.dromara.easyai.extensions.cuda.MatrixUtil;
import org.dromara.jcudax.JCudax;

public class CudaFp64Util {
    public static double[] matrixMulMatrix(double[] h_A, double[] h_B, int M, int K, int N) throws Exception {
        return CudaFp64Util.matrixGemm(h_A, h_B, null, 1.0, 0.0, M, K, N);
    }

    public static double[] matrixAddScalar(double[] h_A, double scalar, int M, int N) throws Exception {
        return CudaFp64Util.matrixAddMatrix(h_A, MatrixUtil.makeFp64Array(M, N, scalar), M, N);
    }

    public static double[] matrixAddMatrix(double[] h_A, double[] h_C, int M, int N) throws Exception {
        double[] h_B = MatrixUtil.makeFp64UnitColumnMajorArray(N);
        return CudaFp64Util.matrixGemm(h_A, h_B, h_C, 1.0, 1.0, M, N, N);
    }

    public static double[] matrixGemm(double[] h_A, double[] h_B, double[] h_C, double alpha, double beta, int M, int K, int N) {
        Pointer d_A = new Pointer();
        Pointer d_B = new Pointer();
        Pointer d_C = new Pointer();
        JCuda.cudaMalloc(d_A, (long)M * (long)K * 8L);
        JCuda.cudaMalloc(d_B, (long)K * (long)N * 8L);
        JCuda.cudaMalloc(d_C, (long)M * (long)N * 8L);
        JCublas.cublasSetMatrix(M, K, 8, Pointer.to(h_A), M, d_A, M);
        JCublas.cublasSetMatrix(K, N, 8, Pointer.to(h_B), K, d_B, K);
        if (null == h_C) {
            h_C = new double[M * N];
        } else {
            JCublas.cublasSetMatrix(M, N, 8, Pointer.to(h_C), M, d_C, M);
        }
        JCublas.cublasDgemm('n', 'n', M, N, K, alpha, d_A, M, d_B, K, beta, d_C, M);
        JCublas.cublasGetMatrix(M, N, 8, d_C, M, Pointer.to(h_C), M);
        JCuda.cudaFree(d_A);
        JCuda.cudaFree(d_B);
        JCuda.cudaFree(d_C);
        return h_C;
    }

    public static double[] matrixScale(double[] h_A, double alpha, int M, int N) {
        Pointer d_A = new Pointer();
        JCuda.cudaMalloc(d_A, (long)M * (long)N * 8L);
        JCublas.cublasSetMatrix(M, N, 8, Pointer.to(h_A), M, d_A, M);
        JCublas.cublasDscal(M * N, alpha, d_A, 1);
        JCublas.cublasGetMatrix(M, N, 8, d_A, M, Pointer.to(h_A), M);
        JCuda.cudaFree(d_A);
        return h_A;
    }

    public static double[] matrixSoftmax(double[] h_A, int M, int N) {
        cudnnHandle handle = new cudnnHandle();
        JCudnn.cudnnCreate(handle);
        Pointer d_A = new Pointer();
        Pointer d_C = new Pointer();
        JCuda.cudaMalloc(d_A, (long)M * (long)N * 8L);
        JCuda.cudaMalloc(d_C, (long)M * (long)N * 8L);
        JCublas.cublasSetMatrix(M, N, 8, Pointer.to(h_A), M, d_A, M);
        cudnnTensorDescriptor inputDesc = new cudnnTensorDescriptor();
        JCudnn.cudnnCreateTensorDescriptor(inputDesc);
        JCudnn.cudnnSetTensor4dDescriptor(inputDesc, 0, 1, M, 1, 1, N);
        cudnnTensorDescriptor outputDesc = new cudnnTensorDescriptor();
        JCudnn.cudnnCreateTensorDescriptor(outputDesc);
        JCudnn.cudnnSetTensor4dDescriptor(outputDesc, 0, 1, M, 1, 1, N);
        JCudnn.cudnnSoftmaxForward(handle, 1, 0, Pointer.to(new double[]{1.0}), inputDesc, d_A, Pointer.to(new double[]{0.0}), outputDesc, d_C);
        double[] h_C = new double[M * N];
        JCublas.cublasGetMatrix(M, N, 8, d_C, M, Pointer.to(h_C), M);
        JCudnn.cudnnDestroyTensorDescriptor(inputDesc);
        JCuda.cudaFree(d_A);
        JCuda.cudaFree(d_C);
        JCudnn.cudnnDestroy(handle);
        return h_C;
    }

    public static void matrixSoftMaxPd(double[] qkt, double[] errorMatrix, double[] grMatrix, int x, int y, double wordVectorDimension) {
        JCudax.matrixSoftMaxPd(qkt, errorMatrix, grMatrix, x, y, wordVectorDimension);
    }
}

