/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.mlplan.metamining;

import ai.libs.hasco.core.HASCOUtil;
import ai.libs.hasco.metamining.IMetaMiner;
import ai.libs.hasco.metamining.MetaMinerBasedSorter;
import ai.libs.jaicore.components.api.IComponent;
import ai.libs.jaicore.components.api.IComponentInstance;
import ai.libs.jaicore.components.model.ComponentInstance;
import ai.libs.jaicore.logging.LoggerUtil;
import ai.libs.jaicore.logic.fol.structure.Monom;
import ai.libs.jaicore.ml.classification.loss.dataset.EAggregatedClassifierMetric;
import ai.libs.jaicore.ml.core.evaluation.MLEvaluationUtil;
import ai.libs.jaicore.ml.weka.classification.learner.IWekaClassifier;
import ai.libs.jaicore.ml.weka.dataset.WekaInstances;
import ai.libs.jaicore.planning.hierarchical.algorithms.forwarddecomposition.graphgenerators.tfd.TFDNode;
import ai.libs.jaicore.search.algorithms.standard.lds.BestFirstLimitedDiscrepancySearch;
import ai.libs.jaicore.search.algorithms.standard.lds.BestFirstLimitedDiscrepancySearchFactory;
import ai.libs.jaicore.search.algorithms.standard.lds.NodeOrderList;
import ai.libs.jaicore.search.model.other.SearchGraphPath;
import ai.libs.jaicore.search.model.travesaltree.ReducedGraphGenerator;
import ai.libs.jaicore.search.probleminputs.GraphSearchWithNodeRecommenderInput;
import ai.libs.mlplan.metamining.IntermediateSolutionEvent;
import ai.libs.mlplan.metamining.WEKAMetaminer;
import ai.libs.mlplan.metamining.databaseconnection.ExperimentRepository;
import ai.libs.mlplan.weka.MLPlan4Weka;
import ai.libs.mlplan.weka.MLPlanWekaBuilder;
import ai.libs.mlplan.weka.weka.MLPipelineComponentInstanceFactory;
import ai.libs.mlplan.weka.weka.WekaPipelineFactory;
import com.google.common.eventbus.EventBus;
import java.io.File;
import java.io.IOException;
import java.sql.SQLException;
import java.util.Collection;
import java.util.Comparator;
import java.util.List;
import java.util.NoSuchElementException;
import java.util.Timer;
import java.util.TimerTask;
import org.apache.commons.lang3.time.StopWatch;
import org.api4.java.ai.graphsearch.problem.IPathSearchInput;
import org.api4.java.ai.ml.classification.IClassifier;
import org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset;
import org.api4.java.ai.ml.core.learner.ISupervisedLearner;
import org.api4.java.algorithm.exceptions.AlgorithmException;
import org.api4.java.datastructure.graph.implicit.IGraphGenerator;
import org.openml.webapplication.fantail.dc.GlobalCharacterizer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import weka.classifiers.AbstractClassifier;
import weka.core.Instance;
import weka.core.Instances;

public class MetaMLPlan
extends AbstractClassifier {
    private transient Logger logger = LoggerFactory.getLogger(MetaMLPlan.class);
    private static final long serialVersionUID = 4772178784402396834L;
    private static final File resourceFile = new File("resources/automl/searchmodels/weka/weka-all-autoweka.json");
    private String algorithmId = "MetaMLPlan";
    private transient BestFirstLimitedDiscrepancySearch<GraphSearchWithNodeRecommenderInput<TFDNode, String>, TFDNode, String, NodeOrderList> lds;
    private transient WEKAMetaminer metaMiner;
    private transient WekaPipelineFactory factory = new WekaPipelineFactory();
    private long timeoutInSeconds = 60L;
    private long safetyInSeconds = 1L;
    private int cpus = 1;
    private String metaFeatureSetName = "all";
    private String datasetSetName = "all";
    private int seed = 0;
    private IWekaClassifier bestModel;
    private transient Collection<IComponent> components;
    private transient EventBus eventBus = new EventBus();

    public MetaMLPlan(ILabeledDataset<?> data) throws IOException {
        this(resourceFile, data);
    }

    public MetaMLPlan(File configFile, ILabeledDataset<?> data) throws IOException {
        MLPlanWekaBuilder builder = new MLPlanWekaBuilder();
        builder.withSearchSpaceConfigFile(configFile);
        builder.withDataset(data);
        MLPlan4Weka mlPlan = builder.build();
        mlPlan.next();
        this.components = builder.getComponents();
        this.metaMiner = new WEKAMetaminer(builder.getComponentParameterConfigurations());
        BestFirstLimitedDiscrepancySearchFactory ldsFactory = new BestFirstLimitedDiscrepancySearchFactory();
        IPathSearchInput originalInput = mlPlan.getSearchProblemInputGenerator();
        GraphSearchWithNodeRecommenderInput problemInput = new GraphSearchWithNodeRecommenderInput((IGraphGenerator)new ReducedGraphGenerator(originalInput.getGraphGenerator()), originalInput.getGoalTester(), (Comparator)new MetaMinerBasedSorter((IMetaMiner)this.metaMiner, (Collection)builder.getComponents()));
        ldsFactory.setProblemInput((Object)problemInput);
        this.lds = ldsFactory.getAlgorithm();
    }

    public void buildMetaComponents(String host, String user, String password) throws AlgorithmException, InterruptedException, SQLException, IOException {
        ExperimentRepository repo = new ExperimentRepository(host, user, password, new MLPipelineComponentInstanceFactory(this.components), this.cpus, this.metaFeatureSetName, this.datasetSetName);
        this.metaMiner.build(repo.getDistinctPipelines(), repo.getDatasetCahracterizations(), repo.getPipelineResultsOnDatasets());
    }

    public void buildMetaComponents(String host, String user, String password, int limit) throws AlgorithmException, InterruptedException, SQLException, IOException {
        this.logger.info("Get past experiment data from data base and build MetaMiner.");
        ExperimentRepository repo = new ExperimentRepository(host, user, password, new MLPipelineComponentInstanceFactory(this.components), this.cpus, this.metaFeatureSetName, this.datasetSetName);
        repo.setLimit(limit);
        this.metaMiner.build(repo.getDistinctPipelines(), repo.getDatasetCahracterizations(), repo.getPipelineResultsOnDatasets());
    }

    public void buildClassifier(final Instances data) throws Exception {
        StopWatch totalTimer = new StopWatch();
        totalTimer.start();
        this.logger.info("Characterizing data set");
        this.metaMiner.setDataSetCharacterization(new GlobalCharacterizer().characterize(data));
        this.logger.info("Preparing validation split");
        this.logger.info("Searching for solutions");
        StopWatch trainingTimer = new StopWatch();
        this.bestModel = null;
        double bestScore = 1.0;
        double bestModelMaxTrainingTime = 0.0;
        boolean thereIsEnoughTime = true;
        boolean thereAreMoreElements = true;
        while (!this.lds.isCanceled() && thereIsEnoughTime && thereAreMoreElements) {
            try {
                SearchGraphPath searchGraphPath = (SearchGraphPath)this.lds.nextSolutionCandidate();
                List solution = searchGraphPath.getNodes();
                if (solution == null) {
                    this.logger.info("Ran out of solutions. Search is over.");
                    break;
                }
                ComponentInstance ci = HASCOUtil.getSolutionCompositionFromState(this.components, (Monom)((TFDNode)solution.get(solution.size() - 1)).getState(), (boolean)true);
                IWekaClassifier pl = this.factory.getComponentInstantiation((IComponentInstance)ci);
                trainingTimer.reset();
                trainingTimer.start();
                this.logger.info("Evaluate Pipeline: {}", (Object)pl);
                double score = MLEvaluationUtil.mccv((ISupervisedLearner)pl, (ILabeledDataset)new WekaInstances(data), (int)5, (double)0.7, (long)this.seed, (EAggregatedClassifierMetric)EAggregatedClassifierMetric.MEAN_ERRORRATE);
                this.logger.info("Pipeline Score: {}", (Object)score);
                trainingTimer.stop();
                this.eventBus.post((Object)new IntermediateSolutionEvent(null, (IClassifier)pl, score));
                if (score < bestScore) {
                    this.bestModel = pl;
                    bestScore = score;
                }
                if ((double)trainingTimer.getTime() > bestModelMaxTrainingTime) {
                    bestModelMaxTrainingTime = trainingTimer.getTime();
                }
                thereIsEnoughTime = this.checkTermination(totalTimer, bestModelMaxTrainingTime, thereIsEnoughTime);
            }
            catch (NoSuchElementException e) {
                this.logger.info("Finished search (Exhaustive search conducted).");
                thereAreMoreElements = false;
            }
            catch (Exception e) {
                this.logger.warn("Continuing search despite error: {}", (Object)LoggerUtil.getExceptionInfo((Throwable)e));
            }
        }
        final Thread finalEval = new Thread(){

            @Override
            public void run() {
                MetaMLPlan.this.logger.info("Evaluating best model on whole training data ({})", (Object)MetaMLPlan.this.bestModel);
                try {
                    MetaMLPlan.this.bestModel.getClassifier().buildClassifier(data);
                }
                catch (Exception e) {
                    MetaMLPlan.this.bestModel = null;
                    MetaMLPlan.this.logger.error("Evaluation of best model failed with an exception: {}", (Object)LoggerUtil.getExceptionInfo((Throwable)e));
                }
            }
        };
        TimerTask newT = new TimerTask(){

            @Override
            public void run() {
                MetaMLPlan.this.logger.error("MetaMLPlan: Interrupt building of final classifier because time is running out.");
                finalEval.interrupt();
            }
        };
        try {
            new Timer().schedule(newT, this.timeoutInSeconds * 1000L - this.safetyInSeconds * 1000L - totalTimer.getTime());
        }
        catch (IllegalArgumentException e) {
            this.logger.error("No time anymore to start evaluation of final model. Abort search.");
            return;
        }
        finalEval.start();
        finalEval.join();
        this.logger.info("Ready. Best solution: {}", (Object)this.bestModel);
    }

    private boolean checkTermination(StopWatch totalTimer, double bestModelMaxTrainingTime, boolean thereIsEnoughTime) {
        if ((double)((this.timeoutInSeconds - this.safetyInSeconds) * 1000L) <= (double)totalTimer.getTime() + bestModelMaxTrainingTime) {
            this.logger.info("Stopping search to train best model on whole training data which is expected to take {} ms", (Object)bestModelMaxTrainingTime);
            thereIsEnoughTime = false;
        }
        return thereIsEnoughTime;
    }

    public double classifyInstance(Instance instance) throws Exception {
        return this.bestModel.getClassifier().classifyInstance(instance);
    }

    public void registerListenerForIntermediateSolutions(Object listener) {
        this.eventBus.register(listener);
    }

    public void setTimeOutInSeconds(int timeOutInSeconds) {
        this.timeoutInSeconds = timeOutInSeconds;
    }

    public void setMetaFeatureSetName(String metaFeatureSetName) {
        this.metaFeatureSetName = metaFeatureSetName;
    }

    public void setDatasetSetName(String datasetSetName) {
        this.datasetSetName = datasetSetName;
    }

    public void setCPUs(int cPUs) {
        this.cpus = cPUs;
    }

    public WEKAMetaminer getMetaMiner() {
        return this.metaMiner;
    }

    public void setSeed(int seed) {
        this.seed = seed;
    }

    public String getAlgorithmId() {
        return this.algorithmId;
    }

    public void setAlgorithmId(String algorithmId) {
        this.algorithmId = algorithmId;
    }
}

