/*
 * Decompiled with CFR 0.152.
 */
package org.dromara.easyai.transFormer.seflAttention;

import java.util.HashMap;
import java.util.Map;
import java.util.Random;
import org.dromara.easyai.matrixTools.Matrix;
import org.dromara.easyai.matrixTools.MatrixOperation;
import org.dromara.easyai.transFormer.model.QKVModel;
import org.dromara.easyai.transFormer.seflAttention.AttentionError;
import org.dromara.easyai.transFormer.seflAttention.EventBody;

public class SelfAttention {
    private final Map<Long, MyFeature> featureMatrix = new HashMap<Long, MyFeature>();
    private Matrix powerQ;
    private Matrix powerK;
    private Matrix powerV;
    private final int wordVectorDimension;
    private final int depth;
    private final float studyPoint;
    private final int selfID;
    private final boolean encoder;
    private final MatrixOperation matrixOperation;

    public int getSelfID() {
        return this.selfID;
    }

    public SelfAttention(float studyPoint, int depth, int wordVectorDimension, int selfID, boolean encoder, int coreNumber) throws Exception {
        this.matrixOperation = new MatrixOperation(coreNumber);
        this.studyPoint = studyPoint;
        this.depth = depth;
        this.encoder = encoder;
        this.wordVectorDimension = wordVectorDimension;
        this.selfID = selfID;
        this.powerQ = this.initPowerMatrix(wordVectorDimension);
        this.powerK = this.initPowerMatrix(wordVectorDimension);
        this.powerV = this.initPowerMatrix(wordVectorDimension);
    }

    public void insertModel(QKVModel qkvModel) throws Exception {
        this.insertPower(qkvModel.getQ(), this.powerQ);
        this.insertPower(qkvModel.getK(), this.powerK);
        this.insertPower(qkvModel.getV(), this.powerV);
    }

    private void insertPower(float[][] modelPower, Matrix power) throws Exception {
        for (int i = 0; i < power.getX(); ++i) {
            for (int j = 0; j < power.getY(); ++j) {
                power.setNub(i, j, modelPower[i][j]);
            }
        }
    }

    public QKVModel getModel() throws Exception {
        QKVModel qkvModel = new QKVModel();
        qkvModel.setQ(this.powerQ.getMatrix());
        qkvModel.setK(this.powerK.getMatrix());
        qkvModel.setV(this.powerV.getMatrix());
        qkvModel.setSelfID(this.selfID);
        return qkvModel;
    }

    public AttentionError backError(Matrix feature, long eventID) throws Exception {
        Matrix nextFeatureError;
        Matrix myError = this.matrixOperation.mathMulBySelf(feature, this.studyPoint);
        MyFeature featureBody = this.featureMatrix.get(eventID);
        Matrix q = featureBody.q;
        Matrix kt = featureBody.kt;
        Matrix v = featureBody.v;
        Matrix qkt = featureBody.qkt;
        Matrix errorV = this.matrixOperation.matrixMulPd(myError, qkt, v, false);
        Matrix subQktMax = this.matrixOperation.matrixMulPd(feature, qkt, v, true);
        Matrix grMatrix = this.matrixOperation.matrixSoftMaxPd(qkt, subQktMax, this.wordVectorDimension);
        if (this.depth == 1 && !this.encoder) {
            this.backMask(grMatrix);
        }
        Matrix errorKt = this.matrixOperation.matrixMulPd(grMatrix, q, kt, false);
        Matrix errorQ = this.matrixOperation.matrixMulPd(grMatrix, q, kt, true);
        Matrix errorK = this.matrixOperation.transPosition(errorKt);
        ErrorFeature QPower = this.updateError(errorQ, featureBody.allFeature, this.powerQ);
        Matrix leftMatrix = featureBody.allFeature;
        if (!this.encoder && this.depth > 1) {
            leftMatrix = featureBody.encoderFeature;
        }
        ErrorFeature KPower = this.updateError(errorK, leftMatrix, this.powerK);
        ErrorFeature VPower = this.updateError(errorV, leftMatrix, this.powerV);
        this.powerQ = QPower.powerMatrix;
        this.powerK = KPower.powerMatrix;
        this.powerV = VPower.powerMatrix;
        AttentionError attentionError = new AttentionError();
        Matrix lastEncoderError = null;
        if (!this.encoder && this.depth > 1) {
            nextFeatureError = QPower.errorFeatureMatrix;
            lastEncoderError = this.matrixOperation.add(KPower.errorFeatureMatrix, VPower.errorFeatureMatrix);
        } else {
            nextFeatureError = this.matrixOperation.addThreeMatrix(QPower.errorFeatureMatrix, KPower.errorFeatureMatrix, VPower.errorFeatureMatrix);
        }
        attentionError.setNextFeatureError(nextFeatureError);
        attentionError.setLastEncoderError(lastEncoderError);
        this.featureMatrix.remove(eventID);
        return attentionError;
    }

    private ErrorFeature updateError(Matrix errorMatrix, Matrix feature, Matrix powerMatrix) throws Exception {
        Matrix errorPower = this.matrixOperation.matrixMulPd(errorMatrix, feature, powerMatrix, false);
        Matrix featureError = this.matrixOperation.matrixMulPd(errorMatrix, feature, powerMatrix, true);
        Matrix nextPowerMatrix = this.matrixOperation.add(powerMatrix, errorPower);
        ErrorFeature errorFeature = new ErrorFeature();
        errorFeature.errorFeatureMatrix = featureError;
        errorFeature.powerMatrix = nextPowerMatrix;
        return errorFeature;
    }

    private void backMask(Matrix matrix) throws Exception {
        int x = matrix.getX();
        int y = matrix.getY();
        for (int i = 0; i < x; ++i) {
            for (int j = i + 1; j < y; ++j) {
                matrix.setNub(i, j, 0.0f);
            }
        }
    }

    private void mask(Matrix matrix) throws Exception {
        int x = matrix.getX();
        int y = matrix.getY();
        for (int i = 0; i < x; ++i) {
            for (int j = i + 1; j < y; ++j) {
                matrix.setNub(i, j, -1000.0f);
            }
        }
    }

    private Matrix countSelfAttention(long eventID, boolean isStudy) throws Exception {
        MyFeature featureBody = this.featureMatrix.get(eventID);
        Matrix myFeature = featureBody.allFeature;
        Matrix kvFeature = !this.encoder && this.depth > 1 ? featureBody.encoderFeature : featureBody.allFeature;
        Matrix q = this.matrixOperation.mulMatrix(myFeature, this.powerQ);
        Matrix k = this.matrixOperation.mulMatrix(kvFeature, this.powerK);
        Matrix v = this.matrixOperation.mulMatrix(kvFeature, this.powerV);
        Matrix kt = this.matrixOperation.transPosition(k);
        Matrix qkt = this.matrixOperation.mulMatrix(q, kt);
        this.matrixOperation.mathDiv(qkt, (float)Math.sqrt(this.wordVectorDimension));
        if (this.depth == 1 && !this.encoder) {
            this.mask(qkt);
        }
        this.matrixOperation.softMax(qkt);
        Matrix result = this.matrixOperation.mulMatrix(qkt, v);
        if (!isStudy) {
            this.featureMatrix.remove(eventID);
        } else {
            featureBody.q = q;
            featureBody.kt = kt;
            featureBody.v = v;
            featureBody.qkt = qkt;
        }
        return result;
    }

    public EventBody sendMatrixFeature(long eventID, boolean isStudy, Matrix feature, Matrix encoderFeature) throws Exception {
        EventBody eventBody = new EventBody();
        eventBody.setEventID(eventID);
        eventBody.setSelfID(this.selfID);
        MyFeature myFeature = new MyFeature();
        myFeature.allFeature = feature;
        myFeature.encoderFeature = encoderFeature;
        this.featureMatrix.put(eventID, myFeature);
        eventBody.setFeatureMatrix(this.countSelfAttention(eventID, isStudy));
        return eventBody;
    }

    private Matrix initPowerMatrix(int wordVectorDimension) throws Exception {
        Random random = new Random();
        Matrix power = new Matrix(wordVectorDimension, wordVectorDimension);
        for (int i = 0; i < wordVectorDimension; ++i) {
            for (int j = 0; j < wordVectorDimension; ++j) {
                power.setNub(i, j, random.nextFloat() / (float)wordVectorDimension);
            }
        }
        return power;
    }

    static class MyFeature {
        Matrix allFeature;
        Matrix encoderFeature;
        Matrix q;
        Matrix kt;
        Matrix v;
        Matrix qkt;

        MyFeature() {
        }
    }

    static class ErrorFeature {
        Matrix errorFeatureMatrix;
        Matrix powerMatrix;

        ErrorFeature() {
        }
    }
}

