/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.weka.classification.pipeline;

import ai.libs.jaicore.ml.weka.WekaUtil;
import ai.libs.jaicore.ml.weka.classification.pipeline.PreprocessingException;
import ai.libs.jaicore.ml.weka.classification.pipeline.SupervisedFilterSelector;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import org.apache.commons.math3.stat.descriptive.DescriptiveStatistics;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import weka.attributeSelection.ASEvaluation;
import weka.attributeSelection.ASSearch;
import weka.attributeSelection.AttributeSelection;
import weka.classifiers.Classifier;
import weka.classifiers.SingleClassifierEnhancer;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;

public class MLPipeline
extends SingleClassifierEnhancer
implements Classifier,
Serializable {
    private static final Logger logger = LoggerFactory.getLogger(MLPipeline.class);
    private final List<SupervisedFilterSelector> preprocessors = new ArrayList<SupervisedFilterSelector>();
    private boolean trained = false;
    private int timeForTrainingPreprocessors;
    private int timeForTrainingClassifier;
    private DescriptiveStatistics timeForExecutingPreprocessors;
    private DescriptiveStatistics timeForExecutingClassifier;

    public MLPipeline(List<SupervisedFilterSelector> preprocessors, Classifier baseClassifier) {
        if (baseClassifier == null) {
            throw new IllegalArgumentException("Base classifier must not be null!");
        }
        this.preprocessors.addAll(preprocessors);
        super.setClassifier(baseClassifier);
    }

    public MLPipeline(ASSearch searcher, ASEvaluation evaluator, Classifier baseClassifier) {
        if (baseClassifier == null) {
            throw new IllegalArgumentException("Base classifier must not be null!");
        }
        if (searcher != null && evaluator != null) {
            AttributeSelection selector = new AttributeSelection();
            selector.setSearch(searcher);
            selector.setEvaluator(evaluator);
            this.preprocessors.add(new SupervisedFilterSelector(searcher, evaluator, selector));
        }
        super.setClassifier(baseClassifier);
    }

    public void buildClassifier(Instances data) throws Exception {
        long start;
        int numAttributesBefore = data.numAttributes();
        logger.info("Starting to build the preprocessors of the pipeline.");
        for (SupervisedFilterSelector pp : this.preprocessors) {
            if (!pp.isPrepared()) {
                try {
                    start = System.currentTimeMillis();
                    pp.prepare(data);
                    this.timeForTrainingPreprocessors = (int)(System.currentTimeMillis() - start);
                    int newNumberOfClasses = pp.apply(data).numClasses();
                    if (data.numClasses() != newNumberOfClasses) {
                        logger.info("{} changed number of classes from {} to {}", new Object[]{pp.getSelector(), data.numClasses(), newNumberOfClasses});
                    }
                }
                catch (NullPointerException e) {
                    logger.error("Could not apply preprocessor", (Throwable)e);
                }
            }
            data = pp.apply(data);
        }
        logger.info("Reduced number of attributes from {} to {}", (Object)numAttributesBefore, (Object)data.numAttributes());
        start = System.currentTimeMillis();
        super.getClassifier().buildClassifier(data);
        this.timeForTrainingClassifier = (int)(System.currentTimeMillis() - start);
        this.trained = true;
        this.timeForExecutingPreprocessors = new DescriptiveStatistics();
        this.timeForExecutingClassifier = new DescriptiveStatistics();
    }

    private Instance applyPreprocessors(Instance data) throws PreprocessingException {
        long start = System.currentTimeMillis();
        for (SupervisedFilterSelector pp : this.preprocessors) {
            data = pp.apply(data);
        }
        this.timeForExecutingPreprocessors.addValue((double)((int)(System.currentTimeMillis() - start)));
        return data;
    }

    public double classifyInstance(Instance arg0) throws Exception {
        if (!this.trained) {
            throw new IllegalStateException("Cannot make predictions on untrained pipeline!");
        }
        int numAttributesBefore = arg0.numAttributes();
        if (numAttributesBefore != (arg0 = this.applyPreprocessors(arg0)).numAttributes()) {
            logger.info("Reduced number of attributes from {} to {}", (Object)numAttributesBefore, (Object)arg0.numAttributes());
        }
        long start = System.currentTimeMillis();
        double result = super.getClassifier().classifyInstance(arg0);
        this.timeForExecutingClassifier.addValue((double)(System.currentTimeMillis() - start));
        return result;
    }

    public double[] classifyInstances(Instances arg0) throws Exception {
        int n = arg0.size();
        double[] answers = new double[n];
        for (int i = 0; i < n; ++i) {
            answers[i] = this.classifyInstance(arg0.get(i));
        }
        return answers;
    }

    public double[] distributionForInstance(Instance arg0) throws Exception {
        if (!this.trained) {
            throw new IllegalStateException("Cannot make predictions on untrained pipeline!");
        }
        if (arg0 == null) {
            throw new IllegalArgumentException("Cannot make predictions for null-instance");
        }
        if ((arg0 = this.applyPreprocessors(arg0)) == null) {
            throw new IllegalStateException("The filter has turned the instance into NULL");
        }
        long start = System.currentTimeMillis();
        double[] result = super.getClassifier().distributionForInstance(arg0);
        this.timeForExecutingClassifier.addValue((double)((int)(System.currentTimeMillis() - start)));
        return result;
    }

    public Capabilities getCapabilities() {
        return super.getClassifier().getCapabilities();
    }

    public Classifier getBaseClassifier() {
        return super.getClassifier();
    }

    public List<SupervisedFilterSelector> getPreprocessors() {
        return this.preprocessors;
    }

    public String toString() {
        return this.getPreprocessors() + " (preprocessors), " + WekaUtil.getClassifierDescriptor(this.getBaseClassifier()) + " (classifier)";
    }

    public long getTimeForTrainingPreprocessor() {
        return this.timeForTrainingPreprocessors;
    }

    public long getTimeForTrainingClassifier() {
        return this.timeForTrainingClassifier;
    }

    public DescriptiveStatistics getTimeForExecutingPreprocessor() {
        return this.timeForExecutingPreprocessors;
    }

    public DescriptiveStatistics getTimeForExecutingClassifier() {
        return this.timeForExecutingClassifier;
    }
}

