/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.mlplan.core;

import ai.libs.jaicore.ml.core.evaluation.measure.IMeasure;
import ai.libs.jaicore.ml.core.evaluation.measure.singlelabel.ZeroOneLoss;
import ai.libs.jaicore.ml.evaluation.evaluators.weka.factory.IClassifierEvaluatorFactory;
import ai.libs.jaicore.ml.evaluation.evaluators.weka.factory.MonteCarloCrossValidationEvaluatorFactory;
import ai.libs.jaicore.ml.evaluation.evaluators.weka.splitevaluation.ISplitBasedClassifierEvaluator;
import ai.libs.jaicore.ml.evaluation.evaluators.weka.splitevaluation.SimpleSLCSplitBasedClassifierEvaluator;
import ai.libs.jaicore.ml.weka.dataset.splitter.IDatasetSplitter;
import ai.libs.jaicore.ml.weka.dataset.splitter.MulticlassClassStratifiedSplitter;
import ai.libs.mlplan.core.AbstractMLPlanBuilder;

public abstract class AbstractMLPlanSingleLabelBuilder
extends AbstractMLPlanBuilder {
    protected static final int SEARCH_NUM_MC_ITERATIONS = 5;
    protected static final double SEARCH_TRAIN_FOLD_SIZE = 0.7;
    protected static final int SELECTION_NUM_MC_ITERATIONS = 5;
    protected static final double SELECTION_TRAIN_FOLD_SIZE = 0.7;
    protected static final IMeasure<Double, Double> LOSS_FUNCTION = new ZeroOneLoss();

    protected AbstractMLPlanSingleLabelBuilder() {
    }

    public AbstractMLPlanSingleLabelBuilder withMonteCarloCrossValidationInSearchPhase(int numIterations, double trainFoldSize, IMeasure<Double, Double> lossFunction) {
        if (!(this.getSearchEvaluatorFactory() instanceof MonteCarloCrossValidationEvaluatorFactory)) {
            this.withSearchPhaseEvaluatorFactory((IClassifierEvaluatorFactory)new MonteCarloCrossValidationEvaluatorFactory().withDatasetSplitter((IDatasetSplitter)new MulticlassClassStratifiedSplitter()));
        }
        ((MonteCarloCrossValidationEvaluatorFactory)this.getSearchEvaluatorFactory()).withNumMCIterations(numIterations).withTrainFoldSize(trainFoldSize).withSplitBasedEvaluator((ISplitBasedClassifierEvaluator)new SimpleSLCSplitBasedClassifierEvaluator(lossFunction));
        return this;
    }

    public AbstractMLPlanSingleLabelBuilder withMonteCarloCrossValidationInSelectionPhase(int numIterations, double trainFoldSize, IMeasure<Double, Double> lossFunction) {
        if (!(this.getSelectionEvaluatorFactory() instanceof MonteCarloCrossValidationEvaluatorFactory)) {
            this.withSelectionPhaseEvaluatorFactory(new MonteCarloCrossValidationEvaluatorFactory().withDatasetSplitter((IDatasetSplitter)new MulticlassClassStratifiedSplitter()));
        }
        ((MonteCarloCrossValidationEvaluatorFactory)this.getSelectionEvaluatorFactory()).withNumMCIterations(numIterations).withTrainFoldSize(trainFoldSize).withSplitBasedEvaluator((ISplitBasedClassifierEvaluator)new SimpleSLCSplitBasedClassifierEvaluator(lossFunction));
        return this;
    }

    public AbstractMLPlanSingleLabelBuilder withPerformanceMeasure(IMeasure<Double, Double> lossFunction) {
        if (!(this.getSearchEvaluatorFactory() instanceof MonteCarloCrossValidationEvaluatorFactory)) {
            this.withSearchPhaseEvaluatorFactory((IClassifierEvaluatorFactory)new MonteCarloCrossValidationEvaluatorFactory().withDatasetSplitter((IDatasetSplitter)new MulticlassClassStratifiedSplitter()).withNumMCIterations(5).withTrainFoldSize(0.7));
        }
        if (!(this.getSearchEvaluatorFactory() instanceof MonteCarloCrossValidationEvaluatorFactory)) {
            this.withSearchPhaseEvaluatorFactory((IClassifierEvaluatorFactory)new MonteCarloCrossValidationEvaluatorFactory().withDatasetSplitter((IDatasetSplitter)new MulticlassClassStratifiedSplitter()).withNumMCIterations(5).withTrainFoldSize(0.7));
        }
        ((MonteCarloCrossValidationEvaluatorFactory)this.getSelectionEvaluatorFactory()).withSplitBasedEvaluator((ISplitBasedClassifierEvaluator)new SimpleSLCSplitBasedClassifierEvaluator(lossFunction));
        return this;
    }

    protected IDatasetSplitter getDefaultDatasetSplitter() {
        return new MulticlassClassStratifiedSplitter();
    }
}

