/*
 * Decompiled with CFR 0.152.
 */
package deepnetts.net.train;

import deepnetts.core.DeepNetts;
import deepnetts.data.MLDataItem;
import deepnetts.data.TabularDataSet;
import deepnetts.eval.ClassifierEvaluator;
import deepnetts.net.NeuralNetwork;
import deepnetts.net.layers.AbstractLayer;
import deepnetts.net.loss.LossFunction;
import deepnetts.net.train.Trainer;
import deepnetts.net.train.TrainingEvent;
import deepnetts.net.train.TrainingListener;
import deepnetts.net.train.opt.OptimizerType;
import deepnetts.util.FileIO;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.Properties;
import javax.visrec.ml.data.DataSet;
import javax.visrec.ml.eval.EvaluationMetrics;
import javax.visrec.ml.eval.Evaluator;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

public class BackpropagationTrainer
implements Trainer,
Serializable {
    private long maxEpochs = 100000L;
    private float maxError = 0.01f;
    private float learningRate = 0.01f;
    private OptimizerType optType = OptimizerType.SGD;
    private float momentum = 0.0f;
    private boolean batchMode = false;
    private int batchSize;
    private boolean stopTraining = false;
    private int epoch;
    private float valLoss = 0.0f;
    private float prevValLoss = 0.0f;
    private float trainAccuracy = 0.0f;
    private float valAccuracy = 0.0f;
    private float totalTrainingLoss;
    private boolean shuffle = false;
    private NeuralNetwork<?> neuralNet;
    private transient DataSet<? extends MLDataItem> trainingSet;
    private transient DataSet<? extends MLDataItem> validationSet;
    private LossFunction lossFunction;
    private boolean trainingSnapshots = false;
    private int snapshotEpochs = 5;
    private String snapshotPath = "";
    private boolean earlyStopping = false;
    private int checkpointEpochs = 1;
    private float earlyStoppingMinDelta = 1.0E-6f;
    private int earlyStoppingPatience = 2;
    private int earlyStoppingCheckpointCount = 0;
    private float prevCheckpointTestLoss = 100.0f;
    private transient Evaluator<NeuralNetwork, DataSet<? extends MLDataItem>> eval = new ClassifierEvaluator();
    private float regL2;
    private float regL1;
    private transient List<TrainingListener> listeners = new ArrayList<TrainingListener>();
    private static final Logger LOGGER = LogManager.getLogger((String)DeepNetts.class.getName());
    public static final String PROP_MAX_ERROR = "maxError";
    public static final String PROP_MAX_EPOCHS = "maxEpochs";
    public static final String PROP_LEARNING_RATE = "learningRate";
    public static final String PROP_MOMENTUM = "momentum";
    public static final String PROP_BATCH_MODE = "batchMode";
    public static final String PROP_BATCH_SIZE = "batchSize";
    public static final String PROP_OPTIMIZER_TYPE = "optimizer";

    public BackpropagationTrainer(NeuralNetwork neuralNet) {
        this.neuralNet = neuralNet;
    }

    public BackpropagationTrainer(Properties prop) {
        this.maxError = Float.parseFloat(prop.getProperty(PROP_MAX_ERROR));
        this.maxEpochs = Integer.parseInt(prop.getProperty(PROP_MAX_EPOCHS));
        this.learningRate = Float.parseFloat(prop.getProperty(PROP_LEARNING_RATE));
        this.momentum = Float.parseFloat(prop.getProperty(PROP_MOMENTUM));
        this.batchMode = Boolean.parseBoolean(prop.getProperty(PROP_BATCH_MODE));
        this.batchSize = Integer.parseInt(prop.getProperty(PROP_BATCH_SIZE));
    }

    public void train(DataSet<MLDataItem> trainingSet, DataSet<MLDataItem> validationSet) {
        this.validationSet = validationSet;
        this.train(trainingSet);
    }

    public void train(DataSet<?> trainingSet, double valPart) {
        DataSet[] trainValSets = trainingSet.split(new double[]{1.0 - valPart, valPart});
        this.validationSet = trainValSets[1];
        this.train((DataSet<? extends MLDataItem>)trainValSets[0]);
    }

    @Override
    public void train(DataSet<? extends MLDataItem> trainingSet) {
        if (trainingSet == null) {
            throw new IllegalArgumentException("Argument trainingSet cannot be null!");
        }
        if (trainingSet.size() == 0) {
            throw new IllegalArgumentException("Training set cannot be empty!");
        }
        this.trainingSet = trainingSet;
        this.neuralNet.setOutputLabels(((TabularDataSet)trainingSet).getTargetNames());
        int trainingSamplesCount = trainingSet.size();
        this.stopTraining = false;
        if (this.batchMode && this.batchSize == 0) {
            this.batchSize = trainingSamplesCount;
        }
        for (AbstractLayer layer : this.neuralNet.getLayers()) {
            layer.setLearningRate(this.learningRate);
            layer.setMomentum(this.momentum);
            layer.setRegularization(this.regL2);
            layer.setBatchMode(this.batchMode);
            layer.setBatchSize(this.batchSize);
            layer.setOptimizerType(this.optType);
        }
        this.lossFunction = this.neuralNet.getLossFunction();
        this.epoch = 0;
        this.totalTrainingLoss = 0.0f;
        float prevTotalLoss = 0.0f;
        LOGGER.info("------------------------------------------------------------------------------------------------------------------------------------------------");
        LOGGER.info("TRAINING NEURAL NETWORK");
        LOGGER.info("------------------------------------------------------------------------------------------------------------------------------------------------");
        this.fireTrainingEvent(TrainingEvent.Type.STARTED);
        long startTraining = System.currentTimeMillis();
        do {
            ++this.epoch;
            this.lossFunction.reset();
            this.valLoss = 0.0f;
            this.trainAccuracy = 0.0f;
            this.valAccuracy = 0.0f;
            if (this.shuffle) {
                trainingSet.shuffle();
            }
            int sampleCounter = 0;
            long startEpoch = System.currentTimeMillis();
            for (MLDataItem dataSetItem : trainingSet) {
                ++sampleCounter;
                this.neuralNet.setInput(dataSetItem.getInput());
                float[] outputError = this.lossFunction.addPatternError(this.neuralNet.getOutput(), dataSetItem.getTargetOutput().getValues());
                this.neuralNet.setOutputError(outputError);
                this.neuralNet.backward();
                if (!this.isBatchMode()) {
                    this.neuralNet.applyWeightChanges();
                } else if (sampleCounter % this.batchSize == 0) {
                    this.neuralNet.applyWeightChanges();
                    float miniBatchError = this.lossFunction.getTotal();
                    LOGGER.info("Epoch:" + this.epoch + ", Mini Batch:" + sampleCounter / this.batchSize + ", Batch Loss:" + miniBatchError);
                }
                this.fireTrainingEvent(TrainingEvent.Type.ITERATION_FINISHED);
                if (!this.stopTraining) continue;
                break;
            }
            if (this.regL2 != 0.0f) {
                this.lossFunction.addRegularizationSum(this.regL2 * this.neuralNet.getL2Reg());
            }
            long endEpoch = System.currentTimeMillis();
            if (this.isBatchMode() && trainingSamplesCount % this.batchSize != 0) {
                this.neuralNet.applyWeightChanges();
            }
            this.totalTrainingLoss = this.lossFunction.getTotal();
            float totalLossChange = this.totalTrainingLoss - prevTotalLoss;
            prevTotalLoss = this.totalTrainingLoss;
            this.trainAccuracy = this.calculateAccuracy(this.trainingSet);
            if (this.validationSet != null) {
                this.prevValLoss = this.valLoss;
                this.valLoss = this.validationLoss(this.validationSet);
                this.valAccuracy = this.calculateAccuracy(this.validationSet);
            }
            long epochTime = endEpoch - startEpoch;
            if (this.validationSet != null) {
                LOGGER.info("Epoch:" + this.epoch + ", Time:" + epochTime + "ms, TrainError:" + this.totalTrainingLoss + ", TrainErrorChange:" + totalLossChange + ", TrainAccuracy: " + this.trainAccuracy + ", ValError:" + this.valLoss + ", ValAccuracy: " + this.valAccuracy);
            } else {
                LOGGER.info("Epoch:" + this.epoch + ", Time:" + epochTime + "ms, TrainError:" + this.totalTrainingLoss + ", TrainErrorChange:" + totalLossChange + ", TrainAccuracy: " + this.trainAccuracy);
            }
            if (Float.isNaN(this.totalTrainingLoss)) {
                this.stopTraining = true;
                LOGGER.info("The training was interrupted due to NaN value before completing all Epochs. Epochs completed: " + this.epoch + "/" + this.maxEpochs);
            }
            this.fireTrainingEvent(TrainingEvent.Type.EPOCH_FINISHED);
            if (this.earlyStopping && this.epoch > 0 && this.epoch % this.checkpointEpochs == 0) {
                if (this.prevCheckpointTestLoss - this.valLoss < this.earlyStoppingMinDelta) {
                    if (this.earlyStoppingCheckpointCount == this.earlyStoppingPatience) {
                        this.stop();
                    } else {
                        ++this.earlyStoppingCheckpointCount;
                    }
                } else {
                    this.earlyStoppingCheckpointCount = 0;
                }
                this.prevCheckpointTestLoss = this.valLoss;
            }
            if (this.trainingSnapshots && this.epoch > 0 && this.epoch % this.snapshotEpochs == 0) {
                try {
                    FileIO.writeToFile(this.neuralNet, this.snapshotPath + "_epoch_" + this.epoch + ".dnet");
                }
                catch (IOException ex) {
                    LOGGER.catching((Throwable)ex);
                }
            }
            boolean bl = this.stopTraining = this.stopTraining || (long)this.epoch == this.maxEpochs || this.totalTrainingLoss <= this.maxError;
        } while (!this.stopTraining);
        long endTraining = System.currentTimeMillis();
        long trainingTime = endTraining - startTraining;
        LOGGER.info(System.lineSeparator() + "TRAINING COMPLETED");
        LOGGER.info("Total Training Time: " + trainingTime + "ms");
        LOGGER.info("------------------------------------------------------------------------");
        this.fireTrainingEvent(TrainingEvent.Type.STOPPED);
    }

    public long getMaxEpochs() {
        return this.maxEpochs;
    }

    public BackpropagationTrainer setMaxEpochs(long maxEpochs) {
        if (maxEpochs <= 0L) {
            throw new IllegalArgumentException("Max epochs should be greater then zero : " + maxEpochs);
        }
        this.maxEpochs = maxEpochs;
        return this;
    }

    public float getMaxError() {
        return this.maxError;
    }

    public BackpropagationTrainer setMaxError(float maxError) {
        if (maxError < 0.0f) {
            throw new IllegalArgumentException("Max error cannot be negative : " + maxError);
        }
        this.maxError = maxError;
        return this;
    }

    public BackpropagationTrainer setLearningRate(float learningRate) {
        if (learningRate < 0.0f) {
            throw new IllegalArgumentException("Learning rate cannot be negative : " + learningRate);
        }
        if (learningRate > 1.0f) {
            throw new IllegalArgumentException("Learning rate cannot be greater then 1 : " + learningRate);
        }
        this.learningRate = learningRate;
        return this;
    }

    public BackpropagationTrainer setL2Regularization(float regL2) {
        this.regL2 = regL2;
        return this;
    }

    public BackpropagationTrainer setL1Regularization(float regL1) {
        this.regL1 = regL1;
        return this;
    }

    public boolean getShuffle() {
        return this.shuffle;
    }

    public void setShuffle(boolean shuffle) {
        this.shuffle = shuffle;
    }

    private void fireTrainingEvent(TrainingEvent.Type type) {
        for (TrainingListener l : this.listeners) {
            l.handleEvent(new TrainingEvent(this, type));
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void addListener(TrainingListener listener) {
        Objects.requireNonNull(listener, "Training listener cannot be null!");
        List<TrainingListener> list = this.listeners;
        synchronized (list) {
            if (!this.listeners.contains(listener)) {
                this.listeners.add(listener);
            }
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public synchronized void removeListener(TrainingListener listener) {
        List<TrainingListener> list = this.listeners;
        synchronized (list) {
            this.listeners.remove(listener);
        }
    }

    public boolean isBatchMode() {
        return this.batchMode;
    }

    public BackpropagationTrainer setBatchMode(boolean batchMode) {
        this.batchMode = batchMode;
        return this;
    }

    public int getBatchSize() {
        return this.batchSize;
    }

    public BackpropagationTrainer setBatchSize(int batchSize) {
        this.batchSize = batchSize;
        return this;
    }

    public BackpropagationTrainer setMomentum(float momentum) {
        this.momentum = momentum;
        return this;
    }

    public float getMomentum() {
        return this.momentum;
    }

    public float getLearningRate() {
        return this.learningRate;
    }

    public void stop() {
        this.stopTraining = true;
    }

    public float getTrainingLoss() {
        return this.totalTrainingLoss;
    }

    public float getValidationLoss() {
        return this.valLoss;
    }

    public float getTrainingAccuracy() {
        return this.trainAccuracy;
    }

    public float getValidationAccuracy() {
        return this.valAccuracy;
    }

    public int getCurrentEpoch() {
        return this.epoch;
    }

    public OptimizerType getOptimizer() {
        return this.optType;
    }

    public BackpropagationTrainer setOptimizer(OptimizerType optimizer) {
        this.optType = optimizer;
        return this;
    }

    public DataSet<?> getTestSet() {
        return this.validationSet;
    }

    public void setTestSet(DataSet<MLDataItem> testSet) {
        this.validationSet = testSet;
    }

    public boolean getEarlyStopping() {
        return this.earlyStopping;
    }

    public void setEarlyStopping(boolean earlyStopping) {
        this.earlyStopping = earlyStopping;
    }

    public BackpropagationTrainer setSnapshotPath(String snapshotPath) {
        this.snapshotPath = snapshotPath;
        return this;
    }

    public int getSnapshotEpochs() {
        return this.snapshotEpochs;
    }

    public void setSnapshotEpochs(int snapshotEpochs) {
        this.snapshotEpochs = snapshotEpochs;
    }

    public String getSnapshotPath() {
        return this.snapshotPath;
    }

    public boolean createsTrainingSnaphots() {
        return this.trainingSnapshots;
    }

    public void setTrainingSnapshots(boolean trainingSnapshots) {
        this.trainingSnapshots = trainingSnapshots;
    }

    public int getCheckpointEpochs() {
        return this.checkpointEpochs;
    }

    public BackpropagationTrainer setCheckpointEpochs(int checkpointEpochs) {
        this.checkpointEpochs = checkpointEpochs;
        return this;
    }

    public float getEarlyStoppingMinDelta() {
        return this.earlyStoppingMinDelta;
    }

    public BackpropagationTrainer setEarlyStoppingMinDelta(float earlyStoppingMinDelta) {
        this.earlyStoppingMinDelta = earlyStoppingMinDelta;
        return this;
    }

    public int getEarlyStoppingPatience() {
        return this.earlyStoppingPatience;
    }

    public BackpropagationTrainer setEarlyStoppingPatience(int earlyStoppingPatience) {
        this.earlyStoppingPatience = earlyStoppingPatience;
        return this;
    }

    public void setProperties(Properties prop) {
        this.maxError = Float.parseFloat(prop.getProperty(PROP_MAX_ERROR));
        this.maxEpochs = Integer.parseInt(prop.getProperty(PROP_MAX_EPOCHS));
        this.learningRate = Float.parseFloat(prop.getProperty(PROP_LEARNING_RATE));
        this.momentum = Float.parseFloat(prop.getProperty(PROP_MOMENTUM));
        this.batchMode = Boolean.parseBoolean(prop.getProperty(PROP_BATCH_MODE));
        this.batchSize = Integer.parseInt(prop.getProperty(PROP_BATCH_SIZE));
        this.optType = OptimizerType.valueOf(prop.getProperty(PROP_OPTIMIZER_TYPE));
        if (prop.getProperty(PROP_LEARNING_RATE) != null) {
            this.learningRate = Float.parseFloat(prop.getProperty(PROP_LEARNING_RATE));
        }
    }

    private float validationLoss(DataSet<? extends MLDataItem> validationSet) {
        this.lossFunction.reset();
        float validationLoss = this.lossFunction.valueFor(this.neuralNet, validationSet);
        return validationLoss;
    }

    private float calculateAccuracy(DataSet<? extends MLDataItem> validationSet) {
        EvaluationMetrics pm = this.eval.evaluate(this.neuralNet, validationSet);
        return pm.get("Accuracy");
    }

    private void readObject(ObjectInputStream aInputStream) throws ClassNotFoundException, IOException {
        this.listeners = new ArrayList<TrainingListener>();
    }
}

