/*
 * Decompiled with CFR 0.152.
 */
package org.dromara.easyai.transFormer.seflAttention;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;
import org.dromara.easyai.i.OutBack;
import org.dromara.easyai.matrixTools.Matrix;
import org.dromara.easyai.matrixTools.MatrixList;
import org.dromara.easyai.matrixTools.MatrixOperation;
import org.dromara.easyai.transFormer.CodecBlock;
import org.dromara.easyai.transFormer.FirstDecoderBlock;
import org.dromara.easyai.transFormer.model.LayNormModel;
import org.dromara.easyai.transFormer.nerve.HiddenNerve;
import org.dromara.easyai.transFormer.seflAttention.MultiSelfAttention;

public class LayNorm {
    private MultiSelfAttention multiSelfAttention;
    private final CodecBlock myEncoderBlock;
    private final int featureDimension;
    private List<HiddenNerve> hiddenNerves;
    private final int type;
    private final Map<Long, MatrixList> reMatrixMap = new HashMap<Long, MatrixList>();
    private final FirstDecoderBlock firstDecoderBlock;
    private Matrix bTa;
    private Matrix power;
    private Matrix myNormData;
    private final float study;
    private Matrix myFinalError;
    private int number;
    private final MatrixOperation matrixOperation;
    private final boolean encoder;
    private final int depth;

    public LayNormModel getModel() throws Exception {
        LayNormModel layNormModel = new LayNormModel();
        layNormModel.setbTa(this.bTa.getMatrix());
        layNormModel.setPower(this.power.getMatrix());
        return layNormModel;
    }

    public void insertModel(LayNormModel layNormModel) throws Exception {
        this.insertPower(layNormModel.getPower(), this.power);
        this.insertPower(layNormModel.getbTa(), this.bTa);
    }

    private void insertPower(float[][] modelPower, Matrix power) throws Exception {
        for (int i = 0; i < power.getX(); ++i) {
            for (int j = 0; j < power.getY(); ++j) {
                power.setNub(i, j, modelPower[i][j]);
            }
        }
    }

    public LayNorm(int type, int featureDimension, CodecBlock myEncoderBlock, FirstDecoderBlock firstDecoderBlock, float study, int coreNumber, boolean encoder, int depth) throws Exception {
        int i;
        this.study = study;
        this.myEncoderBlock = myEncoderBlock;
        this.encoder = encoder;
        this.depth = depth;
        this.type = type;
        this.featureDimension = featureDimension;
        this.firstDecoderBlock = firstDecoderBlock;
        this.matrixOperation = new MatrixOperation(coreNumber);
        this.bTa = new Matrix(1, featureDimension);
        this.power = new Matrix(featureDimension, featureDimension);
        Random random = new Random();
        float sh = 1.0f;
        if (!encoder && depth == 1) {
            sh = featureDimension * featureDimension;
        }
        for (i = 0; i < featureDimension; ++i) {
            float value = random.nextFloat() / sh;
            this.bTa.setNub(0, i, value);
        }
        for (i = 0; i < featureDimension; ++i) {
            for (int j = 0; j < featureDimension; ++j) {
                float value = random.nextFloat() / sh;
                this.power.setNub(i, j, value);
            }
        }
    }

    private Matrix back(Matrix errorMatrix, Matrix myData) throws Exception {
        Matrix subPower = this.matrixOperation.matrixMulPd(errorMatrix, myData, this.power, false);
        Matrix sub = this.matrixOperation.matrixMulPd(errorMatrix, myData, this.power, true);
        this.power = this.matrixOperation.add(subPower, this.power);
        float n = (float)Math.sqrt(sub.getY());
        float nt = -n / (n - 1.0f);
        Matrix subMatrix = new Matrix(1, sub.getY());
        for (int i = 0; i < sub.getY(); ++i) {
            float subValue = sub.getNumber(0, i) * this.study;
            float value = subValue * n + subMatrix.getNumber(0, i);
            subMatrix.setNub(0, i, value);
            for (int j = 0; j < sub.getY(); ++j) {
                if (i == j) continue;
                float otherValue = subValue * nt + subMatrix.getNumber(0, j);
                subMatrix.setNub(0, j, otherValue);
            }
        }
        return subMatrix;
    }

    public void backErrorFromFNN(Matrix errorMatrix, long eventID, Matrix allError) throws Exception {
        ++this.number;
        this.myFinalError = this.myFinalError == null ? errorMatrix : this.matrixOperation.add(this.myFinalError, errorMatrix);
        if (this.number == this.featureDimension) {
            this.number = 0;
            Matrix error = this.myFinalError.getSonOfMatrix(0, 0, this.myFinalError.getX(), this.myFinalError.getY() - 1);
            this.myFinalError = null;
            Matrix myError = this.matrixOperation.add(error, allError);
            this.backErrorFromLine(myError, eventID);
        }
    }

    public void backLastError(Matrix errorMatrix) throws Exception {
        this.myFinalError = this.myFinalError == null ? errorMatrix : this.matrixOperation.add(this.myFinalError, errorMatrix);
    }

    public void encoderBackStart(long eventID) throws Exception {
        Matrix error = this.myFinalError.copy();
        this.myFinalError = null;
        this.backErrorFromLine(error, eventID);
    }

    public void backErrorFromLine(Matrix errorMatrix, long eventID) throws Exception {
        this.matrixOperation.mathMul(errorMatrix, this.study);
        int x = errorMatrix.getX();
        MatrixList errorMatrixList = null;
        for (int i = 0; i < x; ++i) {
            Matrix error = errorMatrix.getRow(i);
            Matrix myData = this.myNormData.getRow(i);
            this.bTa = this.matrixOperation.add(error, this.bTa);
            Matrix myRowError = this.back(error, myData);
            if (i == 0) {
                errorMatrixList = new MatrixList(myRowError, true);
                continue;
            }
            errorMatrixList.add(myRowError);
        }
        Matrix myError = errorMatrixList.getMatrix();
        if (this.type == 2) {
            int size = this.hiddenNerves.size();
            for (int i = 0; i < size; ++i) {
                this.hiddenNerves.get(i).receiveErrorMatrix(myError.getColumn(i), eventID, myError);
            }
        } else {
            this.multiSelfAttention.backError(myError, eventID);
        }
    }

    public void addNorm(Matrix feature, Matrix outMatrix, long eventID, boolean isStudy, OutBack outBack, List<Integer> E, Matrix encoderFeature, boolean outAllPro) throws Exception {
        Matrix myMatrix = this.matrixOperation.add(feature, outMatrix);
        Matrix out = this.layNorm(myMatrix, isStudy);
        if (this.type == 1) {
            if (this.myEncoderBlock != null) {
                this.sendHiddenParameter(out, eventID, isStudy, outBack, E, encoderFeature, outAllPro);
            } else if (this.firstDecoderBlock != null) {
                this.firstDecoderBlock.sendOutputMatrix(eventID, out, isStudy, outBack, E, outAllPro);
            }
        } else {
            this.myEncoderBlock.sendOutputMatrix(eventID, out, isStudy, outBack, E, encoderFeature, outAllPro);
        }
    }

    public void addNormFromNerve(long eventID, boolean isStudy, Matrix parameter, Matrix allFeature, OutBack outBack, List<Integer> E, Matrix encoderFeature, boolean outAllPro) throws Exception {
        MatrixList matrixFeature;
        if (this.reMatrixMap.containsKey(eventID)) {
            matrixFeature = this.reMatrixMap.get(eventID);
            matrixFeature.add(parameter);
        } else {
            matrixFeature = new MatrixList(parameter, false);
            this.reMatrixMap.put(eventID, matrixFeature);
        }
        if (matrixFeature.getY() == this.featureDimension) {
            this.reMatrixMap.remove(eventID);
            this.addNorm(matrixFeature.getMatrix(), allFeature, eventID, isStudy, outBack, E, encoderFeature, outAllPro);
        }
    }

    private void sendHiddenParameter(Matrix feature, long eventId, boolean isStudy, OutBack outBack, List<Integer> E, Matrix encoderFeature, boolean outAllPro) throws Exception {
        for (HiddenNerve hiddenNerve : this.hiddenNerves) {
            hiddenNerve.receive(feature, eventId, isStudy, outBack, E, encoderFeature, outAllPro);
        }
    }

    private Matrix norm(Matrix row) throws Exception {
        Matrix result = new Matrix(1, row.getY());
        float avg = row.getAVG();
        float sd = this.matrixOperation.getSdByMatrix(row, avg, 1.0E-7f);
        for (int i = 0; i < row.getY(); ++i) {
            float value = (row.getNumber(0, i) - avg) / sd;
            result.setNub(0, i, value);
        }
        return result;
    }

    private Matrix layNorm(Matrix feature, boolean isStudy) throws Exception {
        int x = feature.getX();
        MatrixList normMatrixList = null;
        MatrixList outMatrixList = null;
        for (int i = 0; i < x; ++i) {
            Matrix normData = this.norm(feature.getRow(i));
            if (isStudy) {
                if (i == 0) {
                    normMatrixList = new MatrixList(normData, true);
                } else {
                    normMatrixList.add(normData);
                }
            }
            Matrix want = this.matrixOperation.add(this.matrixOperation.mulMatrix(normData, this.power), this.bTa);
            if (i == 0) {
                outMatrixList = new MatrixList(want, true);
                continue;
            }
            outMatrixList.add(want);
        }
        if (isStudy) {
            this.myNormData = normMatrixList.getMatrix();
        }
        return outMatrixList.getMatrix();
    }

    public void setHiddenNerves(List<HiddenNerve> hiddenNerves) {
        this.hiddenNerves = hiddenNerves;
    }

    public void setMultiSelfAttention(MultiSelfAttention multiSelfAttention) {
        this.multiSelfAttention = multiSelfAttention;
    }
}

