/*
 * Decompiled with CFR 0.152.
 */
package com.github.tjake.jlama.tensor.operations;

import com.github.tjake.jlama.safetensors.DType;
import com.github.tjake.jlama.tensor.AbstractTensor;
import com.github.tjake.jlama.tensor.Q4ByteBufferTensor;
import com.github.tjake.jlama.tensor.Q8ByteBufferTensor;
import com.github.tjake.jlama.tensor.operations.NaiveTensorOperations;
import com.github.tjake.jlama.tensor.operations.PanamaTensorOperations;
import com.github.tjake.jlama.tensor.operations.TensorOperations;
import com.github.tjake.jlama.tensor.operations.cnative.NativeSimd;
import com.github.tjake.jlama.util.MachineSpec;
import com.github.tjake.jlama.util.RuntimeSupport;
import java.lang.foreign.Arena;
import java.lang.foreign.MemorySegment;
import java.lang.foreign.ValueLayout;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class NativeTensorOperations
implements TensorOperations {
    private static final int MAX_BATCH_SIZE = 4;
    private static final ThreadLocal<MemorySegment[]> tmpArr;
    private static final Logger logger;
    public static final int HAS_F16C;
    public static final int HAS_AVX2;
    private static final TensorOperations delegate;
    final int flags;

    public NativeTensorOperations() {
        int f = 0;
        if (RuntimeSupport.isLinux()) {
            f |= HAS_F16C;
        }
        if (MachineSpec.VECTOR_TYPE == MachineSpec.Type.AVX_512) {
            f |= HAS_AVX2;
        }
        this.flags = f;
        this.checkLib();
    }

    NativeTensorOperations(int flags) {
        this.flags = flags;
    }

    public String name() {
        return "Native SIMD Operations";
    }

    private void checkLib() {
        NativeSimd.gemm_f32$MH();
    }

    public boolean requiresOffHeapTensor() {
        return true;
    }

    public int parallelSplitSize() {
        return 128;
    }

    public void batchDotProduct(AbstractTensor result, AbstractTensor at, AbstractTensor bt, int aColumnOffset, int bColumnOffset, int columnLength, int bRowOffset, int rowChunkSize) {
        int M = at.shape().dim(0);
        int N = rowChunkSize;
        int K = columnLength;
        block0 : switch (at.dType()) {
            case F32: {
                switch (bt.dType()) {
                    case F32: {
                        NativeSimd.gemm_f32(this.flags, at.getMemorySegment(), at.getOffset(new int[]{0, aColumnOffset}), bt.getMemorySegment(), bt.getOffset(new int[]{0, bColumnOffset}), result.getMemorySegment(), result.shape().sparseOffset(), M, bRowOffset, N, K, at.getStride(), bt.getStride(), result.getStride());
                        break block0;
                    }
                    case Q4: {
                        switch (MachineSpec.VECTOR_TYPE) {
                            case ARM_128: {
                                throw new UnsupportedOperationException("F32 Q4 Unsupported on Arm");
                            }
                        }
                        Q4ByteBufferTensor b = (Q4ByteBufferTensor)bt;
                        NativeSimd.gemm_f32_q4(this.flags, at.getMemorySegment(), at.getOffset(new int[]{0, aColumnOffset}), b.getBlockF().getMemorySegment(), b.getMemorySegment(), b.getMemorySegmentOffset(b.getOffset(new int[]{0, bColumnOffset})), result.getMemorySegment(), result.shape().sparseOffset(), M, bRowOffset, N, K, at.getStride(), b.getMemorySegmentOffset(b.getStride()), b.getBlockF().getStride(), result.getStride());
                        break block0;
                    }
                }
                throw new UnsupportedOperationException(at.dType().name() + " " + bt.dType().name());
            }
            case I8: {
                switch (bt.dType()) {
                    case Q4: {
                        Q8ByteBufferTensor a = (Q8ByteBufferTensor)at;
                        Q4ByteBufferTensor b = (Q4ByteBufferTensor)bt;
                        NativeSimd.gemm_q8_q4(this.flags, a.getBlockF().getMemorySegment(), a.getMemorySegment(), a.getOffset(new int[]{0, aColumnOffset}), b.getBlockF().getMemorySegment(), b.getMemorySegment(), b.getMemorySegmentOffset(b.getOffset(new int[]{0, bColumnOffset})), result.getMemorySegment(), result.shape().sparseOffset(), M, bRowOffset, N, K, a.getStride(), a.getBlockF().getStride(), b.getMemorySegmentOffset(b.getStride()), b.getBlockF().getStride(), result.getStride());
                        break block0;
                    }
                }
                throw new UnsupportedOperationException(at.dType().name() + " " + bt.dType().name());
            }
            default: {
                throw new UnsupportedOperationException(at.dType().name());
            }
        }
    }

    public void dotProductBatchChunk(AbstractTensor[] r, AbstractTensor a, AbstractTensor[] b, int columnOffset, int columnLength, int bRowOffset, int rowChunkSize) {
        MemorySegment[] tmp = tmpArr.get();
        MemorySegment ra = tmp[0];
        MemorySegment rb = tmp[1];
        MemorySegment rc = tmp[2];
        for (int i = 0; i < r.length; ++i) {
            ra.setAtIndex(ValueLayout.ADDRESS, (long)i, r[i].getMemorySegment());
            rb.setAtIndex(ValueLayout.ADDRESS, (long)i, b[i].getMemorySegment());
        }
        int M = a.shape().dim(0);
        int N = rowChunkSize;
        int K = columnLength;
        block0 : switch (a.dType()) {
            case F32: {
                switch (b[0].dType()) {
                    case F32: {
                        NativeSimd.gemm_f32_batch(this.flags, r.length, a.getMemorySegment(), a.getOffset(new int[]{0, columnOffset}), rb, b[0].getOffset(new int[]{0, columnOffset}), ra, r[0].shape().sparseOffset(), M, bRowOffset, N, K, a.getStride(), b[0].getStride(), r[0].getStride());
                        break block0;
                    }
                    case Q4: {
                        switch (MachineSpec.VECTOR_TYPE) {
                            case ARM_128: {
                                throw new UnsupportedOperationException("F32 Q4 Unsupported on Arm");
                            }
                        }
                        Q4ByteBufferTensor bt = (Q4ByteBufferTensor)b[0];
                        for (int i = 0; i < r.length; ++i) {
                            rc.setAtIndex(ValueLayout.ADDRESS, (long)i, ((Q4ByteBufferTensor)b[i]).getBlockF().getMemorySegment());
                        }
                        NativeSimd.gemm_f32_q4_batch(this.flags, r.length, a.getMemorySegment(), a.getOffset(new int[]{0, columnOffset}), rc, rb, b[0].getMemorySegmentOffset(b[0].getOffset(new int[]{0, columnOffset})), ra, r[0].shape().sparseOffset(), M, bRowOffset, N, K, a.getStride(), b[0].getMemorySegmentOffset(b[0].getStride()), bt.getBlockF().getStride(), r[0].getStride());
                        break block0;
                    }
                }
                throw new UnsupportedOperationException(a.dType().name() + " " + b[0].dType().name());
            }
            case I8: {
                switch (b[0].dType()) {
                    case Q4: {
                        for (int i = 0; i < r.length; ++i) {
                            rc.setAtIndex(ValueLayout.ADDRESS, (long)i, ((Q4ByteBufferTensor)b[i]).getBlockF().getMemorySegment());
                        }
                        Q8ByteBufferTensor at = (Q8ByteBufferTensor)a;
                        Q4ByteBufferTensor bt = (Q4ByteBufferTensor)b[0];
                        NativeSimd.gemm_q8_q4_batch(this.flags, r.length, at.getBlockF().getMemorySegment(), a.getMemorySegment(), a.getOffset(new int[]{0, columnOffset}), rc, rb, bt.getMemorySegmentOffset(bt.getOffset(new int[]{0, columnOffset})), ra, r[0].shape().sparseOffset(), M, bRowOffset, N, K, a.getStride(), at.getBlockF().getStride(), bt.getMemorySegmentOffset(bt.getStride()), bt.getBlockF().getStride(), r[0].getStride());
                        break block0;
                    }
                }
                throw new UnsupportedOperationException(a.dType().name() + " " + b[0].dType().name());
            }
            default: {
                throw new UnsupportedOperationException(a.dType().name());
            }
        }
    }

    public void accumulate(AbstractTensor a, AbstractTensor b, int offset, int length) {
        delegate.accumulate(a, b, offset, length);
    }

    public void maccumulate(AbstractTensor a, AbstractTensor b, int offset, int length) {
        delegate.maccumulate(a, b, offset, length);
    }

    public void saxpy(float alpha, AbstractTensor x, AbstractTensor y, int xoffset, int yoffset, int limit) {
        delegate.saxpy(alpha, x, y, xoffset, yoffset, limit);
    }

    public void saxpy(AbstractTensor alpha, AbstractTensor x, AbstractTensor y, int xoffset, int yoffset, int limit, int batchSize) {
        delegate.saxpy(alpha, x, y, xoffset, yoffset, limit, batchSize);
    }

    public void sxpby(float beta, AbstractTensor x, AbstractTensor y, int xoffset, int yoffset, int limit) {
        delegate.sxpby(beta, x, y, xoffset, yoffset, limit);
    }

    public void scale(float factor, AbstractTensor x, int offset, int length) {
        delegate.scale(factor, x, offset, length);
    }

    public AbstractTensor quantize(AbstractTensor t, DType qtype, int offset, int length) {
        return delegate.quantize(t, qtype, offset, length);
    }

    static {
        PanamaTensorOperations tmp;
        tmpArr = ThreadLocal.withInitial(() -> new MemorySegment[]{Arena.global().allocateArray(ValueLayout.ADDRESS, 4L), Arena.global().allocateArray(ValueLayout.ADDRESS, 4L), Arena.global().allocateArray(ValueLayout.ADDRESS, 4L)});
        logger = LoggerFactory.getLogger(NativeTensorOperations.class);
        HAS_F16C = NativeSimd.HAS_F16C();
        HAS_AVX2 = NativeSimd.HAS_AVX2();
        try {
            tmp = new PanamaTensorOperations(MachineSpec.VECTOR_TYPE);
        }
        catch (Throwable t) {
            tmp = new NaiveTensorOperations();
        }
        delegate = tmp;
    }
}

