/*
 * Decompiled with CFR 0.152.
 */
package deepnetts.net;

import deepnetts.net.NeuralNetwork;
import deepnetts.net.layers.AbstractLayer;
import deepnetts.net.layers.FullyConnectedLayer;
import deepnetts.net.layers.InputLayer;
import deepnetts.net.layers.OutputLayer;
import deepnetts.net.layers.SoftmaxOutputLayer;
import deepnetts.net.layers.activation.ActivationType;
import deepnetts.net.loss.BinaryCrossEntropyLoss;
import deepnetts.net.loss.CrossEntropyLoss;
import deepnetts.net.loss.LossFunction;
import deepnetts.net.loss.LossType;
import deepnetts.net.loss.MeanSquaredErrorLoss;
import deepnetts.net.train.BackpropagationTrainer;
import deepnetts.util.RandomGenerator;
import deepnetts.util.Tensor;

public final class FeedForwardNetwork
extends NeuralNetwork<BackpropagationTrainer> {
    private Tensor inputTensor;

    private FeedForwardNetwork() {
        this.setTrainer(new BackpropagationTrainer(this));
    }

    public void setInput(float[] inputs) {
        this.inputTensor.setValues(inputs);
        this.setInput(this.inputTensor);
    }

    public float[] predict(float[] inputs) {
        this.setInput(inputs);
        return this.getOutput();
    }

    public static Builder builder() {
        return new Builder();
    }

    public static class Builder {
        private final FeedForwardNetwork network = new FeedForwardNetwork();
        private ActivationType defaultActivationType = ActivationType.TANH;
        private boolean setDefaultActivation = false;

        public Builder addInputLayer(int width) {
            InputLayer inLayer = new InputLayer(width);
            this.network.addLayer(inLayer);
            this.network.setInputLayer(inLayer);
            this.network.inputTensor = new Tensor(width);
            return this;
        }

        public Builder addFullyConnectedLayer(int width) {
            FullyConnectedLayer layer = new FullyConnectedLayer(width);
            this.network.addLayer(layer);
            return this;
        }

        public Builder addFullyConnectedLayers(int ... widths) {
            for (int width : widths) {
                FullyConnectedLayer layer = new FullyConnectedLayer(width);
                this.network.addLayer(layer);
            }
            return this;
        }

        public Builder addFullyConnectedLayer(int width, ActivationType activationType) {
            FullyConnectedLayer layer = new FullyConnectedLayer(width, activationType);
            this.network.addLayer(layer);
            return this;
        }

        public Builder addFullyConnectedLayers(ActivationType activationType, int ... widths) {
            for (int width : widths) {
                FullyConnectedLayer layer = new FullyConnectedLayer(width, activationType);
                this.network.addLayer(layer);
            }
            return this;
        }

        public Builder addLayer(AbstractLayer layer) {
            this.network.addLayer(layer);
            return this;
        }

        public Builder addOutputLayer(int width, ActivationType activationType) {
            OutputLayer outputLayer = null;
            outputLayer = activationType.equals((Object)ActivationType.SOFTMAX) ? new SoftmaxOutputLayer(width) : new OutputLayer(width, activationType);
            this.network.setOutputLayer(outputLayer);
            this.network.addLayer(outputLayer);
            return this;
        }

        public Builder hiddenActivationFunction(ActivationType activationType) {
            this.defaultActivationType = activationType;
            this.setDefaultActivation = true;
            return this;
        }

        public Builder lossFunction(LossType lossType) {
            LossFunction loss = null;
            switch (lossType) {
                case MEAN_SQUARED_ERROR: {
                    loss = new MeanSquaredErrorLoss(this.network);
                    break;
                }
                case CROSS_ENTROPY: {
                    loss = this.network.getOutputLayer().getWidth() == 1 ? new BinaryCrossEntropyLoss(this.network) : new CrossEntropyLoss(this.network);
                }
            }
            this.network.setLossFunction(loss);
            return this;
        }

        public Builder randomSeed(long seed) {
            RandomGenerator.getDefault().initSeed(seed);
            return this;
        }

        public FeedForwardNetwork build() {
            AbstractLayer prevLayer = null;
            for (int i = 0; i < this.network.getLayers().size(); ++i) {
                AbstractLayer layer = this.network.getLayers().get(i);
                if (this.setDefaultActivation && !(layer instanceof InputLayer) && !(layer instanceof OutputLayer)) {
                    layer.setActivationType(this.defaultActivationType);
                }
                layer.setPrevLayer(prevLayer);
                if (prevLayer != null) {
                    prevLayer.setNextlayer(layer);
                }
                prevLayer = layer;
            }
            for (AbstractLayer layer : this.network.getLayers()) {
                layer.init();
            }
            return this.network;
        }
    }
}

