/*
 * Decompiled with CFR 0.152.
 */
package deepnetts.net.train;

import deepnetts.data.MLDataItem;
import deepnetts.data.TabularDataSet;
import deepnetts.eval.ClassifierEvaluator;
import deepnetts.eval.RegresionEvaluator;
import deepnetts.net.NeuralNetwork;
import deepnetts.net.train.BackpropagationTrainer;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import javax.visrec.ml.data.DataSet;
import javax.visrec.ml.eval.EvaluationMetrics;
import javax.visrec.ml.eval.Evaluator;
import org.apache.commons.lang3.SerializationUtils;

public class KFoldCrossValidation {
    private int splitsNum;
    private NeuralNetwork neuralNetwork;
    private BackpropagationTrainer trainer;
    private DataSet<MLDataItem> dataSet;
    private Evaluator<NeuralNetwork, DataSet<? extends MLDataItem>> evaluator;
    private final List<NeuralNetwork> trainedNetworks = new ArrayList<NeuralNetwork>();

    public EvaluationMetrics runCrossValidation() {
        ArrayList<EvaluationMetrics> measures = new ArrayList<EvaluationMetrics>();
        DataSet[] folds = this.dataSet.split(this.splitsNum);
        for (int testFoldIdx = 0; testFoldIdx < this.splitsNum; ++testFoldIdx) {
            DataSet testSet = folds[testFoldIdx];
            TabularDataSet trainingSet = new TabularDataSet(((TabularDataSet)this.dataSet).getNumInputs(), ((TabularDataSet)this.dataSet).getNumOutputs());
            trainingSet.setColumnNames(((TabularDataSet)this.dataSet).getColumnNames());
            for (int trainFoldIdx = 0; trainFoldIdx < this.splitsNum; ++trainFoldIdx) {
                if (trainFoldIdx == testFoldIdx) continue;
                trainingSet.addAll(folds[trainFoldIdx]);
            }
            NeuralNetwork neuralNet = (NeuralNetwork)SerializationUtils.clone((Serializable)this.neuralNetwork);
            this.trainer.train((DataSet<? extends MLDataItem>)trainingSet);
            EvaluationMetrics pe = this.evaluator.evaluate((Object)neuralNet, (Object)testSet);
            measures.add(pe);
            this.trainedNetworks.add(neuralNet);
        }
        if (this.evaluator instanceof ClassifierEvaluator) {
            return ClassifierEvaluator.averagePerformance(measures);
        }
        return RegresionEvaluator.averagePerformance(measures);
    }

    public List<NeuralNetwork> getTrainedNetworks() {
        return this.trainedNetworks;
    }

    public static Builder builder() {
        return new Builder();
    }

    public static class Builder {
        KFoldCrossValidation kFoldCV = new KFoldCrossValidation();

        public Builder splitsNum(int k) {
            this.kFoldCV.splitsNum = k;
            return this;
        }

        public Builder model(NeuralNetwork neuralNet) {
            this.kFoldCV.neuralNetwork = neuralNet;
            return this;
        }

        public Builder trainer(BackpropagationTrainer trainer) {
            this.kFoldCV.trainer = trainer;
            return this;
        }

        public Builder dataSet(DataSet dataSet) {
            this.kFoldCV.dataSet = dataSet;
            return this;
        }

        public Builder evaluator(Evaluator<NeuralNetwork, DataSet<? extends MLDataItem>> evaluator) {
            this.kFoldCV.evaluator = evaluator;
            return this;
        }

        public KFoldCrossValidation build() {
            return this.kFoldCV;
        }
    }
}

