/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.weka.classification.singlelabel.timeseries.learner.trees;

import java.util.Random;
import org.api4.java.ai.ml.core.exception.PredictionException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import weka.classifiers.rules.ZeroR;
import weka.classifiers.trees.RandomTree;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Utils;

public class AccessibleRandomTree
extends RandomTree {
    private static final long serialVersionUID = 1L;
    private int nosLeafNodes;
    private int lastNode = 0;
    private static final Logger logger = LoggerFactory.getLogger(AccessibleRandomTree.class);
    protected AccessibleTree tree = null;

    public double[] distributionForInstance(Instance instance) throws Exception {
        if (this.m_zeroR != null) {
            return this.m_zeroR.distributionForInstance(instance);
        }
        return this.tree.distributionForInstance(instance);
    }

    public void buildClassifier(Instances data) throws Exception {
        this.nosLeafNodes = 0;
        if (this.m_computeImpurityDecreases) {
            this.m_impurityDecreasees = new double[data.numAttributes()][2];
        }
        if (this.m_KValue > data.numAttributes() - 1) {
            this.m_KValue = data.numAttributes() - 1;
        }
        if (this.m_KValue < 1) {
            this.m_KValue = (int)Utils.log2((double)((double)data.numAttributes() - 1.0)) + 1;
        }
        this.getCapabilities().testWithFail(data);
        data = new Instances(data);
        data.deleteWithMissingClass();
        if (data.numAttributes() == 1) {
            logger.error("Cannot build model (only class attribute present in data!), using ZeroR model instead!");
            this.m_zeroR = new ZeroR();
            this.m_zeroR.buildClassifier(data);
            return;
        }
        this.m_zeroR = null;
        Instances train = null;
        Instances backfit = null;
        Random rand = data.getRandomNumberGenerator((long)this.m_randomSeed);
        if (this.m_NumFolds <= 0) {
            train = data;
        } else {
            data.randomize(rand);
            data.stratify(this.m_NumFolds);
            train = data.trainCV(this.m_NumFolds, 1, rand);
            backfit = data.testCV(this.m_NumFolds, 1);
        }
        int[] attIndicesWindow = new int[data.numAttributes() - 1];
        int j = 0;
        for (int i = 0; i < attIndicesWindow.length; ++i) {
            if (Thread.currentThread().isInterrupted()) {
                throw new InterruptedException("Thread got interrupted, thus, kill WEKA.");
            }
            if (j == data.classIndex()) {
                // empty if block
            }
            int n = ++j;
            ++j;
            attIndicesWindow[i] = n;
        }
        double totalWeight = 0.0;
        double totalSumSquared = 0.0;
        double[] classProbs = new double[train.numClasses()];
        for (int i = 0; i < train.numInstances(); ++i) {
            if (Thread.currentThread().isInterrupted()) {
                throw new InterruptedException("Thread got interrupted, thus, kill WEKA.");
            }
            Instance inst = train.instance(i);
            if (data.classAttribute().isNominal()) {
                int n = (int)inst.classValue();
                classProbs[n] = classProbs[n] + inst.weight();
                totalWeight += inst.weight();
                continue;
            }
            classProbs[0] = classProbs[0] + inst.classValue() * inst.weight();
            totalSumSquared += inst.classValue() * inst.classValue() * inst.weight();
            totalWeight += inst.weight();
        }
        double trainVariance = 0.0;
        if (totalWeight == 0.0) {
            throw new IllegalStateException("Total weight must not be 0 at this point.");
        }
        if (data.classAttribute().isNumeric()) {
            trainVariance = RandomTree.singleVariance((double)classProbs[0], (double)totalSumSquared, (double)totalWeight) / totalWeight;
            classProbs[0] = classProbs[0] / totalWeight;
        }
        this.tree = new AccessibleTree();
        this.m_Info = new Instances(data, 0);
        this.tree.buildTree(train, classProbs, attIndicesWindow, totalWeight, rand, 0, this.m_MinVarianceProp * trainVariance);
        if (backfit != null) {
            this.tree.backfitData(backfit);
        }
    }

    public AccessibleTree getMTree() {
        return this.tree;
    }

    public int getNosLeafNodes() {
        return this.nosLeafNodes;
    }

    public int getLastNode() {
        return this.lastNode;
    }

    protected static double singleVariance(double s, double sS, double weight) {
        return sS - s * s / weight;
    }

    public class AccessibleTree
    extends RandomTree.Tree {
        private static final long serialVersionUID = 1L;
        protected AccessibleTree[] successors;
        private int leafNodeID;

        public AccessibleTree() {
            super((RandomTree)AccessibleRandomTree.this);
        }

        protected void buildTree(Instances data, double[] classProbs, int[] attIndicesWindow, double totalWeight, Random random, int depth, double minVariance) throws Exception {
            if (data.numInstances() == 0) {
                this.m_Attribute = -1;
                this.m_ClassDistribution = null;
                this.m_Prop = null;
                if (data.classAttribute().isNumeric()) {
                    this.m_Distribution = new double[2];
                }
                this.leafNodeID = AccessibleRandomTree.this.nosLeafNodes++;
                return;
            }
            double priorVar = 0.0;
            if (data.classAttribute().isNumeric()) {
                double totalSum = 0.0;
                double totalSumSquared = 0.0;
                double totalSumOfWeights = 0.0;
                for (int i = 0; i < data.numInstances(); ++i) {
                    Instance inst = data.instance(i);
                    totalSum += inst.classValue() * inst.weight();
                    totalSumSquared += inst.classValue() * inst.classValue() * inst.weight();
                    totalSumOfWeights += inst.weight();
                }
                priorVar = AccessibleRandomTree.singleVariance(totalSum, totalSumSquared, totalSumOfWeights);
            }
            if (data.classAttribute().isNominal()) {
                totalWeight = Utils.sum((double[])classProbs);
            }
            if (totalWeight < 2.0 * AccessibleRandomTree.this.m_MinNum || data.classAttribute().isNominal() && Utils.eq((double)classProbs[Utils.maxIndex((double[])classProbs)], (double)Utils.sum((double[])classProbs)) || data.classAttribute().isNumeric() && priorVar / totalWeight < minVariance || AccessibleRandomTree.this.getMaxDepth() > 0 && depth >= AccessibleRandomTree.this.getMaxDepth()) {
                this.m_Attribute = -1;
                this.m_ClassDistribution = (double[])classProbs.clone();
                if (data.classAttribute().isNumeric()) {
                    this.m_Distribution = new double[2];
                    this.m_Distribution[0] = priorVar;
                    this.m_Distribution[1] = totalWeight;
                }
                this.leafNodeID = AccessibleRandomTree.this.nosLeafNodes++;
                this.m_Prop = null;
                return;
            }
            double val = -1.7976931348623157E308;
            double split = -1.7976931348623157E308;
            double[][] bestDists = null;
            double[] bestProps = null;
            int bestIndex = 0;
            double[][] props = new double[1][0];
            double[][][] dists = new double[1][0][0];
            double[][] totalSubsetWeights = new double[data.numAttributes()][0];
            int attIndex = 0;
            int windowSize = attIndicesWindow.length;
            int k = AccessibleRandomTree.this.m_KValue;
            boolean gainFound = false;
            double[] tempNumericVals = new double[data.numAttributes()];
            while (!(windowSize <= 0 || k-- <= 0 && gainFound)) {
                double currVal;
                int chosenIndex = random.nextInt(windowSize);
                attIndex = attIndicesWindow[chosenIndex];
                attIndicesWindow[chosenIndex] = attIndicesWindow[windowSize - 1];
                attIndicesWindow[windowSize - 1] = attIndex;
                --windowSize;
                double currSplit = data.classAttribute().isNominal() ? this.distribution(props, dists, attIndex, data) : this.numericDistribution(props, dists, attIndex, totalSubsetWeights, data, tempNumericVals);
                double d = currVal = data.classAttribute().isNominal() ? this.gain(dists[0], this.priorVal(dists[0])) : tempNumericVals[attIndex];
                if (Utils.gr((double)currVal, (double)0.0)) {
                    gainFound = true;
                }
                if (!(currVal > val) && (AccessibleRandomTree.this.getBreakTiesRandomly() || currVal != val || attIndex >= bestIndex)) continue;
                val = currVal;
                bestIndex = attIndex;
                split = currSplit;
                bestProps = props[0];
                bestDists = dists[0];
            }
            this.m_Attribute = bestIndex;
            if (Utils.gr((double)val, (double)0.0)) {
                this.m_SplitPoint = split;
                this.m_Prop = bestProps;
                Instances[] subsets = this.splitData(data);
                this.successors = new AccessibleTree[bestDists.length];
                double[] attTotalSubsetWeights = totalSubsetWeights[bestIndex];
                for (int i = 0; i < bestDists.length; ++i) {
                    this.successors[i] = new AccessibleTree();
                    this.successors[i].buildTree(subsets[i], (double[])bestDists[i], attIndicesWindow, data.classAttribute().isNominal() ? 0.0 : attTotalSubsetWeights[i], random, depth + 1, minVariance);
                }
                boolean emptySuccessor = false;
                for (int i = 0; i < subsets.length; ++i) {
                    if (this.successors[i].m_ClassDistribution != null) continue;
                    emptySuccessor = true;
                    break;
                }
                if (emptySuccessor) {
                    this.m_ClassDistribution = (double[])classProbs.clone();
                }
            } else {
                this.m_Attribute = -1;
                this.m_ClassDistribution = (double[])classProbs.clone();
                if (data.classAttribute().isNumeric()) {
                    this.m_Distribution = new double[2];
                    this.m_Distribution[0] = priorVar;
                    this.m_Distribution[1] = totalWeight;
                }
            }
        }

        public double[] distributionForInstance(Instance instance) throws Exception {
            double[] returnedDist = null;
            if (this.m_Attribute > -1) {
                if (instance.isMissing(this.m_Attribute)) {
                    returnedDist = new double[AccessibleRandomTree.this.m_Info.numClasses()];
                    for (int i = 0; i < this.successors.length; ++i) {
                        double[] help = this.successors[i].distributionForInstance(instance);
                        if (help == null) continue;
                        for (int j = 0; j < help.length; ++j) {
                            int n = j;
                            returnedDist[n] = returnedDist[n] + this.m_Prop[i] * help[j];
                        }
                    }
                } else {
                    returnedDist = AccessibleRandomTree.this.m_Info.attribute(this.m_Attribute).isNominal() ? this.successors[(int)instance.value(this.m_Attribute)].distributionForInstance(instance) : (instance.value(this.m_Attribute) < this.m_SplitPoint ? this.successors[0].distributionForInstance(instance) : this.successors[1].distributionForInstance(instance));
                }
            }
            if (this.m_Attribute == -1 || returnedDist == null) {
                AccessibleRandomTree.this.lastNode = this.leafNodeID;
                if (this.m_ClassDistribution == null) {
                    if (AccessibleRandomTree.this.getAllowUnclassifiedInstances()) {
                        double[] result = new double[AccessibleRandomTree.this.m_Info.numClasses()];
                        if (AccessibleRandomTree.this.m_Info.classAttribute().isNumeric()) {
                            result[0] = Utils.missingValue();
                        }
                        return result;
                    }
                    throw new PredictionException("Could not obtain a prediction.");
                }
                double[] normalizedDistribution = (double[])this.m_ClassDistribution.clone();
                if (AccessibleRandomTree.this.m_Info.classAttribute().isNominal()) {
                    Utils.normalize((double[])normalizedDistribution);
                }
                return normalizedDistribution;
            }
            return returnedDist;
        }

        public AccessibleTree[] getSuccessors() {
            return this.successors;
        }

        public int getAttribute() {
            return super.getM_Attribute();
        }

        public double getSplitPoint() {
            return super.getM_SplitPoint();
        }
    }
}

