/*
 * Decompiled with CFR 0.152.
 */
package org.dromara.easyai.transFormer.seflAttention;

import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import org.dromara.easyai.i.OutBack;
import org.dromara.easyai.matrixTools.Matrix;
import org.dromara.easyai.matrixTools.MatrixOperation;
import org.dromara.easyai.transFormer.CodecBlock;
import org.dromara.easyai.transFormer.TransWordVector;
import org.dromara.easyai.transFormer.model.MultiSelfAttentionModel;
import org.dromara.easyai.transFormer.model.QKVModel;
import org.dromara.easyai.transFormer.seflAttention.AttentionError;
import org.dromara.easyai.transFormer.seflAttention.EventBody;
import org.dromara.easyai.transFormer.seflAttention.LayNorm;
import org.dromara.easyai.transFormer.seflAttention.SelfAttention;

public class MultiSelfAttention {
    private final CodecBlock codecBlock;
    private final List<SelfAttention> selfAttentions = new ArrayList<SelfAttention>();
    private LayNorm layNorm;
    private final float studyPoint;
    private Matrix powerMatrix;
    private final int multiNumber;
    private final int wordVectorDimension;
    private Matrix featureMatrix;
    private final int depth;
    private final boolean encoder;
    private final MatrixOperation matrixOperation;
    private final TransWordVector transWordVector;

    public void setLayNorm(LayNorm layNorm) {
        this.layNorm = layNorm;
    }

    public int getDepth() {
        return this.depth;
    }

    private QKVModel getQKV(List<QKVModel> qkvModelList, int selfID) {
        QKVModel myQKV = null;
        for (QKVModel qkvModel : qkvModelList) {
            if (qkvModel.getSelfID() != selfID) continue;
            myQKV = qkvModel;
            break;
        }
        return myQKV;
    }

    public void insertModel(MultiSelfAttentionModel multiSelfAttentionModel) throws Exception {
        this.insertPower(multiSelfAttentionModel.getPowerModel(), this.powerMatrix);
        List<QKVModel> qkvModelList = multiSelfAttentionModel.getQkvModelList();
        for (int i = 0; i < this.selfAttentions.size(); ++i) {
            QKVModel qkvModel = this.getQKV(qkvModelList, i);
            if (qkvModel == null) {
                throw new Exception("\u6a21\u578b\u4e0e\u6fc0\u6d3b\u53c2\u6570\u4e0d\u5339\u914d!\u5185\u5b58\u4e0e\u6a21\u578b\u6587\u4ef6\u7684\u591a\u5934\u6570\u91cf\u4e0d\u4e00\u81f4\uff01");
            }
            this.selfAttentions.get(i).insertModel(qkvModel);
        }
    }

    private void insertPower(float[][] modelPower, Matrix power) throws Exception {
        for (int i = 0; i < power.getX(); ++i) {
            for (int j = 0; j < power.getY(); ++j) {
                power.setNub(i, j, modelPower[i][j]);
            }
        }
    }

    public MultiSelfAttentionModel getModel() throws Exception {
        MultiSelfAttentionModel multiSelfAttentionModel = new MultiSelfAttentionModel();
        ArrayList<QKVModel> qkvModelList = new ArrayList<QKVModel>();
        for (SelfAttention selfAttention : this.selfAttentions) {
            qkvModelList.add(selfAttention.getModel());
        }
        multiSelfAttentionModel.setPowerModel(this.powerMatrix.getMatrix());
        multiSelfAttentionModel.setQkvModelList(qkvModelList);
        multiSelfAttentionModel.setDepth(this.depth);
        return multiSelfAttentionModel;
    }

    private void mergeFeatureMatrix(Matrix myMultiFeature, Matrix matrix, int index) throws Exception {
        int startY = this.wordVectorDimension * index;
        int endY = startY + this.wordVectorDimension;
        for (int i = 0; i < matrix.getX(); ++i) {
            for (int j = startY; j < endY; ++j) {
                myMultiFeature.setNub(i, j, matrix.getNumber(i, j - startY));
            }
        }
    }

    private List<Matrix> splitMatrix(Matrix subFeature) {
        ArrayList<Matrix> matrixList = new ArrayList<Matrix>();
        int maxDeep = subFeature.getX();
        for (int i = 0; i < this.selfAttentions.size(); ++i) {
            Matrix matrix = subFeature.getSonOfMatrix(0, i * this.wordVectorDimension, maxDeep, this.wordVectorDimension);
            matrixList.add(matrix);
        }
        return matrixList;
    }

    public void backError(Matrix allErrorMatrix, long eventID) throws Exception {
        Matrix error = this.matrixOperation.mathMulBySelf(allErrorMatrix, this.studyPoint);
        Matrix subPower = this.matrixOperation.matrixMulPd(error, this.featureMatrix, this.powerMatrix, false);
        Matrix subFeature = this.matrixOperation.matrixMulPd(allErrorMatrix, this.featureMatrix, this.powerMatrix, true);
        this.powerMatrix = this.matrixOperation.add(this.powerMatrix, subPower);
        List<Matrix> matrixList = this.splitMatrix(subFeature);
        Matrix allNextFeatureError = null;
        Matrix allLastEncoderError = null;
        for (int i = 0; i < this.selfAttentions.size(); ++i) {
            AttentionError attentionError = this.getSefAttentionBySelfID(i).backError(matrixList.get(i), eventID);
            Matrix nextFeatureError = attentionError.getNextFeatureError();
            allNextFeatureError = allNextFeatureError == null ? nextFeatureError : this.matrixOperation.add(allNextFeatureError, nextFeatureError);
            if (this.encoder || this.depth <= 1) continue;
            Matrix lastEncoderError = attentionError.getLastEncoderError();
            allLastEncoderError = allLastEncoderError == null ? lastEncoderError : this.matrixOperation.add(allLastEncoderError, lastEncoderError);
        }
        if (!this.encoder && this.depth > 1) {
            this.codecBlock.backLastEncoderError(allLastEncoderError);
        }
        if (this.codecBlock != null) {
            this.codecBlock.backCodecError(allNextFeatureError, eventID, allErrorMatrix);
        } else {
            this.transWordVector.backDecoderError(allNextFeatureError, allErrorMatrix);
        }
    }

    private SelfAttention getSefAttentionBySelfID(int selfID) {
        SelfAttention mySelfAttention = null;
        for (SelfAttention selfAttention : this.selfAttentions) {
            if (selfAttention.getSelfID() != selfID) continue;
            mySelfAttention = selfAttention;
            break;
        }
        return mySelfAttention;
    }

    private Matrix countMultiSelfAttention(List<EventBody> eventBodies, boolean isStudy) throws Exception {
        int one = this.wordVectorDimension * this.multiNumber;
        Matrix myMultiFeature = null;
        for (int i = 0; i < eventBodies.size(); ++i) {
            EventBody eventBody = this.getEventBodyBySelfID(i, eventBodies);
            Matrix matrix = eventBody.getFeatureMatrix();
            if (i == 0) {
                myMultiFeature = new Matrix(matrix.getX(), one);
            }
            this.mergeFeatureMatrix(myMultiFeature, matrix, i);
        }
        Matrix out = this.matrixOperation.mulMatrix(myMultiFeature, this.powerMatrix);
        if (isStudy) {
            this.featureMatrix = myMultiFeature;
        }
        return out;
    }

    private EventBody getEventBodyBySelfID(int selfID, List<EventBody> eventBodies) {
        EventBody eventBody = null;
        for (EventBody myEventBody : eventBodies) {
            if (myEventBody.getSelfID() != selfID) continue;
            eventBody = myEventBody;
            break;
        }
        return eventBody;
    }

    public void sendMatrixMessage(long eventID, Matrix feature, boolean isStudy, OutBack outBack, List<Integer> E, Matrix encoderFeature, boolean outAllPro) throws Exception {
        ArrayList<EventBody> eventBodies = new ArrayList<EventBody>();
        for (SelfAttention selfAttention : this.selfAttentions) {
            EventBody eventBody = selfAttention.sendMatrixFeature(eventID, isStudy, feature, encoderFeature);
            eventBodies.add(eventBody);
        }
        Matrix matrix = this.countMultiSelfAttention(eventBodies, isStudy);
        this.layNorm.addNorm(feature, matrix, eventID, isStudy, outBack, E, encoderFeature, outAllPro);
    }

    public MultiSelfAttention(int multiNumber, float studyPoint, int depth, int wordVectorDimension, boolean encoder, CodecBlock codecBlock, int coreNumber, TransWordVector transWordVector) throws Exception {
        Random random = new Random();
        this.matrixOperation = new MatrixOperation(coreNumber);
        this.transWordVector = transWordVector;
        this.codecBlock = codecBlock;
        this.encoder = encoder;
        int yiZhi = wordVectorDimension * multiNumber;
        this.studyPoint = studyPoint;
        this.wordVectorDimension = wordVectorDimension;
        this.multiNumber = multiNumber;
        this.depth = depth;
        for (int k = 0; k < multiNumber; ++k) {
            SelfAttention selfAttention = new SelfAttention(studyPoint, depth, wordVectorDimension, k, encoder, coreNumber);
            this.selfAttentions.add(selfAttention);
        }
        this.powerMatrix = new Matrix(yiZhi, wordVectorDimension);
        int x = this.powerMatrix.getX();
        int y = this.powerMatrix.getY();
        for (int i = 0; i < x; ++i) {
            for (int j = 0; j < y; ++j) {
                this.powerMatrix.setNub(i, j, random.nextFloat() / (float)yiZhi);
            }
        }
    }
}

