/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.mlpipeline_evaluation;

import ai.libs.jaicore.ml.WekaUtil;
import ai.libs.jaicore.ml.core.evaluation.measure.IMeasure;
import ai.libs.jaicore.ml.core.evaluation.measure.singlelabel.ZeroOneLoss;
import ai.libs.jaicore.ml.evaluation.evaluators.weka.IClassifierEvaluator;
import ai.libs.jaicore.ml.evaluation.evaluators.weka.MonteCarloCrossValidationEvaluator;
import ai.libs.jaicore.ml.evaluation.evaluators.weka.splitevaluation.ISplitBasedClassifierEvaluator;
import ai.libs.jaicore.ml.evaluation.evaluators.weka.splitevaluation.SimpleSLCSplitBasedClassifierEvaluator;
import java.util.Collection;
import java.util.List;
import java.util.Random;
import org.apache.commons.lang.NotImplementedException;
import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.core.Instances;

public class ConsistentMLPipelineEvaluator {
    private ConsistentMLPipelineEvaluator() {
    }

    public static double evaluateClassifier(String testSplitTechnique, String testEvaluationTechnique, int testSeed, String valSplitTechnique, String valEvaluationTechnique, int valSeed, Instances data, Classifier classifier) throws Exception {
        switch (testEvaluationTechnique) {
            case "single": {
                return ConsistentMLPipelineEvaluator.evaluateClassifier(valSplitTechnique, valEvaluationTechnique, valSeed, ConsistentMLPipelineEvaluator.getTrainSplit(testSplitTechnique, data, testSeed), classifier);
            }
            case "multi": {
                throw new NotImplementedException("\"multi\" not yet supported!");
            }
        }
        throw new IllegalArgumentException("Unkown evaluation technique.");
    }

    public static double evaluateClassifier(String splitTechnique, String evaluationTechnique, int seed, Instances data, Classifier classifier) throws Exception {
        switch (evaluationTechnique) {
            case "single": {
                Instances trainSplit = ConsistentMLPipelineEvaluator.getTrainSplit(splitTechnique, data, seed);
                Evaluation eval = new Evaluation(trainSplit);
                classifier.buildClassifier(trainSplit);
                eval.evaluateModel(classifier, ConsistentMLPipelineEvaluator.getTestSplit(splitTechnique, data, seed), new Object[0]);
                return 1.0 - eval.pctCorrect() / 100.0;
            }
            case "multi": {
                IClassifierEvaluator evaluator = ConsistentMLPipelineEvaluator.getEvaluatorForSplitTechnique(splitTechnique, data, seed);
                if (evaluator != null) {
                    return (Double)evaluator.evaluate((Object)classifier);
                }
                throw new IllegalArgumentException("Could not find classifier evaluator.");
            }
        }
        throw new IllegalArgumentException("Invalid split technique: " + evaluationTechnique);
    }

    public static IClassifierEvaluator getEvaluatorForSplitTechnique(String split_technique, Instances data, int seed) {
        String[] techniqueAndDescription = split_technique.split("_");
        if (techniqueAndDescription[0].equals("3MCCV")) {
            return new MonteCarloCrossValidationEvaluator((ISplitBasedClassifierEvaluator)new SimpleSLCSplitBasedClassifierEvaluator((IMeasure)new ZeroOneLoss()), 3, data, (double)Float.parseFloat(techniqueAndDescription[1]), (long)seed);
        }
        return null;
    }

    public static Instances getTrainSplit(String split_technique, Instances data, int seed) {
        String[] techniquAndDescription = split_technique.split("_");
        if (techniquAndDescription[0].equals("MCCV")) {
            Collection[] instancesInFolds = WekaUtil.getArbitrarySplit((Instances)data, (Random)new Random(seed), (double[])new double[]{Double.parseDouble(techniquAndDescription[1])});
            List folds = WekaUtil.realizeSplit((Instances)data, (Collection[])instancesInFolds);
            return (Instances)folds.get(0);
        }
        return null;
    }

    public static Instances getTestSplit(String split_technique, Instances data, int seed) {
        String[] techniquAndDescription = split_technique.split("_");
        if (techniquAndDescription[0].equals("MCCV")) {
            Collection[] instancesInFolds = WekaUtil.getArbitrarySplit((Instances)data, (Random)new Random(seed), (double[])new double[]{Double.parseDouble(techniquAndDescription[1])});
            List folds = WekaUtil.realizeSplit((Instances)data, (Collection[])instancesInFolds);
            return (Instances)folds.get(1);
        }
        return null;
    }
}

