/*
 * Decompiled with CFR 0.152.
 */
package jama.gpu;

import com.nativelibs4java.opencl.CLBuffer;
import com.nativelibs4java.opencl.CLContext;
import com.nativelibs4java.opencl.CLDevice;
import com.nativelibs4java.opencl.CLEvent;
import com.nativelibs4java.opencl.CLException;
import com.nativelibs4java.opencl.CLMem;
import com.nativelibs4java.opencl.CLQueue;
import com.nativelibs4java.opencl.JavaCL;
import jama.FloatMatrix;
import jama.Matrix;
import jama.gpu.MultiplicationKernel;
import java.io.IOException;
import org.bridj.Pointer;

public class GPU {
    private final CLContext context;
    private static CLContext defaultContext;

    public static GPU create() {
        if (defaultContext == null) {
            defaultContext = JavaCL.createBestContext();
        }
        return new GPU(defaultContext);
    }

    public static GPU create(CLContext context) {
        return new GPU(context);
    }

    private GPU(CLContext context) {
        this.context = context;
    }

    public FloatMatrix multiply(FloatMatrix A, FloatMatrix B) throws IOException {
        return this.multiply(A, B, true);
    }

    public FloatMatrix multiplyLocal(FloatMatrix A, FloatMatrix B) throws IOException {
        return this.multiply(A, B, false);
    }

    private FloatMatrix multiply(FloatMatrix A, FloatMatrix B, boolean local) throws IOException {
        if (A.getColumnDimension() != B.getRowDimension()) {
            throw new IllegalArgumentException("Matrix inner dimensions must agree.");
        }
        CLQueue queue = this.context.createDefaultQueue(new CLDevice.QueueProperties[0]);
        FloatMatrix padA = GPU.zeroPadding(A, 16);
        FloatMatrix padB = GPU.zeroPadding(B, 16);
        int resultLength = padA.getRowDimension() * padB.getColumnDimension();
        Pointer<Float> aPtr = GPU.matrixToPointer(padA);
        Pointer<Float> bPtr = GPU.matrixToPointer(padB);
        Pointer resultPtr = Pointer.allocateFloats((long)resultLength);
        Pointer q = Pointer.allocateInt();
        q.set((Object)padA.getColumnDimension());
        CLBuffer aInputBuffer = this.context.createBuffer(CLMem.Usage.Input, aPtr);
        CLBuffer bInputBuffer = this.context.createBuffer(CLMem.Usage.Input, bPtr);
        CLBuffer qInputBuffer = this.context.createIntBuffer(CLMem.Usage.Input, q);
        CLBuffer resultBuffer = this.context.createBuffer(CLMem.Usage.Output, resultPtr);
        MultiplicationKernel kernel = new MultiplicationKernel(this.context);
        int[] localWorkSizes = new int[]{16, 16};
        int[] globalWorkSizes = new int[]{padA.getRowDimension(), padB.getColumnDimension()};
        CLEvent clEvent = null;
        Pointer outPtr = null;
        FloatMatrix matrix = null;
        try {
            clEvent = local ? kernel.floatMatrixMultLocals(queue, (CLBuffer<Float>)resultBuffer, (CLBuffer<Float>)aInputBuffer, (CLBuffer<Float>)bInputBuffer, (CLBuffer<Integer>)qInputBuffer, globalWorkSizes, localWorkSizes, new CLEvent[0]) : kernel.floatMatrixMult(queue, (CLBuffer<Float>)resultBuffer, (CLBuffer<Float>)aInputBuffer, (CLBuffer<Float>)bInputBuffer, (CLBuffer<Integer>)qInputBuffer, globalWorkSizes, localWorkSizes, new CLEvent[0]);
            outPtr = resultBuffer.read(queue, new CLEvent[]{clEvent});
            matrix = GPU.pointerToFloatMatrix((Pointer<Float>)outPtr, padA.getRowDimension(), padB.getColumnDimension());
            matrix = GPU.removeZeroPadding(matrix, A, B);
        }
        catch (CLException e) {
            try {
                e.printStackTrace();
                throw e;
            }
            catch (Throwable throwable) {
                Pointer.release((Pointer[])new Pointer[]{aPtr, bPtr, outPtr, resultPtr, q});
                aInputBuffer.release();
                bInputBuffer.release();
                qInputBuffer.release();
                resultBuffer.release();
                queue.release();
                clEvent.release();
                throw throwable;
            }
        }
        Pointer.release((Pointer[])new Pointer[]{aPtr, bPtr, outPtr, resultPtr, q});
        aInputBuffer.release();
        bInputBuffer.release();
        qInputBuffer.release();
        resultBuffer.release();
        queue.release();
        clEvent.release();
        return matrix;
    }

    protected static Pointer<Double> matrixToPointer(Matrix matrix) {
        int size = matrix.getColumnDimension() * matrix.getRowDimension();
        Pointer pointer = Pointer.allocateDoubles((long)size);
        for (int row = 0; row < matrix.getRowDimension(); ++row) {
            for (int col = 0; col < matrix.getColumnDimension(); ++col) {
                pointer.set((long)(row + matrix.getRowDimension() * col), (Object)matrix.get(row, col));
            }
        }
        return pointer;
    }

    protected static Pointer<Float> matrixToPointer(FloatMatrix matrix) {
        int size = matrix.getColumnDimension() * matrix.getRowDimension();
        Pointer pointer = Pointer.allocateFloats((long)size);
        for (int row = 0; row < matrix.getRowDimension(); ++row) {
            for (int col = 0; col < matrix.getColumnDimension(); ++col) {
                pointer.set((long)(row + matrix.getRowDimension() * col), (Object)Float.valueOf(matrix.get(row, col)));
            }
        }
        return pointer;
    }

    protected static Matrix pointerToMatrix(Pointer<Double> pointer, int rows, int cols) {
        Matrix matrix = new Matrix(rows, cols);
        for (int row = 0; row < rows; ++row) {
            for (int col = 0; col < cols; ++col) {
                matrix.set(row, col, (Double)pointer.get((long)(row + matrix.getRowDimension() * col)));
            }
        }
        return matrix;
    }

    protected static FloatMatrix pointerToFloatMatrix(Pointer<Float> pointer, int rows, int cols) {
        FloatMatrix matrix = new FloatMatrix(rows, cols);
        for (int row = 0; row < rows; ++row) {
            for (int col = 0; col < cols; ++col) {
                matrix.set(row, col, ((Float)pointer.get((long)(row + matrix.getRowDimension() * col))).floatValue());
            }
        }
        return matrix;
    }

    protected static int workgroupSize(int size, int blocksize) {
        if (size <= blocksize) {
            return blocksize;
        }
        int rest = size % blocksize;
        if (rest == 0) {
            return size;
        }
        return size + blocksize - rest;
    }

    protected static FloatMatrix zeroPadding(FloatMatrix matrix, int workgroupSize) {
        int m = GPU.workgroupSize(matrix.getRowDimension(), workgroupSize);
        int n = GPU.workgroupSize(matrix.getColumnDimension(), workgroupSize);
        if (m == matrix.getRowDimension() && n == matrix.getColumnDimension()) {
            return matrix;
        }
        FloatMatrix paddedMatrix = new FloatMatrix(m, n);
        paddedMatrix.setFloatMatrix(0, matrix.getRowDimension() - 1, 0, matrix.getColumnDimension() - 1, matrix);
        return paddedMatrix;
    }

    protected static FloatMatrix removeZeroPadding(FloatMatrix result, FloatMatrix A, FloatMatrix B) {
        if (result.getColumnDimension() == B.getColumnDimension() && result.getRowDimension() == A.getRowDimension()) {
            return result;
        }
        return result.getFloatMatrix(0, A.getRowDimension() - 1, 0, B.getColumnDimension() - 1);
    }
}

