/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.functionprediction.learner.learningcurveextrapolation;

import ai.libs.jaicore.ml.classification.loss.dataset.EClassificationPerformanceMeasure;
import ai.libs.jaicore.ml.core.evaluation.evaluator.SupervisedLearnerExecutor;
import ai.libs.jaicore.ml.core.filter.FilterBasedDatasetSplitter;
import ai.libs.jaicore.ml.core.filter.sampling.inmemory.ASamplingAlgorithm;
import ai.libs.jaicore.ml.core.filter.sampling.inmemory.factories.LabelBasedStratifiedSamplingFactory;
import ai.libs.jaicore.ml.core.filter.sampling.inmemory.factories.interfaces.IRerunnableSamplingAlgorithmFactory;
import ai.libs.jaicore.ml.core.filter.sampling.inmemory.factories.interfaces.ISamplingAlgorithmFactory;
import ai.libs.jaicore.ml.functionprediction.learner.learningcurveextrapolation.InvalidAnchorPointsException;
import ai.libs.jaicore.ml.functionprediction.learner.learningcurveextrapolation.LearningCurveExtrapolationMethod;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeoutException;
import org.api4.java.ai.ml.classification.singlelabel.evaluation.ISingleLabelClassification;
import org.api4.java.ai.ml.core.dataset.splitter.SplitFailedException;
import org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset;
import org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance;
import org.api4.java.ai.ml.core.evaluation.execution.ILearnerRunReport;
import org.api4.java.ai.ml.core.evaluation.learningcurve.ILearningCurve;
import org.api4.java.ai.ml.core.exception.DatasetCreationException;
import org.api4.java.ai.ml.core.learner.ISupervisedLearner;
import org.api4.java.algorithm.exceptions.AlgorithmException;
import org.api4.java.algorithm.exceptions.AlgorithmExecutionCanceledException;
import org.api4.java.common.control.ILoggingCustomizable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class LearningCurveExtrapolator
implements ILoggingCustomizable {
    private Logger logger = LoggerFactory.getLogger(LearningCurveExtrapolator.class);
    protected ISupervisedLearner<ILabeledInstance, ILabeledDataset<? extends ILabeledInstance>> learner;
    protected ILabeledDataset<? extends ILabeledInstance> dataset;
    protected ILabeledDataset<? extends ILabeledInstance> train;
    protected ILabeledDataset<? extends ILabeledInstance> test;
    protected ISamplingAlgorithmFactory<ILabeledDataset<?>, ? extends ASamplingAlgorithm<ILabeledDataset<?>>> samplingAlgorithmFactory;
    protected ASamplingAlgorithm<ILabeledDataset<? extends ILabeledInstance>> samplingAlgorithm;
    protected Random random;
    protected LearningCurveExtrapolationMethod extrapolationMethod;
    private final int[] anchorPoints;
    private final double[] yValues;
    private final int[] trainingTimes;

    public LearningCurveExtrapolator(LearningCurveExtrapolationMethod extrapolationMethod, ISupervisedLearner<ILabeledInstance, ILabeledDataset<? extends ILabeledInstance>> learner, ILabeledDataset<?> dataset, double trainsplit, int[] anchorPoints, ISamplingAlgorithmFactory<ILabeledDataset<?>, ? extends ASamplingAlgorithm<ILabeledDataset<?>>> samplingAlgorithmFactory, long seed) throws DatasetCreationException, InterruptedException {
        this.extrapolationMethod = extrapolationMethod;
        this.learner = learner;
        this.dataset = dataset;
        this.anchorPoints = anchorPoints;
        this.samplingAlgorithmFactory = samplingAlgorithmFactory;
        this.samplingAlgorithm = null;
        this.random = new Random(seed);
        this.createSplit(trainsplit, seed);
        this.yValues = new double[this.anchorPoints.length];
        this.trainingTimes = new int[this.anchorPoints.length];
    }

    public ILearningCurve extrapolateLearningCurve() throws InvalidAnchorPointsException, AlgorithmException, InterruptedException {
        try {
            ILabeledDataset<? extends ILabeledInstance> testInstances = this.test;
            SupervisedLearnerExecutor learnerExecutor = new SupervisedLearnerExecutor();
            EClassificationPerformanceMeasure metric = EClassificationPerformanceMeasure.ERRORRATE;
            for (int i = 0; i < this.anchorPoints.length; ++i) {
                if (this.samplingAlgorithmFactory instanceof IRerunnableSamplingAlgorithmFactory && this.samplingAlgorithm != null) {
                    ((IRerunnableSamplingAlgorithmFactory)this.samplingAlgorithmFactory).setPreviousRun(this.samplingAlgorithm);
                }
                this.samplingAlgorithm = this.samplingAlgorithmFactory.getAlgorithm(this.anchorPoints[i], this.train, this.random);
                ILabeledDataset<? extends ILabeledInstance> subsampledDataset = this.samplingAlgorithm.call();
                this.logger.debug("Running classifier with {} data points.", (Object)this.anchorPoints[i]);
                ILearnerRunReport report = learnerExecutor.execute(this.learner, subsampledDataset, testInstances);
                this.trainingTimes[i] = (int)(report.getTrainEndTime() - report.getTrainStartTime());
                this.yValues[i] = metric.loss(report.getPredictionDiffList().getCastedView(Integer.class, ISingleLabelClassification.class));
                this.logger.debug("Training finished. Observed learning curve value (accuracy) of {}.", (Object)this.yValues[i]);
            }
            if (this.logger.isInfoEnabled()) {
                this.logger.info("Computed accuracies of {} for anchor points {}. Now extrapolating a curve from these observations.", (Object)Arrays.toString(this.yValues), (Object)Arrays.toString(this.anchorPoints));
            }
            return this.extrapolationMethod.extrapolateLearningCurveFromAnchorPoints(this.anchorPoints, this.yValues, this.dataset.size());
        }
        catch (TimeoutException | AlgorithmException | AlgorithmExecutionCanceledException e) {
            throw new AlgorithmException("Error during creation of the subsamples for the anchorpoints", e);
        }
        catch (ExecutionException e) {
            throw new AlgorithmException("Error during learning curve extrapolation", (Throwable)e);
        }
        catch (InvalidAnchorPointsException | InterruptedException e) {
            throw e;
        }
        catch (Exception e) {
            throw new AlgorithmException("Error during training/testing the classifier", (Throwable)e);
        }
    }

    private void createSplit(double trainsplit, long seed) throws DatasetCreationException, InterruptedException {
        long start = System.currentTimeMillis();
        this.logger.debug("Creating split with training portion {} and seed {}", (Object)trainsplit, (Object)seed);
        Random r = new Random(seed);
        try {
            FilterBasedDatasetSplitter<ILabeledDataset<? extends ILabeledInstance>> splitter = new FilterBasedDatasetSplitter<ILabeledDataset<? extends ILabeledInstance>>(new LabelBasedStratifiedSamplingFactory(), trainsplit, r);
            List<ILabeledDataset<? extends ILabeledInstance>> folds = splitter.split(this.dataset);
            this.train = folds.get(0);
            this.test = folds.get(1);
            this.logger.debug("Shuffling train and test data");
            Collections.shuffle(this.train, r);
            Collections.shuffle(this.test, r);
            this.logger.debug("Finished split creation after {}ms", (Object)(System.currentTimeMillis() - start));
        }
        catch (SplitFailedException e) {
            throw new DatasetCreationException((Throwable)e);
        }
    }

    public ISupervisedLearner<ILabeledInstance, ILabeledDataset<? extends ILabeledInstance>> getLearner() {
        return this.learner;
    }

    public ILabeledDataset<?> getDataset() {
        return this.dataset;
    }

    public LearningCurveExtrapolationMethod getExtrapolationMethod() {
        return this.extrapolationMethod;
    }

    public int[] getAnchorPoints() {
        return this.anchorPoints;
    }

    public double[] getyValues() {
        return this.yValues;
    }

    public int[] getTrainingTimes() {
        return this.trainingTimes;
    }

    public String getLoggerName() {
        return this.logger.getName();
    }

    public void setLoggerName(String name) {
        this.logger = LoggerFactory.getLogger((String)name);
    }
}

