/*
 * Decompiled with CFR 0.152.
 */
package org.dromara.easyai.extensions.cuda;

import jcuda.Pointer;
import jcuda.jcublas.JCublas;
import jcuda.jcudnn.JCudnn;
import jcuda.jcudnn.cudnnHandle;
import jcuda.jcudnn.cudnnTensorDescriptor;
import jcuda.runtime.JCuda;
import org.dromara.easyai.extensions.cuda.MatrixUtil;
import org.dromara.jcudax.JCudax;

public class CudaFp32Util {
    public static float[] matrixMulMatrix(float[] h_A, float[] h_B, int M, int K, int N) throws Exception {
        return CudaFp32Util.matrixGemm(h_A, h_B, null, 1.0f, 0.0f, M, K, N);
    }

    public static float[] matrixAddScalar(float[] h_A, float scalar, int M, int N) throws Exception {
        return CudaFp32Util.matrixAddMatrix(h_A, MatrixUtil.makeFp32Array(M, N, scalar), M, N);
    }

    public static float[] matrixAddMatrix(float[] h_A, float[] h_C, int M, int N) throws Exception {
        float[] h_B = MatrixUtil.makeFp32UnitColumnMajorArray(N);
        return CudaFp32Util.matrixGemm(h_A, h_B, h_C, 1.0f, 1.0f, M, N, N);
    }

    public static float[] matrixGemm(float[] h_A, float[] h_B, float[] h_C, float alpha, float beta, int M, int K, int N) {
        Pointer d_A = new Pointer();
        Pointer d_B = new Pointer();
        Pointer d_C = new Pointer();
        JCuda.cudaMalloc(d_A, (long)M * (long)K * 4L);
        JCuda.cudaMalloc(d_B, (long)K * (long)N * 4L);
        JCuda.cudaMalloc(d_C, (long)M * (long)N * 4L);
        JCublas.cublasSetMatrix(M, K, 4, Pointer.to(h_A), M, d_A, M);
        JCublas.cublasSetMatrix(K, N, 4, Pointer.to(h_B), K, d_B, K);
        if (null == h_C) {
            h_C = new float[M * N];
        } else {
            JCublas.cublasSetMatrix(M, N, 4, Pointer.to(h_C), M, d_C, M);
        }
        JCublas.cublasSgemm('n', 'n', M, N, K, alpha, d_A, M, d_B, K, beta, d_C, M);
        JCublas.cublasGetMatrix(M, N, 4, d_C, M, Pointer.to(h_C), M);
        JCuda.cudaFree(d_A);
        JCuda.cudaFree(d_B);
        JCuda.cudaFree(d_C);
        return h_C;
    }

    public static float[] matrixScale(float[] h_A, float alpha, int M, int N) {
        Pointer d_A = new Pointer();
        JCuda.cudaMalloc(d_A, (long)M * (long)N * 4L);
        JCublas.cublasSetMatrix(M, N, 4, Pointer.to(h_A), M, d_A, M);
        JCublas.cublasSscal(M * N, alpha, d_A, 1);
        JCublas.cublasGetMatrix(M, N, 4, d_A, M, Pointer.to(h_A), M);
        JCuda.cudaFree(d_A);
        return h_A;
    }

    public static float[] matrixSoftmax(float[] h_A, int M, int N) {
        cudnnHandle handle = new cudnnHandle();
        JCudnn.cudnnCreate(handle);
        Pointer d_A = new Pointer();
        Pointer d_C = new Pointer();
        JCuda.cudaMalloc(d_A, (long)M * (long)N * 4L);
        JCuda.cudaMalloc(d_C, (long)M * (long)N * 4L);
        JCublas.cublasSetMatrix(M, N, 4, Pointer.to(h_A), M, d_A, M);
        cudnnTensorDescriptor inputDesc = new cudnnTensorDescriptor();
        JCudnn.cudnnCreateTensorDescriptor(inputDesc);
        JCudnn.cudnnSetTensor4dDescriptor(inputDesc, 0, 0, M, 1, 1, N);
        cudnnTensorDescriptor outputDesc = new cudnnTensorDescriptor();
        JCudnn.cudnnCreateTensorDescriptor(outputDesc);
        JCudnn.cudnnSetTensor4dDescriptor(outputDesc, 0, 0, M, 1, 1, N);
        JCudnn.cudnnSoftmaxForward(handle, 1, 0, Pointer.to(new float[]{1.0f}), inputDesc, d_A, Pointer.to(new float[]{0.0f}), outputDesc, d_C);
        float[] h_C = new float[M * N];
        JCublas.cublasGetMatrix(M, N, 4, d_C, M, Pointer.to(h_C), M);
        JCudnn.cudnnDestroyTensorDescriptor(inputDesc);
        JCuda.cudaFree(d_A);
        JCuda.cudaFree(d_C);
        JCudnn.cudnnDestroy(handle);
        return h_C;
    }

    public static void matrixSoftMaxPd(float[] qkt, float[] errorMatrix, float[] grMatrix, int x, int y, float wordVectorDimension) {
        JCudax.matrixSoftMaxPdFp32(qkt, errorMatrix, grMatrix, x, y, wordVectorDimension);
    }
}

