/*
 * Decompiled with CFR 0.152.
 */
package deepnetts.net;

import deepnetts.data.MLDataItem;
import deepnetts.eval.Evaluators;
import deepnetts.net.layers.AbstractLayer;
import deepnetts.net.layers.InputLayer;
import deepnetts.net.layers.OutputLayer;
import deepnetts.net.loss.BinaryCrossEntropyLoss;
import deepnetts.net.loss.CrossEntropyLoss;
import deepnetts.net.loss.LossFunction;
import deepnetts.net.loss.LossType;
import deepnetts.net.loss.MeanSquaredErrorLoss;
import deepnetts.net.train.Trainer;
import deepnetts.net.train.TrainerProvider;
import deepnetts.util.Tensor;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import javax.visrec.ml.data.DataSet;
import javax.visrec.ml.eval.EvaluationMetrics;

public class NeuralNetwork<T extends Trainer>
implements TrainerProvider<T>,
Serializable {
    private static final long serialVersionUID = 1L;
    private T trainer;
    private final List<AbstractLayer> layers = new ArrayList<AbstractLayer>();
    private LossFunction lossFunction;
    private InputLayer inputLayer;
    private OutputLayer outputLayer;
    private String[] outputLabels;
    private Tensor inputWrapper;
    private String label;
    private float regularizationSum = 0.0f;

    protected NeuralNetwork() {
    }

    public void setInput(Tensor inputs) {
        this.inputLayer.setInput(inputs);
        this.forward();
    }

    public float[] getOutput() {
        return this.outputLayer.getOutputs().getValues();
    }

    public void setOutputError(float[] outputErrors) {
        this.outputLayer.setOutputErrors(outputErrors);
    }

    public void train(DataSet<? extends MLDataItem> trainingSet) {
        this.trainer.train(trainingSet);
    }

    public EvaluationMetrics test(DataSet<MLDataItem> testSet) {
        if (this.getLossFunction() instanceof CrossEntropyLoss || this.getLossFunction() instanceof BinaryCrossEntropyLoss) {
            return Evaluators.evaluateClassifier(this, testSet);
        }
        return Evaluators.evaluateRegressor(this, testSet);
    }

    public void applyWeightChanges() {
        this.layers.forEach(layer -> layer.applyWeightChanges());
    }

    public void forward() {
        for (int i = 1; i < this.layers.size(); ++i) {
            this.layers.get(i).forward();
        }
    }

    public void backward() {
        for (int i = this.layers.size() - 1; i > 0; --i) {
            this.layers.get(i).backward();
        }
    }

    protected void addLayer(AbstractLayer layer) {
        this.layers.add(layer);
    }

    public List<AbstractLayer> getLayers() {
        return this.layers;
    }

    public InputLayer getInputLayer() {
        return this.inputLayer;
    }

    public OutputLayer getOutputLayer() {
        return this.outputLayer;
    }

    public void setOutputLabels(String ... outputLabels) {
        this.outputLabels = outputLabels;
    }

    public String[] getOutputLabels() {
        return this.outputLabels;
    }

    public String getOutputLabel(int i) {
        return this.outputLabels[i];
    }

    protected void setInputLayer(InputLayer inputLayer) {
        this.inputLayer = inputLayer;
    }

    protected void setOutputLayer(OutputLayer outputLayer) {
        this.outputLayer = outputLayer;
    }

    public LossFunction getLossFunction() {
        return this.lossFunction;
    }

    public void setLossFunction(LossFunction lossFunction) {
        this.lossFunction = lossFunction;
        if (lossFunction instanceof MeanSquaredErrorLoss) {
            this.outputLayer.setLossType(LossType.MEAN_SQUARED_ERROR);
        } else if (lossFunction instanceof CrossEntropyLoss || lossFunction instanceof BinaryCrossEntropyLoss) {
            this.outputLayer.setLossType(LossType.CROSS_ENTROPY);
        }
    }

    public String getLabel() {
        return this.label;
    }

    public void setLabel(String label) {
        this.label = label;
    }

    public float getL2Reg() {
        this.regularizationSum = 0.0f;
        for (int i = 1; i < this.layers.size(); ++i) {
            this.regularizationSum += this.layers.get(i).getL2();
        }
        return this.regularizationSum;
    }

    public float getL1Reg() {
        this.regularizationSum = 0.0f;
        for (int i = 1; i < this.layers.size(); ++i) {
            this.regularizationSum += this.layers.get(i).getL1();
        }
        return this.regularizationSum;
    }

    @Override
    public T getTrainer() {
        return this.trainer;
    }

    @Override
    public void setTrainer(T trainer) {
        this.trainer = trainer;
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        this.layers.stream().forEach(layer -> sb.append(layer.toString()).append(System.lineSeparator()));
        return sb.toString();
    }
}

