/*
 * Decompiled with CFR 0.152.
 */
package org.dromara.easyai.nerveEntity;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;
import org.dromara.easyai.conv.ConvCount;
import org.dromara.easyai.entity.ThreeChannelMatrix;
import org.dromara.easyai.i.ActiveFunction;
import org.dromara.easyai.i.OutBack;
import org.dromara.easyai.matrixTools.Matrix;
import org.dromara.easyai.matrixTools.MatrixOperation;
import org.dromara.easyai.nerveEntity.ConvParameter;
import org.dromara.easyai.nerveEntity.ConvSize;

public abstract class Nerve
extends ConvCount {
    private final List<Nerve> son = new ArrayList<Nerve>();
    private final List<Nerve> father = new ArrayList<Nerve>();
    private Nerve sonOnly;
    private Nerve fatherOnly;
    protected Map<Integer, Float> dendrites = new HashMap<Integer, Float>();
    protected Map<Integer, Float> wg = new HashMap<Integer, Float>();
    private final int id;
    protected int upNub;
    protected int downNub;
    protected Map<Long, List<Float>> features = new HashMap<Long, List<Float>>();
    protected float threshold;
    protected String name;
    protected float outNub;
    protected float E;
    protected float gradient;
    protected float studyPoint;
    protected float sigmaW;
    protected List<Matrix> sigmaMatrix;
    private int backNub = 0;
    protected ActiveFunction activeFunction;
    private final int rzType;
    private final float lParam;
    private final int kernLen;
    protected final int depth;
    protected final int matrixX;
    protected final int matrixY;
    private final MatrixOperation matrixOperation;
    protected final int channelNo;
    private final ConvParameter convParameter = new ConvParameter();
    protected final float oneConvRate;
    private final boolean norm;

    public Map<Integer, Float> getDendrites() {
        return this.dendrites;
    }

    public ConvParameter getConvParameter() {
        return this.convParameter;
    }

    public void setDendrites(Map<Integer, Float> dendrites) {
        this.dendrites = dendrites;
    }

    public float getThreshold() {
        return this.threshold;
    }

    public void setThreshold(float threshold) {
        this.threshold = threshold;
    }

    protected Nerve(int id, int upNub, String name, int downNub, float studyPoint, boolean init, ActiveFunction activeFunction, boolean isDynamic, int rzType, float lParam, int kernLen, int depth, int matrixX, int matrixY, int coreNumber, int channelNo, float onConvRate, boolean norm) throws Exception {
        this.matrixOperation = new MatrixOperation(coreNumber);
        this.matrixX = matrixX;
        this.norm = norm;
        this.matrixY = matrixY;
        this.channelNo = channelNo;
        this.id = id;
        this.depth = depth;
        this.upNub = upNub;
        this.name = name;
        this.downNub = downNub;
        this.studyPoint = studyPoint;
        this.activeFunction = activeFunction;
        this.rzType = rzType;
        this.lParam = lParam;
        this.kernLen = kernLen;
        this.oneConvRate = onConvRate;
        this.initPower(init, isDynamic);
    }

    protected void setStudyPoint(float studyPoint) {
        this.studyPoint = studyPoint;
    }

    public void sendMessage(long eventId, float parameter, boolean isStudy, Map<Integer, Float> E, OutBack outBack) throws Exception {
        if (!this.son.isEmpty()) {
            for (Nerve nerve : this.son) {
                nerve.input(eventId, parameter, isStudy, E, outBack);
            }
        } else {
            throw new Exception("this layer is lastIndex");
        }
    }

    protected List<Matrix> conv(List<Matrix> matrix) throws Exception {
        return this.downConvAndPooling(matrix, this.convParameter, this.channelNo, this.activeFunction, this.kernLen, true, -1L);
    }

    protected void demRedByMatrixList(long eventId, List<Matrix> matrixList, boolean study, Map<Integer, Float> E, OutBack outBack, boolean needMatrix) throws Exception {
        List<Matrix> feature;
        if (study) {
            this.convParameter.setFeatureMatrixList(matrixList);
        }
        if (this.norm) {
            feature = this.manyOneConv(matrixList, this.convParameter.getOneConvPower());
        } else {
            if (matrixList.size() != 3) {
                throw new Exception("\u4e0d\u8fdb\u884c\u7ef4\u5ea6\u8c03\u8282\uff0c\u8f93\u5165\u7684\u7279\u5f81\u77e9\u9635\u901a\u9053\u6570\u5fc5\u987b\u4e3a3");
            }
            feature = matrixList;
        }
        List<Matrix> convMatrix = this.conv(feature);
        this.sendMatrix(eventId, convMatrix, study, E, outBack, needMatrix);
    }

    public void sendMatrixList(long eventId, List<Float> parameter, boolean isStudy, Map<Integer, Float> E, OutBack outBack) throws Exception {
        if (!this.son.isEmpty()) {
            for (Nerve nerve : this.son) {
                nerve.inputMatrixFeature(eventId, parameter, isStudy, E, outBack);
            }
        } else {
            throw new Exception("this layer is lastIndex");
        }
    }

    public void sendMatrix(long eventId, List<Matrix> parameter, boolean isStudy, Map<Integer, Float> E, OutBack outBack, boolean needMatrix) throws Exception {
        if (this.sonOnly == null) {
            throw new Exception("this layer is lastIndex");
        }
        this.sonOnly.inputMatrix(eventId, parameter, isStudy, E, outBack, needMatrix);
    }

    public void sendThreeChannelMatrix(long eventId, ThreeChannelMatrix parameter, boolean isStudy, Map<Integer, Float> E, OutBack outBack, boolean needMatrix) throws Exception {
        if (this.sonOnly == null) {
            throw new Exception("this layer is lastIndex");
        }
        this.sonOnly.inputThreeChannelMatrix(eventId, parameter, isStudy, E, outBack, needMatrix);
    }

    public void sendListMatrix(long eventId, List<Matrix> parameter, boolean isStudy, Map<Integer, Float> E, OutBack outBack, boolean needMatrix) throws Exception {
        if (this.sonOnly == null) {
            throw new Exception("this layer is lastIndex");
        }
        this.sonOnly.demRedByMatrixList(eventId, parameter, isStudy, E, outBack, needMatrix);
    }

    private void backSendMessage(long eventId) throws Exception {
        if (!this.father.isEmpty()) {
            for (int i = 0; i < this.father.size(); ++i) {
                this.father.get(i).backGetMessage(this.wg.get(i + 1).floatValue(), eventId);
            }
        } else if (this.fatherOnly != null && this.depth == 1) {
            ArrayList<Matrix> errorMatrixList = new ArrayList<Matrix>();
            int size = this.matrixX * this.matrixY;
            int featureSize = this.wg.size() / size;
            for (int i = 0; i < featureSize; ++i) {
                ArrayList<Float> list = new ArrayList<Float>();
                int startIndex = size * i;
                int endIndex = startIndex + size;
                for (int j = startIndex; j < endIndex; ++j) {
                    list.add(this.wg.get(j + 1));
                }
                Matrix errorMatrix = this.matrixOperation.ListToMatrix(list, this.matrixX, this.matrixY);
                errorMatrixList.add(errorMatrix);
            }
            this.fatherOnly.backMatrix(errorMatrixList);
        }
    }

    private void backMatrixMessage(List<Matrix> g) throws Exception {
        if (this.fatherOnly != null) {
            this.fatherOnly.backMatrix(g);
        }
    }

    protected void input(long eventId, float parameter, boolean isStudy, Map<Integer, Float> E, OutBack imageBack) throws Exception {
    }

    protected void inputMatrixFeature(long eventId, List<Float> parameters, boolean isStudy, Map<Integer, Float> E, OutBack imageBack) throws Exception {
    }

    protected void inputMatrix(long eventId, List<Matrix> matrix, boolean isKernelStudy, Map<Integer, Float> E, OutBack outBack, boolean needMatrix) throws Exception {
    }

    protected void inputThreeChannelMatrix(long eventId, ThreeChannelMatrix picture, boolean isKernelStudy, Map<Integer, Float> E, OutBack outBack, boolean needMatrix) throws Exception {
    }

    private void backGetMessage(float parameter, long eventId) throws Exception {
        ++this.backNub;
        this.sigmaW += parameter;
        if (this.backNub == this.downNub) {
            this.backNub = 0;
            this.gradient = this.activeFunction.functionG(this.outNub) * this.sigmaW;
            this.updatePower(eventId);
        }
    }

    protected void backMatrix(List<Matrix> t) throws Exception {
        ++this.backNub;
        this.sigmaMatrix = this.sigmaMatrix == null ? t : this.matrixOperation.addMatrixList(t, this.sigmaMatrix);
        if (this.backNub == this.downNub) {
            this.backNub = 0;
            List<Matrix> errorMatrix = this.backDownPoolingByList(this.sigmaMatrix, this.convParameter.getOutX(), this.convParameter.getOutY());
            List<Matrix> myErrorMatrix = this.backAllDownConv(this.convParameter, errorMatrix, this.studyPoint, this.activeFunction, this.channelNo, this.kernLen);
            this.sigmaMatrix = null;
            if (this.depth == 1) {
                if (this.norm) {
                    this.backOneConvByList(myErrorMatrix, this.convParameter.getFeatureMatrixList(), this.convParameter.getOneConvPower(), this.oneConvRate, false);
                }
            } else {
                this.backMatrixMessage(myErrorMatrix);
            }
        }
    }

    protected void updatePower(long eventId) throws Exception {
        float h = this.gradient * this.studyPoint;
        this.threshold -= h;
        this.updateW(h, eventId);
        this.sigmaW = 0.0f;
        this.backSendMessage(eventId);
    }

    private float regularization(float w, float param) {
        float re = 0.0f;
        if (this.rzType != 0) {
            if (this.rzType == 2) {
                re = param * -w;
            } else if (this.rzType == 1) {
                if (w > 0.0f) {
                    re = -param;
                } else if (w < 0.0f) {
                    re = param;
                }
            }
        }
        return re;
    }

    private void updateW(float h, long eventId) {
        List<Float> list = this.features.get(eventId);
        float param = 0.0f;
        if (this.rzType != 0) {
            float sigma = 0.0f;
            for (Map.Entry<Integer, Float> entry : this.dendrites.entrySet()) {
                if (this.rzType == 2) {
                    sigma += (float)Math.pow(entry.getValue().floatValue(), 2.0);
                    continue;
                }
                sigma += Math.abs(entry.getValue().floatValue());
            }
            param = sigma * this.lParam * this.studyPoint;
        }
        for (Map.Entry<Integer, Float> entry : this.dendrites.entrySet()) {
            int key = entry.getKey();
            float w = entry.getValue().floatValue();
            float bn = list.get(key - 1).floatValue();
            float wp = bn * h;
            float dm = w * this.gradient;
            float regular = this.regularization(w, param);
            w += regular;
            this.wg.put(key, Float.valueOf(dm));
            this.dendrites.put(key, Float.valueOf(w += wp));
        }
        this.features.remove(eventId);
    }

    protected void insertParameters(long eventId, List<Float> parameters) {
        List<Object> featuresList;
        if (this.features.containsKey(eventId)) {
            featuresList = this.features.get(eventId);
        } else {
            featuresList = new ArrayList();
            this.features.put(eventId, featuresList);
        }
        featuresList.addAll(parameters);
    }

    protected boolean insertParameter(long eventId, float parameter) {
        List<Object> featuresList;
        boolean allReady = false;
        if (this.features.containsKey(eventId)) {
            featuresList = this.features.get(eventId);
        } else {
            featuresList = new ArrayList();
            this.features.put(eventId, featuresList);
        }
        featuresList.add(Float.valueOf(parameter));
        if (featuresList.size() >= this.upNub) {
            allReady = true;
        }
        return allReady;
    }

    protected void destoryParameter(long eventId) {
        this.features.remove(eventId);
    }

    protected float calculation(long eventId) throws Exception {
        float sigma = 0.0f;
        List<Float> featuresList = this.features.get(eventId);
        if (this.dendrites.size() != featuresList.size()) {
            throw new Exception("\u6743\u91cd\u6570\u91cf:" + this.dendrites.size() + ",\u7279\u5f81\u6570\u91cf:" + featuresList.size());
        }
        for (int i = 0; i < featuresList.size(); ++i) {
            float value = featuresList.get(i).floatValue();
            float w = this.dendrites.get(i + 1).floatValue();
            sigma = w * value + sigma;
        }
        return sigma - this.threshold;
    }

    private void initPower(boolean init, boolean isDynamic) throws Exception {
        Random random = new Random();
        if (!isDynamic) {
            if (this.upNub > 0) {
                for (int i = 1; i < this.upNub + 1; ++i) {
                    float nub = 0.0f;
                    if (init) {
                        nub = random.nextFloat() / (float)Math.sqrt(this.upNub);
                    }
                    this.dendrites.put(i, Float.valueOf(nub));
                }
                float nub = 0.0f;
                if (init) {
                    nub = random.nextFloat() / (float)Math.sqrt(this.upNub);
                }
                this.threshold = nub;
            }
        } else {
            this.initMatrixPower(random);
        }
    }

    private void initMatrixPower(Random random) throws Exception {
        int nerveNub = this.kernLen * this.kernLen;
        List<Matrix> nerveMatrixList = this.convParameter.getNerveMatrixList();
        List<ConvSize> convSizeList = this.convParameter.getConvSizeList();
        ArrayList<List<Float>> onePowers = new ArrayList<List<Float>>();
        for (int k = 0; k < this.channelNo; ++k) {
            Matrix nerveMatrix = new Matrix(nerveNub, 1);
            convSizeList.add(new ConvSize());
            for (int i = 0; i < nerveMatrix.getX(); ++i) {
                float nub = random.nextFloat() / (float)this.kernLen;
                nerveMatrix.setNub(i, 0, nub);
            }
            nerveMatrixList.add(nerveMatrix);
            if (this.depth != 1) continue;
            ArrayList<Float> oneConvPowerList = new ArrayList<Float>();
            for (int i = 0; i < 3; ++i) {
                oneConvPowerList.add(Float.valueOf(random.nextFloat() / 3.0f));
            }
            onePowers.add(oneConvPowerList);
        }
        if (this.depth == 1) {
            this.convParameter.setOneConvPower(onePowers);
        }
    }

    public int getId() {
        return this.id;
    }

    public void connect(List<Nerve> nerveList) {
        this.son.addAll(nerveList);
    }

    public void connectSonOnly(Nerve nerve) {
        this.sonOnly = nerve;
    }

    public void connectFatherOnly(Nerve nerve) {
        this.fatherOnly = nerve;
    }

    public void connectFather(List<Nerve> nerveList) {
        this.father.addAll(nerveList);
    }
}

