/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.mlplan.bigdata;

import ai.libs.jaicore.basic.StatisticsUtil;
import ai.libs.jaicore.basic.algorithm.AAlgorithm;
import ai.libs.jaicore.components.model.ComponentInstance;
import ai.libs.jaicore.ml.core.evaluation.evaluator.events.MCCVSplitEvaluationEvent;
import ai.libs.jaicore.ml.core.filter.sampling.infiles.ReservoirSampling;
import ai.libs.jaicore.ml.core.filter.sampling.inmemory.factories.SimpleRandomSamplingFactory;
import ai.libs.jaicore.ml.core.filter.sampling.inmemory.factories.interfaces.ISamplingAlgorithmFactory;
import ai.libs.jaicore.ml.functionprediction.learner.learningcurveextrapolation.LearningCurveExtrapolatedEvent;
import ai.libs.jaicore.ml.functionprediction.learner.learningcurveextrapolation.LearningCurveExtrapolationMethod;
import ai.libs.jaicore.ml.functionprediction.learner.learningcurveextrapolation.ipl.InversePowerLawExtrapolationMethod;
import ai.libs.jaicore.ml.weka.classification.learner.IWekaClassifier;
import ai.libs.jaicore.ml.weka.dataset.WekaInstances;
import ai.libs.mlplan.core.MLPlan;
import ai.libs.mlplan.core.events.SupervisedLearnerCreatedEvent;
import ai.libs.mlplan.multiclass.wekamlplan.MLPlanWekaBuilder;
import com.google.common.eventbus.Subscribe;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.io.Reader;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.concurrent.TimeUnit;
import org.api4.java.ai.ml.core.dataset.IDataSource;
import org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset;
import org.api4.java.ai.ml.core.learner.ISupervisedLearner;
import org.api4.java.algorithm.Timeout;
import org.api4.java.algorithm.events.IAlgorithmEvent;
import org.api4.java.algorithm.exceptions.AlgorithmException;
import org.api4.java.algorithm.exceptions.AlgorithmExecutionCanceledException;
import org.api4.java.algorithm.exceptions.AlgorithmTimeoutedException;
import org.api4.java.common.control.ILoggingCustomizable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import weka.classifiers.Classifier;
import weka.classifiers.functions.LinearRegression;
import weka.core.Attribute;
import weka.core.DenseInstance;
import weka.core.Instance;
import weka.core.Instances;

public class MLPlan4BigFileInput
extends AAlgorithm<File, Classifier>
implements ILoggingCustomizable {
    private Logger logger = LoggerFactory.getLogger(MLPlan4BigFileInput.class);
    private File intermediateSizeDownsampledFile = new File("testrsc/sampled/intermediate/" + ((File)this.getInput()).getName());
    private final int[] anchorpointsTraining = new int[]{8, 16, 64, 128};
    private Map<ISupervisedLearner<?, ?>, ComponentInstance> classifier2modelMap = new HashMap();
    private Map<ComponentInstance, int[]> trainingTimesDuringSearch = new HashMap<ComponentInstance, int[]>();
    private Map<ComponentInstance, List<Integer>> trainingTimesDuringSelection = new HashMap<ComponentInstance, List<Integer>>();
    private int numTrainingInstancesUsedInSelection;
    private MLPlan<IWekaClassifier> mlplan;

    public MLPlan4BigFileInput(File input) {
        super((Object)input);
    }

    private void downsampleData(File from, File to, int size) throws InterruptedException, AlgorithmExecutionCanceledException, AlgorithmException, AlgorithmTimeoutedException {
        ReservoirSampling sampler = new ReservoirSampling(new Random(0L), (File)this.getInput());
        try {
            File outputFolder = to.getParentFile();
            if (!outputFolder.exists()) {
                this.logger.info("Creating data output folder {}", (Object)outputFolder.getAbsolutePath());
                outputFolder.mkdirs();
            }
            this.logger.info("Starting sampler {} for data source {}", (Object)sampler.getClass().getName(), (Object)from.getAbsolutePath());
            sampler.setOutputFileName(to.getAbsolutePath());
            sampler.setSampleSize(size);
            sampler.call();
            this.logger.info("Reduced dataset size to {}", (Object)size);
        }
        catch (IOException e) {
            throw new AlgorithmException("Could not create a sub-sample of the given data.", (Throwable)e);
        }
    }

    public IAlgorithmEvent nextWithException() throws InterruptedException, AlgorithmExecutionCanceledException, AlgorithmTimeoutedException, AlgorithmException {
        switch (this.getState()) {
            case CREATED: {
                Instances data;
                this.downsampleData((File)this.getInput(), this.intermediateSizeDownsampledFile, 10000);
                File downsampledFile = new File("testrsc/sampled/" + ((File)this.getInput()).getName());
                this.downsampleData(this.intermediateSizeDownsampledFile, downsampledFile, 1000);
                if (!downsampledFile.exists()) {
                    throw new AlgorithmException("The file " + downsampledFile.getAbsolutePath() + " that should be used for ML-Plan does not exist!");
                }
                try {
                    data = new Instances((Reader)new FileReader(downsampledFile));
                    data.setClassIndex(data.numAttributes() - 1);
                    this.logger.info("Loaded {}x{} dataset", (Object)data.size(), (Object)data.numAttributes());
                }
                catch (IOException e) {
                    throw new AlgorithmException("Could not create a sub-sample of the given data.", (Throwable)e);
                }
                try {
                    MLPlanWekaBuilder builder = new MLPlanWekaBuilder();
                    builder.withLearningCurveExtrapolationEvaluation(this.anchorpointsTraining, (ISamplingAlgorithmFactory)new SimpleRandomSamplingFactory(), 0.7, (LearningCurveExtrapolationMethod)new InversePowerLawExtrapolationMethod());
                    builder.withNodeEvaluationTimeOut(new Timeout(15L, TimeUnit.MINUTES));
                    builder.withCandidateEvaluationTimeOut(new Timeout(5L, TimeUnit.MINUTES));
                    this.mlplan = builder.withDataset((ILabeledDataset)new WekaInstances(data)).build();
                    this.mlplan.setLoggerName(this.getLoggerName() + ".mlplan");
                    this.mlplan.registerListener((Object)this);
                    this.mlplan.setTimeout(new Timeout(this.getTimeout().seconds() - 30L, TimeUnit.SECONDS));
                    this.mlplan.setNumCPUs(3);
                    this.mlplan.setBuildSelectedClasifierOnGivenData(false);
                    this.logger.info("ML-Plan initialized, activation finished!");
                    return this.activate();
                }
                catch (IOException e) {
                    throw new AlgorithmException("Could not initialize ML-Plan!", (Throwable)e);
                }
            }
            case ACTIVE: {
                int numInstances;
                this.logger.info("Starting ML-Plan.");
                this.mlplan.call();
                this.logger.info("ML-Plan has finished. Selected classifier is {} with observed internal performance {}. Will now try to determine the portion of training data that may be used for final training.", (Object)this.mlplan.getSelectedClassifier(), (Object)this.mlplan.getInternalValidationErrorOfSelectedClassifier());
                int[] trainingTimesDuringSearch = this.trainingTimesDuringSearch.get(this.mlplan.getComponentInstanceOfSelectedClassifier());
                List<Integer> trainingTimesDuringSelection = this.trainingTimesDuringSelection.get(this.mlplan.getComponentInstanceOfSelectedClassifier());
                this.logger.info("Observed training times of selected classifier: {} (search) and {} (selection on {} training instances)", new Object[]{Arrays.toString(trainingTimesDuringSearch), trainingTimesDuringSelection, this.numTrainingInstancesUsedInSelection});
                Instances observedRuntimeData = this.getTrainingTimeInstancesForClassifier(this.mlplan.getComponentInstanceOfSelectedClassifier());
                this.logger.info("Infered the following data:\n{}", (Object)observedRuntimeData);
                LinearRegression lr = new LinearRegression();
                try {
                    lr.buildClassifier(observedRuntimeData);
                    this.logger.info("Obtained the following output for the regression model: {}", (Object)lr);
                }
                catch (Exception e1) {
                    throw new AlgorithmException("Could not build a regression model for the runtime.", (Throwable)e1);
                }
                int remainingTime = (int)this.getRemainingTimeToDeadline().milliseconds();
                this.logger.info("Determining number of instances that can be used for training given that {}s are remaining.", (Object)((int)Math.round((double)remainingTime / 1000.0)));
                for (numInstances = 500; numInstances < 10000; numInstances += 50) {
                    Instance low = this.getInstanceForRuntimeAnalysis(numInstances);
                    try {
                        double predictedRuntime = lr.classifyInstance(low);
                        if (predictedRuntime > (double)remainingTime) {
                            this.logger.info("Obtained predicted runtime of {}ms for {} training instances, which is more time than we still have. Choosing this number.", (Object)predictedRuntime, (Object)numInstances);
                            break;
                        }
                        this.logger.info("Obtained predicted runtime of {}ms for {} training instances, which still seems managable.", (Object)predictedRuntime, (Object)numInstances);
                        continue;
                    }
                    catch (Exception e) {
                        throw new AlgorithmException("Could not obtain a runtime prediction for " + numInstances + " instances.", (Throwable)e);
                    }
                }
                this.logger.info("Believe that {} instances can be used for training in time!", (Object)numInstances);
                try {
                    File finalDataFile = new File("testrsc/sampled/final/" + ((File)this.getInput()).getName());
                    this.downsampleData(this.intermediateSizeDownsampledFile, finalDataFile, numInstances);
                    Instances completeData = new Instances((Reader)new FileReader(finalDataFile));
                    completeData.setClassIndex(completeData.numAttributes() - 1);
                    this.logger.info("Created final dataset with {} instances. Now building the final classifier.", (Object)completeData.size());
                    long startFinalTraining = System.currentTimeMillis();
                    ((IWekaClassifier)this.mlplan.getSelectedClassifier()).fit((IDataSource)new WekaInstances(completeData));
                    this.logger.info("Classifier has been fully trained within {}ms.", (Object)(System.currentTimeMillis() - startFinalTraining));
                }
                catch (Exception e) {
                    throw new AlgorithmException("Could not train the final classifier with the full data.", (Throwable)e);
                }
                return this.terminate();
            }
        }
        throw new IllegalStateException();
    }

    private Instances getTrainingTimeInstancesForClassifier(ComponentInstance ci) {
        ArrayList<Attribute> attributes = new ArrayList<Attribute>();
        attributes.add(new Attribute("numInstances"));
        attributes.add(new Attribute("runtime"));
        Instances data = new Instances("Runtime Analysis Regression Data for " + ci, attributes, 0);
        for (int i = 0; i < this.anchorpointsTraining.length; ++i) {
            Instance inst = this.getInstanceForRuntimeAnalysis(this.anchorpointsTraining[i]);
            inst.setValue(1, (double)this.trainingTimesDuringSearch.get(ci)[i]);
            data.add(inst);
        }
        if (this.trainingTimesDuringSelection.containsKey(ci)) {
            Instance inst = this.getInstanceForRuntimeAnalysis(this.numTrainingInstancesUsedInSelection);
            inst.setValue(1, StatisticsUtil.mean((Collection)this.trainingTimesDuringSelection.get(ci)));
            data.add(inst);
        } else {
            this.logger.warn("Classifier {} has not been evaluated in selection phase. Cannot use this information to fit its regression model.", (Object)ci);
        }
        data.setClassIndex(1);
        return data;
    }

    private Instance getInstanceForRuntimeAnalysis(int numberOfInstances) {
        DenseInstance inst = new DenseInstance(3);
        inst.setValue(0, (double)numberOfInstances);
        return inst;
    }

    @Subscribe
    public void receiveClassifierCreatedEvent(SupervisedLearnerCreatedEvent e) {
        this.logger.info("Binding component instance {} to classifier {}", (Object)e.getInstance(), (Object)e.getClassifier());
        this.classifier2modelMap.put(e.getClassifier(), e.getInstance());
    }

    @Subscribe
    public void receiveExtrapolationFinishedEvent(LearningCurveExtrapolatedEvent e) {
        ComponentInstance ci = this.classifier2modelMap.get(e.getExtrapolator().getLearner());
        this.logger.info("Storing training times {} for classifier {}", (Object)Arrays.toString(e.getExtrapolator().getTrainingTimes()), (Object)ci);
        this.trainingTimesDuringSearch.put(ci, e.getExtrapolator().getTrainingTimes());
    }

    @Subscribe
    public void receiveMCCVFinishedEvent(MCCVSplitEvaluationEvent e) {
        ComponentInstance ci = this.classifier2modelMap.get(e.getClassifier());
        this.logger.info("Storing training time {} for classifier {} in selection phase with {} training instances and {} validation instances", new Object[]{e.getSplitEvaluationTime(), ci, e.getNumInstancesUsedForTraining(), e.getNumInstancesUsedForValidation()});
        if (this.numTrainingInstancesUsedInSelection == 0) {
            this.numTrainingInstancesUsedInSelection = e.getNumInstancesUsedForTraining();
        } else if (this.numTrainingInstancesUsedInSelection != e.getNumInstancesUsedForTraining()) {
            this.logger.warn("Memorized {} as number of instances used for training in selection phase, but now observed one classifier using {} instances.", (Object)this.numTrainingInstancesUsedInSelection, (Object)e.getNumInstancesUsedForTraining());
        }
        if (!this.trainingTimesDuringSelection.containsKey(ci)) {
            this.trainingTimesDuringSelection.put(ci, new ArrayList());
        }
        this.trainingTimesDuringSelection.get(ci).add(e.getSplitEvaluationTime());
    }

    public Classifier call() throws InterruptedException, AlgorithmExecutionCanceledException, AlgorithmTimeoutedException, AlgorithmException {
        while (this.hasNext()) {
            this.next();
        }
        return ((IWekaClassifier)this.mlplan.getSelectedClassifier()).getClassifier();
    }

    public void setLoggerName(String loggerName) {
        this.logger = LoggerFactory.getLogger((String)loggerName);
    }

    public String getLoggerName() {
        return this.logger.getName();
    }
}

