/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.intervaltree;

import ai.libs.jaicore.ml.core.FeatureSpace;
import ai.libs.jaicore.ml.intervaltree.ExtendedRandomTree;
import ai.libs.jaicore.ml.intervaltree.RangeQueryPredictor;
import ai.libs.jaicore.ml.intervaltree.aggregation.AggressiveAggregator;
import ai.libs.jaicore.ml.intervaltree.aggregation.IntervalAggregator;
import ai.libs.jaicore.ml.intervaltree.aggregation.QuantileAggregator;
import ai.libs.jaicore.ml.intervaltree.util.RQPHelper;
import java.util.ArrayList;
import java.util.Set;
import org.apache.commons.math3.geometry.euclidean.oned.Interval;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import weka.classifiers.Classifier;
import weka.classifiers.trees.RandomForest;
import weka.core.Instance;
import weka.core.Instances;

public class ExtendedRandomForest
extends RandomForest
implements RangeQueryPredictor {
    private static final long serialVersionUID = 8774800172762290733L;
    private static final Logger log = LoggerFactory.getLogger(ExtendedRandomForest.class);
    private final IntervalAggregator forestAggregator;
    private FeatureSpace featureSpace;

    public ExtendedRandomForest() {
        this(new QuantileAggregator(0.15), new AggressiveAggregator());
    }

    public ExtendedRandomForest(IntervalAggregator treeAggregator, IntervalAggregator forestAggregator) {
        ExtendedRandomTree rTree = new ExtendedRandomTree(treeAggregator);
        this.setClassifier((Classifier)rTree);
        this.forestAggregator = forestAggregator;
    }

    public ExtendedRandomForest(FeatureSpace featureSpace) {
        this();
        this.featureSpace = featureSpace;
        ExtendedRandomTree erTree = (ExtendedRandomTree)this.getClassifier();
        erTree.setFeatureSpace(featureSpace);
    }

    public ExtendedRandomForest(IntervalAggregator treeAggregator, IntervalAggregator forestAggregator, FeatureSpace featureSpace) {
        this.forestAggregator = forestAggregator;
        this.featureSpace = featureSpace;
        ExtendedRandomTree erTree = new ExtendedRandomTree(treeAggregator);
        erTree.setFeatureSpace(featureSpace);
        this.setClassifier((Classifier)erTree);
    }

    public void prepareForest(Instances data) {
        this.featureSpace = new FeatureSpace(data);
        for (Classifier classifier : this.m_Classifiers) {
            ExtendedRandomTree curTree = (ExtendedRandomTree)classifier;
            curTree.setFeatureSpace(this.featureSpace);
            curTree.preprocess();
        }
    }

    public void printVariances() {
        for (Classifier classifier : this.m_Classifiers) {
            ExtendedRandomTree curTree = (ExtendedRandomTree)classifier;
            log.debug("cur var: {}", (Object)curTree.getTotalVariance());
        }
    }

    public double computeMarginalVarianceContributionForFeatureSubset(Set<Integer> features) {
        double avg = 0.0;
        for (Classifier classifier : this.m_Classifiers) {
            ExtendedRandomTree curTree = (ExtendedRandomTree)classifier;
            double curMarg = curTree.computeMarginalVarianceContributionForSubsetOfFeatures(features);
            avg += curMarg * 1.0 / (double)this.m_Classifiers.length;
        }
        return avg;
    }

    public double computeMarginalVarianceContributionForFeatureSubsetNotNormalized(Set<Integer> features) {
        double avg = 0.0;
        for (Classifier classifier : this.m_Classifiers) {
            ExtendedRandomTree curTree = (ExtendedRandomTree)classifier;
            double curMarg = curTree.computeMarginalVarianceContributionForSubsetOfFeaturesNotNormalized(features);
            avg += curMarg * 1.0 / (double)this.m_Classifiers.length;
        }
        return avg;
    }

    public int getSize() {
        return this.m_Classifiers.length;
    }

    public FeatureSpace getFeatureSpace() {
        return this.featureSpace;
    }

    protected String defaultClassifierString() {
        return "jaicore.ml.intervaltree.ExtendedRandomTree";
    }

    public ExtendedRandomForest(int seed) {
        this();
        this.setSeed(seed);
    }

    @Override
    public Interval predictInterval(Instance rangeQuery) {
        ArrayList<Double> predictions = new ArrayList<Double>(this.m_Classifiers.length * 2);
        for (int i = 0; i < this.m_Classifiers.length; ++i) {
            ExtendedRandomTree classifier = (ExtendedRandomTree)this.m_Classifiers[i];
            Interval prediction = classifier.predictInterval(rangeQuery);
            predictions.add(prediction.getInf());
            predictions.add(prediction.getSup());
        }
        return this.forestAggregator.aggregate(predictions);
    }

    @Override
    public Interval predictInterval(RQPHelper.IntervalAndHeader intervalAndHeader) {
        ArrayList<Double> predictions = new ArrayList<Double>(this.m_Classifiers.length * 2);
        for (int i = 0; i < this.m_Classifiers.length; ++i) {
            ExtendedRandomTree classifier = (ExtendedRandomTree)this.m_Classifiers[i];
            Interval prediction = classifier.predictInterval(intervalAndHeader);
            predictions.add(prediction.getInf());
            predictions.add(prediction.getSup());
        }
        return this.forestAggregator.aggregate(predictions);
    }
}

