/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.weka.rangequery.learner.intervaltree;

import ai.libs.jaicore.ml.weka.rangequery.learner.intervaltree.RangeQueryPredictor;
import ai.libs.jaicore.ml.weka.rangequery.learner.intervaltree.aggregation.AggressiveAggregator;
import ai.libs.jaicore.ml.weka.rangequery.learner.intervaltree.aggregation.IntervalAggregator;
import ai.libs.jaicore.ml.weka.rangequery.learner.intervaltree.featurespace.CategoricalFeatureDomain;
import ai.libs.jaicore.ml.weka.rangequery.learner.intervaltree.featurespace.FeatureDomain;
import ai.libs.jaicore.ml.weka.rangequery.learner.intervaltree.featurespace.FeatureSpace;
import ai.libs.jaicore.ml.weka.rangequery.learner.intervaltree.featurespace.NumericFeatureDomain;
import ai.libs.jaicore.ml.weka.rangequery.learner.intervaltree.util.RQPHelper;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.commons.math3.geometry.euclidean.oned.Interval;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import weka.classifiers.trees.RandomTree;

public class ExtendedRandomTree
extends RandomTree
implements RangeQueryPredictor {
    private static final Logger LOGGER = LoggerFactory.getLogger(ExtendedRandomTree.class);
    private static final String LOG_WARN_VARIANCE_ZERO = "The trees total variance is zero, predictions make no sense at this point!";
    private static final String LOG_WARN_NOT_PREPARED = "Tree is not prepared, preprocessing may take a while";
    private static final String LOG_INDIVIDUAL_VAR = "Individual var for {} = {}";
    private static final String LOG_TOTAL_VAR = "current total variance for {} = {}";
    private static final long serialVersionUID = -467555221387281335L;
    private final IntervalAggregator intervalAggregator;
    private FeatureSpace featureSpace;
    private HashMap<RandomTree.Tree, FeatureSpace> partitioning;
    private ArrayList<RandomTree.Tree> leaves;
    private ArrayList<Set<Double>> splitPoints;
    private double totalVariance;
    private transient Observation[][] allObservations;
    private HashMap<Set<Integer>, Double> varianceOfSubsetIndividual;
    private HashMap<Set<Integer>, Double> varianceOfSubsetTotal;
    private HashMap<RandomTree.Tree, Double> mapForEmptyLeaves;
    private boolean isPrepared;

    public ExtendedRandomTree() {
        this(new AggressiveAggregator());
        this.partitioning = new HashMap();
        this.leaves = new ArrayList();
        this.setAllowUnclassifiedInstances(false);
        this.varianceOfSubsetTotal = new HashMap();
        this.varianceOfSubsetIndividual = new HashMap();
        this.mapForEmptyLeaves = new HashMap();
        this.isPrepared = false;
    }

    public ExtendedRandomTree(FeatureSpace featureSpace) {
        this();
        this.featureSpace = featureSpace;
        this.isPrepared = false;
    }

    public ExtendedRandomTree(IntervalAggregator intervalAggregator) {
        try {
            this.setOptions(new String[]{"-U"});
        }
        catch (Exception e) {
            throw new IllegalStateException("Couldn't unprune the tree");
        }
        this.intervalAggregator = intervalAggregator;
        this.partitioning = new HashMap();
        this.leaves = new ArrayList();
        this.setAllowUnclassifiedInstances(false);
        this.varianceOfSubsetTotal = new HashMap();
        this.varianceOfSubsetIndividual = new HashMap();
        this.mapForEmptyLeaves = new HashMap();
        this.isPrepared = false;
    }

    @Override
    public Interval predictInterval(RQPHelper.IntervalAndHeader intervalAndHeader) {
        Interval[] queriedInterval = intervalAndHeader.getIntervals();
        ArrayDeque<Map.Entry<Interval[], RandomTree.Tree>> stack = new ArrayDeque<Map.Entry<Interval[], RandomTree.Tree>>();
        stack.push(RQPHelper.getEntry(queriedInterval, this.m_Tree));
        ArrayList<Double> list = new ArrayList<Double>();
        while (stack.peek() != null) {
            Map.Entry toProcess = (Map.Entry)stack.pop();
            RandomTree.Tree nextTree = (RandomTree.Tree)toProcess.getValue();
            double threshold = nextTree.getM_SplitPoint();
            int attribute = nextTree.getM_Attribute();
            RandomTree.Tree[] children = nextTree.getM_Successors();
            double[] classDistribution = nextTree.getM_Classdistribution();
            if (attribute == -1) {
                list.add(classDistribution[0]);
                continue;
            }
            Interval intervalForAttribute = queriedInterval[attribute];
            RandomTree.Tree leftChild = children[0];
            RandomTree.Tree rightChild = children[1];
            if (intervalForAttribute.getInf() <= threshold) {
                if (threshold <= intervalForAttribute.getSup()) {
                    Interval[] newInterval = RQPHelper.substituteInterval((Interval[])toProcess.getKey(), new Interval(intervalForAttribute.getInf(), threshold), attribute);
                    Interval[] newMaxInterval = RQPHelper.substituteInterval((Interval[])toProcess.getKey(), new Interval(threshold, intervalForAttribute.getSup()), attribute);
                    stack.push(RQPHelper.getEntry(newInterval, leftChild));
                    stack.push(RQPHelper.getEntry(newMaxInterval, rightChild));
                } else {
                    stack.push(RQPHelper.getEntry((Interval[])toProcess.getKey(), leftChild));
                }
            }
            if (!(intervalForAttribute.getSup() > threshold)) continue;
            stack.push(RQPHelper.getEntry((Interval[])toProcess.getKey(), rightChild));
        }
        return this.intervalAggregator.aggregate(list);
    }

    public void setFeatureSpace(FeatureSpace featureSpace) {
        this.featureSpace = featureSpace;
    }

    public FeatureSpace getFeatureSpace() {
        return this.featureSpace;
    }

    public double computeMarginalStandardDeviationForSubsetOfFeatures(Set<Integer> features) {
        if (!this.isPrepared) {
            LOGGER.warn(LOG_WARN_NOT_PREPARED);
            this.preprocess();
        }
        features = Collections.unmodifiableSet(features);
        if (this.totalVariance == 0.0) {
            LOGGER.warn(LOG_WARN_VARIANCE_ZERO);
            return Double.NaN;
        }
        double vU = this.varianceOfSubsetTotal.containsKey(features) ? this.varianceOfSubsetTotal.get(features).doubleValue() : this.computeTotalVarianceOfSubset(features);
        LOGGER.trace(LOG_TOTAL_VAR, features, (Object)vU);
        for (int k = 1; k < features.size(); ++k) {
            Set subsets = Sets.combinations(features, (int)k);
            for (Set subset : subsets) {
                if (subset.isEmpty()) continue;
                LOGGER.trace("Subtracting {} for {}", (Object)this.varianceOfSubsetIndividual.get(subset), (Object)subset);
                vU -= this.varianceOfSubsetIndividual.get(subset).doubleValue();
            }
        }
        LOGGER.trace(LOG_INDIVIDUAL_VAR, features, (Object)vU);
        if (vU < 0.0) {
            vU = 0.0;
        }
        this.varianceOfSubsetIndividual.put(features, vU);
        return Math.sqrt(vU);
    }

    public double computeMarginalVarianceContributionForSubsetOfFeatures(Set<Integer> features) {
        if (!this.isPrepared) {
            LOGGER.warn(LOG_WARN_NOT_PREPARED);
            this.preprocess();
        }
        features = Collections.unmodifiableSet(features);
        if (this.totalVariance == 0.0) {
            LOGGER.warn(LOG_WARN_VARIANCE_ZERO);
            return Double.NaN;
        }
        double vU = this.varianceOfSubsetTotal.containsKey(features) ? this.varianceOfSubsetTotal.get(features).doubleValue() : this.computeTotalVarianceOfSubset(features);
        LOGGER.trace(LOG_TOTAL_VAR, features, (Object)vU);
        for (int k = 1; k < features.size(); ++k) {
            Set subsets = Sets.combinations(features, (int)k);
            for (Set subset : subsets) {
                if (subset.isEmpty()) continue;
                LOGGER.trace("Subtracting {} for {} ", (Object)this.varianceOfSubsetIndividual.get(subset), (Object)subset);
                vU -= this.varianceOfSubsetIndividual.get(subset).doubleValue();
            }
        }
        LOGGER.trace(LOG_INDIVIDUAL_VAR, features, (Object)vU);
        vU = Math.max(vU, 0.0);
        this.varianceOfSubsetIndividual.put(features, vU);
        return vU / this.totalVariance;
    }

    public double computeMarginalVarianceContributionForSubsetOfFeaturesNotNormalized(Set<Integer> features) {
        if (!this.isPrepared) {
            LOGGER.warn(LOG_WARN_NOT_PREPARED);
            this.preprocess();
        }
        features = Collections.unmodifiableSet(features);
        if (this.totalVariance == 0.0) {
            LOGGER.warn(LOG_WARN_VARIANCE_ZERO);
            return Double.NaN;
        }
        double vU = this.varianceOfSubsetTotal.containsKey(features) ? this.varianceOfSubsetTotal.get(features).doubleValue() : this.computeTotalVarianceOfSubset(features);
        LOGGER.trace(LOG_TOTAL_VAR, features, (Object)vU);
        for (int k = 1; k < features.size(); ++k) {
            Set subsets = Sets.combinations(features, (int)k);
            for (Set subset : subsets) {
                if (subset.isEmpty()) continue;
                LOGGER.trace("Subtracting {} for {} ", (Object)this.varianceOfSubsetIndividual.get(subset), (Object)subset);
                vU -= this.varianceOfSubsetIndividual.get(subset).doubleValue();
            }
        }
        LOGGER.trace(LOG_INDIVIDUAL_VAR, features, (Object)vU);
        if (vU < 0.0) {
            vU = 0.0;
        }
        this.varianceOfSubsetIndividual.put(features, vU);
        return vU;
    }

    private double getMarginalPrediction(List<Integer> indices, List<Observation> observations) {
        double result = 0.0;
        HashSet<Integer> subset = new HashSet<Integer>();
        subset.addAll(indices);
        ArrayList<Double> obsList = new ArrayList<Double>(observations.size());
        for (Observation obs : observations) {
            obsList.add(obs.midPoint);
        }
        boolean consistentWithAnyLeaf = false;
        for (Map.Entry<RandomTree.Tree, FeatureSpace> leafEntry : this.partitioning.entrySet()) {
            double prediction;
            RandomTree.Tree leaf = leafEntry.getKey();
            if (!this.partitioning.get(leaf).containsPartialInstance(indices, obsList)) continue;
            double sizeOfLeaf = this.partitioning.get(leaf).getRangeSizeOfAllButSubset(subset);
            double sizeOfDomain = this.featureSpace.getRangeSizeOfAllButSubset(subset);
            double fractionOfSpaceForThisLeaf = sizeOfLeaf / sizeOfDomain;
            if (leaf.getM_Classdistribution() != null) {
                prediction = leaf.getM_Classdistribution()[0];
            } else if (this.mapForEmptyLeaves.containsKey(leaf)) {
                prediction = this.mapForEmptyLeaves.get(leaf);
            } else {
                LOGGER.warn("No prediction found anywhere!");
                prediction = Double.NaN;
            }
            assert (prediction != Double.NaN) : "Prediction must not be NaN";
            result += prediction * fractionOfSpaceForThisLeaf;
            consistentWithAnyLeaf = true;
        }
        if (!consistentWithAnyLeaf) {
            LOGGER.warn("Observation {} is not consistent with any leaf with indices: {}", obsList, indices);
        }
        return result;
    }

    private void computePartitioning(FeatureSpace subSpace, RandomTree.Tree node) {
        double splitPoint = node.getM_SplitPoint();
        int attribute = node.getM_Attribute();
        RandomTree.Tree[] children = node.getM_Successors();
        if (attribute == -1) {
            this.leaves.add(node);
            this.partitioning.put(node, subSpace);
        } else if (subSpace.getFeatureDomain(attribute) instanceof CategoricalFeatureDomain) {
            for (int i = 0; i < children.length; ++i) {
                if (children[i].getM_Classdistribution() == null && children[i].getM_Attribute() == -1) {
                    this.mapForEmptyLeaves.put(children[i], node.getM_Classdistribution()[0]);
                }
                FeatureSpace childSubSpace = new FeatureSpace(subSpace);
                ((CategoricalFeatureDomain)childSubSpace.getFeatureDomain(attribute)).setValues(new double[]{i});
                this.computePartitioning(childSubSpace, children[i]);
            }
        } else if (subSpace.getFeatureDomain(attribute) instanceof NumericFeatureDomain) {
            FeatureSpace leftSubSpace = new FeatureSpace(subSpace);
            ((NumericFeatureDomain)leftSubSpace.getFeatureDomain(attribute)).setMax(splitPoint);
            FeatureSpace rightSubSpace = new FeatureSpace(subSpace);
            ((NumericFeatureDomain)rightSubSpace.getFeatureDomain(attribute)).setMin(splitPoint);
            this.computePartitioning(leftSubSpace, children[0]);
            this.computePartitioning(rightSubSpace, children[1]);
        }
    }

    private void collectSplitPointsAndIntervalSizes(RandomTree.Tree root) {
        this.splitPoints = new ArrayList(this.featureSpace.getDimensionality());
        ArrayList splitPointsList = new ArrayList(this.featureSpace.getDimensionality());
        for (int i = 0; i < this.featureSpace.getDimensionality(); ++i) {
            this.splitPoints.add(i, new HashSet());
            splitPointsList.add(i, new ArrayList());
        }
        LinkedList<RandomTree.Tree> queueOfNodes = new LinkedList<RandomTree.Tree>();
        queueOfNodes.add(root);
        while (!queueOfNodes.isEmpty()) {
            RandomTree.Tree node = (RandomTree.Tree)queueOfNodes.poll();
            if (node.getM_Attribute() <= -1) continue;
            this.splitPoints.get(node.getM_Attribute()).add(node.getM_SplitPoint());
            ((List)splitPointsList.get(node.getM_Attribute())).add(node.getM_SplitPoint());
            for (int i = 0; i < node.getM_Successors().length; ++i) {
                queueOfNodes.add(node.getM_Successors()[i]);
            }
        }
    }

    private void computeObservations() {
        this.allObservations = new Observation[this.featureSpace.getDimensionality()][];
        for (int featureIndex = 0; featureIndex < this.featureSpace.getDimensionality(); ++featureIndex) {
            ArrayList<Double> curSplitPoints = new ArrayList<Double>();
            curSplitPoints.addAll((Collection)this.splitPoints.get(featureIndex));
            FeatureDomain curDomain = this.featureSpace.getFeatureDomain(featureIndex);
            if (curDomain instanceof NumericFeatureDomain) {
                NumericFeatureDomain curNumDomain = (NumericFeatureDomain)curDomain;
                curSplitPoints.add(curNumDomain.getMin());
                curSplitPoints.add(curNumDomain.getMax());
                Collections.sort(curSplitPoints);
                if (curSplitPoints.isEmpty()) {
                    this.allObservations[featureIndex] = new Observation[0];
                    continue;
                }
                this.allObservations[featureIndex] = new Observation[curSplitPoints.size() - 1];
                for (int lowerIntervalId = 0; lowerIntervalId < curSplitPoints.size() - 1; ++lowerIntervalId) {
                    Observation obs = (Double)curSplitPoints.get(lowerIntervalId + 1) - (Double)curSplitPoints.get(lowerIntervalId) > 0.0 ? new Observation(((Double)curSplitPoints.get(lowerIntervalId) + (Double)curSplitPoints.get(lowerIntervalId + 1)) / 2.0, (Double)curSplitPoints.get(lowerIntervalId + 1) - (Double)curSplitPoints.get(lowerIntervalId)) : new Observation(((Double)curSplitPoints.get(lowerIntervalId) + (Double)curSplitPoints.get(lowerIntervalId + 1)) / 2.0, 1.0);
                    this.allObservations[featureIndex][lowerIntervalId] = obs;
                }
                continue;
            }
            if (!(curDomain instanceof CategoricalFeatureDomain)) continue;
            CategoricalFeatureDomain cDomain = (CategoricalFeatureDomain)curDomain;
            this.allObservations[featureIndex] = new Observation[cDomain.getValues().length];
            for (int i = 0; i < this.allObservations[featureIndex].length; ++i) {
                this.allObservations[featureIndex][i] = new Observation(cDomain.getValues()[i], 1.0);
            }
        }
    }

    public double computeTotalVarianceOfSubset(Set<Integer> features) {
        if (this.varianceOfSubsetTotal.containsKey(features = Collections.unmodifiableSet(features))) {
            return this.varianceOfSubsetTotal.get(features);
        }
        LinkedList observationList = new LinkedList();
        LinkedList observationSet = new LinkedList();
        for (int featureIndex : features) {
            List list = Arrays.stream(this.allObservations[featureIndex]).collect(Collectors.toList());
            HashSet hSet = new HashSet();
            hSet.addAll(list);
            observationList.add(list);
            observationSet.add(hSet);
        }
        List observationProduct = Lists.cartesianProduct(observationList);
        WeightedVarianceHelper stat = new WeightedVarianceHelper();
        for (List curObs : observationProduct) {
            ArrayList<Integer> featureList = new ArrayList<Integer>();
            featureList.addAll(features);
            Collections.sort(featureList);
            double marginalPrediction = this.getMarginalPrediction(featureList, curObs);
            double prodOfIntervalSizes = 1.0;
            for (Observation obs : curObs) {
                if (obs.intervalSize == 0.0) continue;
                prodOfIntervalSizes *= obs.intervalSize;
            }
            double sizeOfAllButFeatures = this.getFeatureSpace().getRangeSizeOfAllButSubset(features);
            if (Double.isNaN(marginalPrediction)) continue;
            stat.push(marginalPrediction, sizeOfAllButFeatures * prodOfIntervalSizes);
        }
        double vU = stat.getPopulaionVariance();
        this.varianceOfSubsetTotal.put(features, vU);
        return vU;
    }

    public double getTotalVariance() {
        return this.totalVariance;
    }

    public void preprocess() {
        this.computePartitioning(this.featureSpace, this.m_Tree);
        this.collectSplitPointsAndIntervalSizes(this.m_Tree);
        this.computeObservations();
        HashSet<Integer> set = new HashSet<Integer>();
        for (int i = 0; i < this.featureSpace.getDimensionality(); ++i) {
            set.add(i);
        }
        this.totalVariance = this.computeTotalVarianceOfSubset(set);
        this.isPrepared = true;
    }

    public void printObservations() {
        for (int i = 0; i < this.allObservations.length; ++i) {
            StringBuilder sb = new StringBuilder();
            for (int j = 0; j < this.allObservations[i].length; ++j) {
                sb.append(this.allObservations[i][j].midPoint + ", ");
            }
            LOGGER.debug("Observations for feature {}: {}", (Object)i, (Object)sb);
        }
    }

    public void printSplitPoints() {
        for (int i = 0; i < this.splitPoints.size(); ++i) {
            Set<Double> points = this.splitPoints.get(i);
            ArrayList<Double> sorted = new ArrayList<Double>(points);
            if (this.getFeatureSpace().getFeatureDomain(i) instanceof NumericFeatureDomain) {
                sorted.add(((NumericFeatureDomain)this.getFeatureSpace().getFeatureDomain(i)).getMin());
                sorted.add(((NumericFeatureDomain)this.getFeatureSpace().getFeatureDomain(i)).getMax());
            }
            Collections.sort(sorted);
        }
    }

    public void printSizeOfFeatureSpaceAndPartitioning() {
        LOGGER.debug("Size of feature space: {}", (Object)this.featureSpace.getRangeSize());
        double sizeOfPartitioning = 0.0;
        for (Map.Entry<RandomTree.Tree, FeatureSpace> leafEntry : this.partitioning.entrySet()) {
            sizeOfPartitioning += this.partitioning.get(leafEntry.getKey()).getRangeSize();
        }
        LOGGER.debug("Complete size of partitioning: {}", (Object)sizeOfPartitioning);
        double sizeOfIntervals = 1.0;
        for (int i = 0; i < this.allObservations.length; ++i) {
            double temp = 0.0;
            for (int j = 0; j < this.allObservations[i].length; ++j) {
                temp += this.allObservations[i][j].intervalSize;
            }
            sizeOfIntervals *= temp;
        }
        LOGGER.debug("Complete size of intervals: {}", (Object)sizeOfIntervals);
    }

    private class Observation {
        private double midPoint;
        private double intervalSize;

        public Observation(double midPoint, double intervalSize) {
            this.midPoint = midPoint;
            this.intervalSize = intervalSize;
        }
    }

    private class WeightedVarianceHelper {
        private double average = 0.0;
        private double squaredDistanceToMean = 0.0;
        private double sumOfWeights = 0.0;

        public void push(double x, double weight) {
            if (weight <= 0.0) {
                throw new IllegalArgumentException("Weights have to be strictly positive!");
            }
            double delta = x - this.average;
            this.sumOfWeights += weight;
            this.average += delta * weight / this.sumOfWeights;
            this.squaredDistanceToMean += weight * delta * (x - this.average);
        }

        public double getPopulaionVariance() {
            if (this.sumOfWeights > 0.0) {
                return Math.max(0.0, this.squaredDistanceToMean / this.sumOfWeights);
            }
            return Double.NaN;
        }
    }
}

