/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.learningcurve.extrapolation;

import ai.libs.jaicore.basic.ILoggingCustomizable;
import ai.libs.jaicore.basic.algorithm.AlgorithmExecutionCanceledException;
import ai.libs.jaicore.basic.algorithm.exceptions.AlgorithmException;
import ai.libs.jaicore.ml.core.dataset.IDataset;
import ai.libs.jaicore.ml.core.dataset.IInstance;
import ai.libs.jaicore.ml.core.dataset.sampling.inmemory.ASamplingAlgorithm;
import ai.libs.jaicore.ml.core.dataset.sampling.inmemory.factories.interfaces.IRerunnableSamplingAlgorithmFactory;
import ai.libs.jaicore.ml.core.dataset.sampling.inmemory.factories.interfaces.ISamplingAlgorithmFactory;
import ai.libs.jaicore.ml.core.dataset.weka.WekaInstances;
import ai.libs.jaicore.ml.interfaces.LearningCurve;
import ai.libs.jaicore.ml.learningcurve.extrapolation.InvalidAnchorPointsException;
import ai.libs.jaicore.ml.learningcurve.extrapolation.LearningCurveExtrapolationMethod;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeoutException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import weka.classifiers.Classifier;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.UnsupportedAttributeTypeException;

public class LearningCurveExtrapolator<I extends IInstance>
implements ILoggingCustomizable {
    private Logger logger = LoggerFactory.getLogger(LearningCurveExtrapolator.class);
    protected Classifier learner;
    protected IDataset<I> dataset;
    protected IDataset<I> train;
    protected IDataset<I> test;
    protected ISamplingAlgorithmFactory<I, ? extends ASamplingAlgorithm<I>> samplingAlgorithmFactory;
    protected ASamplingAlgorithm<I> samplingAlgorithm;
    protected Random random;
    protected LearningCurveExtrapolationMethod extrapolationMethod;
    private final int[] anchorPoints;
    private final double[] yValues;
    private final int[] trainingTimes;

    public LearningCurveExtrapolator(LearningCurveExtrapolationMethod extrapolationMethod, Classifier learner, IDataset<I> dataset, double trainsplit, int[] anchorPoints, ISamplingAlgorithmFactory<I, ? extends ASamplingAlgorithm<I>> samplingAlgorithmFactory, long seed) {
        this.extrapolationMethod = extrapolationMethod;
        this.learner = learner;
        this.dataset = dataset;
        this.anchorPoints = anchorPoints;
        this.samplingAlgorithmFactory = samplingAlgorithmFactory;
        this.samplingAlgorithm = null;
        this.random = new Random(seed);
        this.createSplit(trainsplit, seed);
        this.yValues = new double[this.anchorPoints.length];
        this.trainingTimes = new int[this.anchorPoints.length];
    }

    public LearningCurve extrapolateLearningCurve() throws InvalidAnchorPointsException, AlgorithmException, InterruptedException {
        try {
            Instances testInstances = (Instances)((WekaInstances)this.test).getList();
            for (int i = 0; i < this.anchorPoints.length; ++i) {
                if (this.samplingAlgorithmFactory instanceof IRerunnableSamplingAlgorithmFactory && this.samplingAlgorithm != null) {
                    ((IRerunnableSamplingAlgorithmFactory)this.samplingAlgorithmFactory).setPreviousRun(this.samplingAlgorithm);
                }
                this.samplingAlgorithm = this.samplingAlgorithmFactory.getAlgorithm(this.anchorPoints[i], this.train, this.random);
                Object subsampledDataset = this.samplingAlgorithm.call();
                this.logger.debug("Running classifier with {} data points.", (Object)this.anchorPoints[i]);
                long start = System.currentTimeMillis();
                this.learner.buildClassifier((Instances)((WekaInstances)subsampledDataset).getList());
                this.trainingTimes[i] = (int)(System.currentTimeMillis() - start);
                double correctCounter = 0.0;
                for (Instance instance : testInstances) {
                    if (this.learner.classifyInstance(instance) != instance.classValue()) continue;
                    correctCounter += 1.0;
                }
                this.yValues[i] = correctCounter / (double)testInstances.size();
                this.logger.debug("Training finished. Observed learning curve value (accuracy) of {}.", (Object)this.yValues[i]);
            }
            this.logger.info("Computed accuracies of {} for anchor points {}. Now extrapolating a curve from these observations.", (Object)Arrays.toString(this.yValues), (Object)Arrays.toString(this.anchorPoints));
            return this.extrapolationMethod.extrapolateLearningCurveFromAnchorPoints(this.anchorPoints, this.yValues, this.dataset.size());
        }
        catch (UnsupportedAttributeTypeException e) {
            throw new AlgorithmException((Throwable)e, "Error during convertion of the dataset to WEKA instances");
        }
        catch (AlgorithmExecutionCanceledException | AlgorithmException | TimeoutException e) {
            throw new AlgorithmException(e, "Error during creation of the subsamples for the anchorpoints");
        }
        catch (ExecutionException e) {
            throw new AlgorithmException((Throwable)e, "Error during learning curve extrapolation");
        }
        catch (InvalidAnchorPointsException | InterruptedException e) {
            throw e;
        }
        catch (Exception e) {
            throw new AlgorithmException((Throwable)e, "Error during training/testing the classifier");
        }
    }

    private void createSplit(double trainsplit, long seed) {
        IDataset availableInstances;
        long start = System.currentTimeMillis();
        this.logger.debug("Creating split with training portion {} and seed {}", (Object)trainsplit, (Object)seed);
        this.train = this.dataset.createEmpty();
        this.test = this.dataset.createEmpty();
        IDataset<I> data = this.dataset.createEmpty();
        data.addAll(this.dataset);
        Random r = new Random(seed);
        Collections.shuffle(data, r);
        HashMap classStrati = new HashMap();
        this.dataset.forEach(d -> {
            Object c = d.getTargetValue(Object.class).getValue();
            if (!classStrati.containsKey(c)) {
                classStrati.put(c, this.dataset.createEmpty());
            }
            ((IDataset)classStrati.get(c)).add(d);
        });
        HashMap classStratiSizes = new HashMap(classStrati.size());
        for (Map.Entry entry : classStrati.entrySet()) {
            classStratiSizes.put(entry.getKey(), ((IDataset)classStrati.get(entry.getKey())).size());
        }
        for (Map.Entry entry : classStrati.entrySet()) {
            availableInstances = (IDataset)classStrati.get(entry.getKey());
            if (!availableInstances.isEmpty()) {
                this.train.add(availableInstances.get(0));
                availableInstances.remove(0);
            }
            if (availableInstances.isEmpty()) continue;
            this.test.add(availableInstances.get(0));
            availableInstances.remove(0);
        }
        for (Map.Entry entry : classStrati.entrySet()) {
            availableInstances = (IDataset)classStrati.get(entry.getKey());
            int trainItems = (int)Math.min((double)availableInstances.size(), Math.ceil(trainsplit * (double)((Integer)classStratiSizes.get(entry.getKey())).intValue()));
            for (int j = 0; j < trainItems; ++j) {
                this.train.add(availableInstances.get(0));
                availableInstances.remove(0);
            }
            int testItems = (int)Math.min((double)availableInstances.size(), Math.ceil((1.0 - trainsplit) * (double)((Integer)classStratiSizes.get(entry.getKey())).intValue()));
            for (int j = 0; j < testItems; ++j) {
                this.test.add(availableInstances.get(0));
                availableInstances.remove(0);
            }
        }
        this.logger.debug("Shuffling train and test data");
        Collections.shuffle(this.train, r);
        Collections.shuffle(this.test, r);
        this.logger.debug("Finished split creation after {}ms", (Object)(System.currentTimeMillis() - start));
    }

    public Classifier getLearner() {
        return this.learner;
    }

    public IDataset<I> getDataset() {
        return this.dataset;
    }

    public LearningCurveExtrapolationMethod getExtrapolationMethod() {
        return this.extrapolationMethod;
    }

    public int[] getAnchorPoints() {
        return this.anchorPoints;
    }

    public double[] getyValues() {
        return this.yValues;
    }

    public int[] getTrainingTimes() {
        return this.trainingTimes;
    }

    public String getLoggerName() {
        return this.logger.getName();
    }

    public void setLoggerName(String name) {
        this.logger = LoggerFactory.getLogger((String)name);
    }
}

