/*
 * Decompiled with CFR 0.152.
 */
package deepnetts.eval;

import deepnetts.data.MLDataItem;
import deepnetts.eval.MeanSquaredError;
import deepnetts.net.NeuralNetwork;
import java.util.List;
import javax.visrec.ml.data.DataSet;
import javax.visrec.ml.eval.EvaluationMetrics;
import javax.visrec.ml.eval.Evaluator;

public class RegresionEvaluator
implements Evaluator<NeuralNetwork, DataSet<MLDataItem>> {
    public EvaluationMetrics evaluate(NeuralNetwork neuralNet, DataSet<MLDataItem> testSet) {
        EvaluationMetrics pe = new EvaluationMetrics();
        MeanSquaredError mse = new MeanSquaredError();
        float tss = 0.0f;
        int numInputs = ((MLDataItem)testSet.get(0)).getInput().size();
        int numItems = testSet.size();
        float targetMean = this.mean(testSet);
        for (MLDataItem item : testSet) {
            neuralNet.setInput(item.getInput());
            float[] predicted = neuralNet.getOutput();
            mse.add(predicted, item.getTargetOutput().getValues());
            tss += (item.getTargetOutput().getValues()[0] - targetMean) * (item.getTargetOutput().getValues()[0] - targetMean);
        }
        float rss = mse.getSquaredSum();
        float rse = (float)Math.sqrt(rss / (float)(testSet.size() - 2));
        pe.set("ResidualStandardError", rse);
        float r2 = 1.0f - rss / tss;
        pe.set("RSquared", r2);
        float fStat = (tss - rss) / (float)numInputs / (rss / (float)(numItems - numInputs - 1));
        pe.set("FStatistics", fStat);
        pe.set("MeanSquaredError", mse.getMeanSquaredSum());
        return pe;
    }

    private float mean(DataSet<? extends MLDataItem> testSet) {
        float mean = 0.0f;
        for (MLDataItem ditem : testSet) {
            mean += ditem.getTargetOutput().get(0);
        }
        return mean / (float)testSet.size();
    }

    public static EvaluationMetrics averagePerformance(List<EvaluationMetrics> measures) {
        float mse = 0.0f;
        float rse = 0.0f;
        float r2 = 0.0f;
        float fstat = 0.0f;
        for (EvaluationMetrics em : measures) {
            mse += em.get("MeanSquaredError");
            r2 += em.get("ResidualStandardError");
            rse += em.get("RSquared");
            fstat += em.get("FStatistics");
        }
        int count = measures.size();
        EvaluationMetrics total = new EvaluationMetrics();
        total.set("MeanSquaredError", mse / (float)count);
        total.set("ResidualStandardError", rse / (float)count);
        total.set("RSquared", r2 / (float)count);
        total.set("FStatistics", fstat / (float)count);
        return total;
    }
}

