/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.weka.classification.pipeline;

import ai.libs.jaicore.ml.weka.WekaUtil;
import ai.libs.jaicore.ml.weka.classification.pipeline.FeaturePreprocessor;
import ai.libs.jaicore.ml.weka.classification.pipeline.PreprocessingException;
import ai.libs.jaicore.ml.weka.classification.pipeline.featuregen.FeatureGenerator;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import weka.classifiers.Classifier;
import weka.core.Attribute;
import weka.core.Capabilities;
import weka.core.DenseInstance;
import weka.core.Instance;
import weka.core.Instances;

public class MLSophisticatedPipeline
implements Classifier,
FeatureGenerator,
Serializable {
    private final List<FeatureGenerator> featureGenerators = new ArrayList<FeatureGenerator>();
    private final List<FeaturePreprocessor> featurePreprocessors = new ArrayList<FeaturePreprocessor>();
    private final List<FeaturePreprocessor> featureSelectors = new ArrayList<FeaturePreprocessor>();
    private final Classifier classifier;
    private boolean trained = false;
    private long timeForTrainingPreprocessors;
    private long timeForTrainingClassifier;
    private long timeForExecutingPreprocessor;
    private long timeForExecutingClassifier;
    private Instances emptyReferenceDataset;

    public MLSophisticatedPipeline(List<FeatureGenerator> featureGenerators, List<FeaturePreprocessor> preprocessors, List<FeaturePreprocessor> featureSelectors, Classifier baseClassifier) {
        if (baseClassifier == null) {
            throw new IllegalArgumentException("Base classifier must not be null!");
        }
        this.featureGenerators.addAll(featureGenerators);
        this.featurePreprocessors.addAll(preprocessors);
        this.featureSelectors.addAll(featureSelectors);
        this.classifier = baseClassifier;
    }

    public void buildClassifier(Instances data) throws Exception {
        long start;
        Instances mergedInstances = new Instances(data);
        int f = data.numAttributes();
        for (FeatureGenerator featureGenerator : this.featureGenerators) {
            Instances modifiedInstances;
            if (!featureGenerator.isPrepared()) {
                start = System.currentTimeMillis();
                featureGenerator.prepare(data);
                this.timeForTrainingPreprocessors = System.currentTimeMillis() - start;
            }
            if ((modifiedInstances = featureGenerator.apply(data)) == null) {
                throw new IllegalStateException("Feature Generator " + featureGenerator + " has generated a null-dataset!");
            }
            for (int i = 0; i < modifiedInstances.numAttributes(); ++i) {
                modifiedInstances.renameAttribute(modifiedInstances.attribute(i), "f" + f++);
            }
            mergedInstances = Instances.mergeInstances((Instances)mergedInstances, (Instances)modifiedInstances);
            mergedInstances.setClassIndex(data.classIndex());
        }
        data = mergedInstances;
        for (FeaturePreprocessor featurePreprocessor : this.featurePreprocessors) {
            featurePreprocessor.prepare(data);
            if ((data = featurePreprocessor.apply(data)).classIndex() >= 0) continue;
            throw new IllegalStateException("Preprocessor " + featurePreprocessor + " has removed class index!");
        }
        for (FeaturePreprocessor featurePreprocessor : this.featureSelectors) {
            featurePreprocessor.prepare(data);
            if ((data = featurePreprocessor.apply(data)).classIndex() >= 0) continue;
            throw new IllegalStateException("Preprocessor " + featurePreprocessor + " has removed class index!");
        }
        this.emptyReferenceDataset = new Instances(data);
        this.emptyReferenceDataset.clear();
        start = System.currentTimeMillis();
        this.classifier.buildClassifier(data);
        this.timeForTrainingClassifier = System.currentTimeMillis() - start;
        this.trained = true;
    }

    private Instance applyPreprocessors(Instance data) throws PreprocessingException {
        long start = System.currentTimeMillis();
        DenseInstance mergedInstance = new DenseInstance(data);
        mergedInstance.setDataset(data.dataset());
        for (FeatureGenerator featureGenerator : this.featureGenerators) {
            Instances mergedDatasetA = new Instances(mergedInstance.dataset());
            mergedDatasetA.clear();
            mergedDatasetA.add((Instance)mergedInstance);
            Instance modifiedInstance = featureGenerator.apply(data);
            if (modifiedInstance.dataset() == null) {
                throw new IllegalStateException("Instance was detached from dataset by " + featureGenerator);
            }
            Instances mergedDatasetB = modifiedInstance.dataset();
            Instances mergedDataset = Instances.mergeInstances((Instances)mergedDatasetA, (Instances)mergedDatasetB);
            mergedDataset.setClassIndex(mergedDatasetA.classIndex());
            mergedInstance = mergedInstance.mergeInstance(modifiedInstance);
            mergedInstance.setDataset(mergedDataset);
            this.timeForExecutingPreprocessor = System.currentTimeMillis() - start;
        }
        data = mergedInstance;
        for (FeaturePreprocessor featurePreprocessor : this.featurePreprocessors) {
            data = featurePreprocessor.apply(data);
        }
        for (FeaturePreprocessor featurePreprocessor : this.featureSelectors) {
            data = featurePreprocessor.apply(data);
        }
        return data;
    }

    public double classifyInstance(Instance arg0) throws Exception {
        if (!this.trained) {
            throw new IllegalStateException("Cannot make predictions on untrained pipeline!");
        }
        arg0 = this.applyPreprocessors(arg0);
        long start = System.currentTimeMillis();
        double result = this.classifier.classifyInstance(arg0);
        this.timeForExecutingClassifier = System.currentTimeMillis() - start;
        return result;
    }

    public double[] distributionForInstance(Instance arg0) throws Exception {
        if (!this.trained) {
            throw new IllegalStateException("Cannot make predictions on untrained pipeline!");
        }
        if (arg0 == null) {
            throw new IllegalArgumentException("Cannot make predictions for null-instance");
        }
        if ((arg0 = this.applyPreprocessors(arg0)) == null) {
            throw new IllegalStateException("The filter has turned the instance into NULL");
        }
        long start = System.currentTimeMillis();
        double[] result = this.classifier.distributionForInstance(arg0);
        this.timeForExecutingClassifier = System.currentTimeMillis() - start;
        return result;
    }

    public Capabilities getCapabilities() {
        return this.classifier.getCapabilities();
    }

    public Classifier getBaseClassifier() {
        return this.classifier;
    }

    public long getTimeForTrainingPreprocessor() {
        return this.timeForTrainingPreprocessors;
    }

    public long getTimeForTrainingClassifier() {
        return this.timeForTrainingClassifier;
    }

    public long getTimeForExecutingPreprocessor() {
        return this.timeForExecutingPreprocessor;
    }

    public long getTimeForExecutingClassifier() {
        return this.timeForExecutingClassifier;
    }

    @Override
    public void prepare(Instances data) throws PreprocessingException {
        try {
            this.buildClassifier(data);
        }
        catch (Exception e) {
            throw new PreprocessingException(e);
        }
    }

    private Instances getEmptyProbingResultDataset() {
        if (!this.isPrepared()) {
            throw new IllegalStateException("Cannot determine empty dataset, because the pipeline has not been trained yet.");
        }
        ArrayList<Attribute> atts = new ArrayList<Attribute>();
        List<String> attributeValues = WekaUtil.getClassesDeclaredInDataset(this.emptyReferenceDataset);
        for (String att : attributeValues) {
            atts.add(new Attribute("probe_classprob_" + att + "_" + this));
        }
        return new Instances("probing", atts, 0);
    }

    @Override
    public Instance apply(Instance data) throws PreprocessingException {
        try {
            double[] classProbs = this.distributionForInstance(data);
            DenseInstance newInst = new DenseInstance(classProbs.length);
            Instances dataset = this.getEmptyProbingResultDataset();
            dataset.add((Instance)newInst);
            newInst.setDataset(dataset);
            for (int i = 0; i < classProbs.length; ++i) {
                newInst.setValue(i, classProbs[i]);
            }
            return newInst;
        }
        catch (Exception e) {
            throw new PreprocessingException(e);
        }
    }

    @Override
    public Instances apply(Instances data) throws PreprocessingException {
        Instances probingResults = new Instances(this.getEmptyProbingResultDataset());
        for (Instance inst : data) {
            Instance probedInst = this.apply(inst);
            probedInst.setDataset(probingResults);
            probingResults.add(probedInst);
        }
        return probingResults;
    }

    @Override
    public boolean isPrepared() {
        return this.trained;
    }
}

