/*
 * Decompiled with CFR 0.152.
 */
package smile.classification;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.PriorityQueue;
import java.util.concurrent.Callable;
import smile.classification.Classifier;
import smile.classification.ClassifierTrainer;
import smile.data.Attribute;
import smile.data.NominalAttribute;
import smile.data.NumericAttribute;
import smile.math.Math;
import smile.sort.QuickSort;
import smile.util.MulticoreExecutor;

public class DecisionTree
implements Classifier<double[]> {
    private Attribute[] attributes;
    private double[] importance;
    private Node root;
    private SplitRule rule = SplitRule.GINI;
    private int k = 2;
    private int J = 100;
    private int M;
    private transient int[][] order;

    private double impurity(int[] count, int n) {
        double impurity = 0.0;
        switch (this.rule) {
            case GINI: {
                impurity = 1.0;
                for (int i = 0; i < count.length; ++i) {
                    if (count[i] <= 0) continue;
                    double p = (double)count[i] / (double)n;
                    impurity -= p * p;
                }
                break;
            }
            case ENTROPY: {
                for (int i = 0; i < count.length; ++i) {
                    if (count[i] <= 0) continue;
                    double p = (double)count[i] / (double)n;
                    impurity -= p * Math.log2(p);
                }
                break;
            }
            case CLASSIFICATION_ERROR: {
                impurity = 0.0;
                for (int i = 0; i < count.length; ++i) {
                    if (count[i] <= 0) continue;
                    impurity = Math.max(impurity, (double)count[i] / (double)n);
                }
                impurity = Math.abs(1.0 - impurity);
            }
        }
        return impurity;
    }

    public DecisionTree(double[][] x, int[] y, int J) {
        this(null, x, y, J);
    }

    public DecisionTree(double[][] x, int[] y, int J, SplitRule rule) {
        this(null, x, y, J, rule);
    }

    public DecisionTree(Attribute[] attributes, double[][] x, int[] y, int J) {
        this(attributes, x, y, J, SplitRule.GINI);
    }

    public DecisionTree(Attribute[] attributes, double[][] x, int[] y, int J, SplitRule rule) {
        this(attributes, x, y, J, null, null, rule);
    }

    DecisionTree(Attribute[] attributes, double[][] x, int[] y, int J, int[] samples, int[][] order, SplitRule rule) {
        TrainNode node;
        int i;
        if (x.length != y.length) {
            throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", x.length, y.length));
        }
        if (J < 2) {
            throw new IllegalArgumentException("Invalid maximum leaves: " + J);
        }
        int[] labels = Math.unique(y);
        Arrays.sort(labels);
        for (int i2 = 0; i2 < labels.length; ++i2) {
            if (labels[i2] < 0) {
                throw new IllegalArgumentException("Negative class label: " + labels[i2]);
            }
            if (i2 <= 0 || labels[i2] - labels[i2 - 1] <= 1) continue;
            throw new IllegalArgumentException("Missing class: " + labels[i2] + 1);
        }
        this.k = labels.length;
        if (this.k < 2) {
            throw new IllegalArgumentException("Only one class.");
        }
        if (attributes == null) {
            int p = x[0].length;
            attributes = new Attribute[p];
            for (int i3 = 0; i3 < p; ++i3) {
                attributes[i3] = new NumericAttribute("V" + (i3 + 1));
            }
        }
        this.attributes = attributes;
        this.J = J;
        this.rule = rule;
        this.M = attributes.length;
        this.importance = new double[attributes.length];
        if (order != null) {
            this.order = order;
        } else {
            int n = x.length;
            int p = x[0].length;
            double[] a = new double[n];
            this.order = new int[p][];
            for (int j = 0; j < p; ++j) {
                if (!(attributes[j] instanceof NumericAttribute)) continue;
                for (int i4 = 0; i4 < n; ++i4) {
                    a[i4] = x[i4][j];
                }
                this.order[j] = QuickSort.sort(a);
            }
        }
        PriorityQueue<TrainNode> nextSplits = new PriorityQueue<TrainNode>();
        int n = y.length;
        int[] count = new int[this.k];
        if (samples == null) {
            samples = new int[n];
            for (i = 0; i < n; ++i) {
                samples[i] = 1;
                int n2 = y[i];
                count[n2] = count[n2] + 1;
            }
        } else {
            for (i = 0; i < n; ++i) {
                int n3 = y[i];
                count[n3] = count[n3] + samples[i];
            }
        }
        this.root = new Node(Math.whichMax(count));
        TrainNode trainRoot = new TrainNode(this.root, x, y, samples);
        if (trainRoot.findBestSplit()) {
            nextSplits.add(trainRoot);
        }
        for (int leaves = 1; leaves < this.J && (node = (TrainNode)nextSplits.poll()) != null; ++leaves) {
            node.split(nextSplits);
        }
    }

    DecisionTree(Attribute[] attributes, double[][] x, int[] y, int M, int[] samples, int[][] order) {
        if (x.length != y.length) {
            throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", x.length, y.length));
        }
        if (M <= 0 || M > x[0].length) {
            throw new IllegalArgumentException("Invalid number of variables to split on at a node of the tree: " + M);
        }
        if (samples == null) {
            throw new IllegalArgumentException("Sampling array is null.");
        }
        this.k = Math.max(y) + 1;
        if (this.k < 2) {
            throw new IllegalArgumentException("Only one class or negative class labels.");
        }
        if (attributes == null) {
            int p = x[0].length;
            attributes = new Attribute[p];
            for (int i = 0; i < p; ++i) {
                attributes[i] = new NumericAttribute("V" + (i + 1));
            }
        }
        this.attributes = attributes;
        this.J = Integer.MAX_VALUE;
        this.M = M;
        this.order = order;
        this.importance = new double[attributes.length];
        int n = y.length;
        int[] count = new int[this.k];
        for (int i = 0; i < n; ++i) {
            int n2 = y[i];
            count[n2] = count[n2] + samples[i];
        }
        this.root = new Node(Math.whichMax(count));
        TrainNode trainRoot = new TrainNode(this.root, x, y, samples);
        if (trainRoot.findBestSplit()) {
            trainRoot.split(null);
        }
    }

    public double[] importance() {
        return this.importance;
    }

    @Override
    public int predict(double[] x) {
        return this.root.predict(x);
    }

    @Override
    public int predict(double[] x, double[] posteriori) {
        throw new UnsupportedOperationException("Not supported.");
    }

    class TrainNode
    implements Comparable<TrainNode> {
        Node node;
        double[][] x;
        int[] y;
        int[] samples;

        public TrainNode(Node node, double[][] x, int[] y, int[] samples) {
            this.node = node;
            this.x = x;
            this.y = y;
            this.samples = samples;
        }

        @Override
        public int compareTo(TrainNode a) {
            return (int)Math.signum(a.node.splitScore - this.node.splitScore);
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         * Enabled force condition propagation
         * Lifted jumps to return sites
         */
        public boolean findBestSplit() {
            int N = this.x.length;
            int label = -1;
            boolean pure = true;
            for (int i = 0; i < N; ++i) {
                if (this.samples[i] <= 0) continue;
                if (label == -1) {
                    label = this.y[i];
                    continue;
                }
                if (this.y[i] == label) continue;
                pure = false;
                break;
            }
            if (pure) {
                return false;
            }
            int n = 0;
            int[] count = new int[DecisionTree.this.k];
            int[] falseCount = new int[DecisionTree.this.k];
            for (int i = 0; i < N; ++i) {
                if (this.samples[i] <= 0) continue;
                n += this.samples[i];
                int n2 = this.y[i];
                count[n2] = count[n2] + this.samples[i];
            }
            double impurity = DecisionTree.this.impurity(count, n);
            int p = DecisionTree.this.attributes.length;
            int[] variables = new int[p];
            for (int i = 0; i < p; ++i) {
                variables[i] = i;
            }
            if (DecisionTree.this.M < p) {
                Class<DecisionTree> i = DecisionTree.class;
                synchronized (DecisionTree.class) {
                    Math.permutate(variables);
                    // ** MonitorExit[i] (shouldn't be in output)
                    for (int j = 0; j < DecisionTree.this.M; ++j) {
                        Node split = this.findBestSplit(n, count, falseCount, impurity, variables[j]);
                        if (!(split.splitScore > this.node.splitScore)) continue;
                        this.node.splitFeature = split.splitFeature;
                        this.node.splitValue = split.splitValue;
                        this.node.splitScore = split.splitScore;
                        this.node.trueChildOutput = split.trueChildOutput;
                        this.node.falseChildOutput = split.falseChildOutput;
                    }
                }
            } else {
                ArrayList<SplitTask> tasks = new ArrayList<SplitTask>(DecisionTree.this.M);
                for (int j = 0; j < DecisionTree.this.M; ++j) {
                    tasks.add(new SplitTask(n, count, impurity, variables[j]));
                }
                try {
                    for (Node split : MulticoreExecutor.run(tasks)) {
                        if (!(split.splitScore > this.node.splitScore)) continue;
                        this.node.splitFeature = split.splitFeature;
                        this.node.splitValue = split.splitValue;
                        this.node.splitScore = split.splitScore;
                        this.node.trueChildOutput = split.trueChildOutput;
                        this.node.falseChildOutput = split.falseChildOutput;
                    }
                }
                catch (Exception ex) {
                    for (int j = 0; j < DecisionTree.this.M; ++j) {
                        Node split = this.findBestSplit(n, count, falseCount, impurity, variables[j]);
                        if (!(split.splitScore > this.node.splitScore)) continue;
                        this.node.splitFeature = split.splitFeature;
                        this.node.splitValue = split.splitValue;
                        this.node.splitScore = split.splitScore;
                        this.node.trueChildOutput = split.trueChildOutput;
                        this.node.falseChildOutput = split.falseChildOutput;
                    }
                }
            }
            {
                if (this.node.splitFeature == -1) return false;
                return true;
            }
        }

        public Node findBestSplit(int n, int[] count, int[] falseCount, double impurity, int j) {
            int N = this.x.length;
            Node splitNode = new Node();
            if (((DecisionTree)DecisionTree.this).attributes[j].type == Attribute.Type.NOMINAL) {
                int m = ((NominalAttribute)DecisionTree.this.attributes[j]).size();
                int[][] trueCount = new int[m][DecisionTree.this.k];
                for (int i = 0; i < N; ++i) {
                    if (this.samples[i] <= 0) continue;
                    int[] nArray = trueCount[(int)this.x[i][j]];
                    int n2 = this.y[i];
                    nArray[n2] = nArray[n2] + this.samples[i];
                }
                for (int l = 0; l < m; ++l) {
                    int tc = Math.sum(trueCount[l]);
                    int fc = n - tc;
                    if (tc == 0 || fc == 0) continue;
                    for (int q = 0; q < DecisionTree.this.k; ++q) {
                        falseCount[q] = count[q] - trueCount[l][q];
                    }
                    int trueLabel = Math.whichMax(trueCount[l]);
                    int falseLabel = Math.whichMax(falseCount);
                    double gain = impurity - (double)tc / (double)n * DecisionTree.this.impurity(trueCount[l], tc) - (double)fc / (double)n * DecisionTree.this.impurity(falseCount, fc);
                    if (!(gain > splitNode.splitScore)) continue;
                    splitNode.splitFeature = j;
                    splitNode.splitValue = l;
                    splitNode.splitScore = gain;
                    splitNode.trueChildOutput = trueLabel;
                    splitNode.falseChildOutput = falseLabel;
                }
            } else if (((DecisionTree)DecisionTree.this).attributes[j].type == Attribute.Type.NUMERIC) {
                int[] trueCount = new int[DecisionTree.this.k];
                double prevx = Double.NaN;
                int prevy = -1;
                for (int i : DecisionTree.this.order[j]) {
                    if (this.samples[i] <= 0) continue;
                    if (Double.isNaN(prevx) || this.x[i][j] == prevx || this.y[i] == prevy) {
                        prevx = this.x[i][j];
                        prevy = this.y[i];
                        int n3 = this.y[i];
                        trueCount[n3] = trueCount[n3] + this.samples[i];
                        continue;
                    }
                    int tc = Math.sum(trueCount);
                    int fc = n - tc;
                    if (tc == 0 || fc == 0) {
                        prevx = this.x[i][j];
                        prevy = this.y[i];
                        int n4 = this.y[i];
                        trueCount[n4] = trueCount[n4] + this.samples[i];
                        continue;
                    }
                    for (int l = 0; l < DecisionTree.this.k; ++l) {
                        falseCount[l] = count[l] - trueCount[l];
                    }
                    int trueLabel = Math.whichMax(trueCount);
                    int falseLabel = Math.whichMax(falseCount);
                    double gain = impurity - (double)tc / (double)n * DecisionTree.this.impurity(trueCount, tc) - (double)fc / (double)n * DecisionTree.this.impurity(falseCount, fc);
                    if (gain > splitNode.splitScore) {
                        splitNode.splitFeature = j;
                        splitNode.splitValue = (this.x[i][j] + prevx) / 2.0;
                        splitNode.splitScore = gain;
                        splitNode.trueChildOutput = trueLabel;
                        splitNode.falseChildOutput = falseLabel;
                    }
                    prevx = this.x[i][j];
                    prevy = this.y[i];
                    int n5 = this.y[i];
                    trueCount[n5] = trueCount[n5] + this.samples[i];
                }
            } else {
                throw new IllegalStateException("Unsupported attribute type: " + (Object)((Object)((DecisionTree)DecisionTree.this).attributes[j].type));
            }
            return splitNode;
        }

        public boolean split(PriorityQueue<TrainNode> nextSplits) {
            TrainNode falseChild;
            int i;
            if (this.node.splitFeature < 0) {
                throw new IllegalStateException("Split a node with invalid feature.");
            }
            int n = this.x.length;
            int tc = 0;
            int fc = 0;
            int[] trueSamples = new int[n];
            int[] falseSamples = new int[n];
            if (((DecisionTree)DecisionTree.this).attributes[this.node.splitFeature].type == Attribute.Type.NOMINAL) {
                for (i = 0; i < n; ++i) {
                    if (this.samples[i] <= 0) continue;
                    if (this.x[i][this.node.splitFeature] == this.node.splitValue) {
                        trueSamples[i] = this.samples[i];
                        tc += this.samples[i];
                        continue;
                    }
                    falseSamples[i] = this.samples[i];
                    fc += this.samples[i];
                }
            } else if (((DecisionTree)DecisionTree.this).attributes[this.node.splitFeature].type == Attribute.Type.NUMERIC) {
                for (i = 0; i < n; ++i) {
                    if (this.samples[i] <= 0) continue;
                    if (this.x[i][this.node.splitFeature] <= this.node.splitValue) {
                        trueSamples[i] = this.samples[i];
                        tc += this.samples[i];
                        continue;
                    }
                    falseSamples[i] = this.samples[i];
                    fc += this.samples[i];
                }
            } else {
                throw new IllegalStateException("Unsupported attribute type: " + (Object)((Object)((DecisionTree)DecisionTree.this).attributes[this.node.splitFeature].type));
            }
            if (tc == 0 || fc == 0) {
                this.node.splitFeature = -1;
                this.node.splitValue = Double.NaN;
                this.node.splitScore = 0.0;
                return false;
            }
            this.node.trueChild = new Node(this.node.trueChildOutput);
            this.node.falseChild = new Node(this.node.falseChildOutput);
            TrainNode trueChild = new TrainNode(this.node.trueChild, this.x, this.y, trueSamples);
            if (trueChild.findBestSplit()) {
                if (nextSplits != null) {
                    nextSplits.add(trueChild);
                } else {
                    trueChild.split(null);
                }
            }
            if ((falseChild = new TrainNode(this.node.falseChild, this.x, this.y, falseSamples)).findBestSplit()) {
                if (nextSplits != null) {
                    nextSplits.add(falseChild);
                } else {
                    falseChild.split(null);
                }
            }
            double[] dArray = DecisionTree.this.importance;
            int n2 = this.node.splitFeature;
            dArray[n2] = dArray[n2] + this.node.splitScore;
            return true;
        }

        class SplitTask
        implements Callable<Node> {
            int n;
            int[] count;
            double impurity;
            int j;

            SplitTask(int n, int[] count, double impurity, int j) {
                this.n = n;
                this.count = count;
                this.impurity = impurity;
                this.j = j;
            }

            @Override
            public Node call() {
                int[] falseCount = new int[DecisionTree.this.k];
                return TrainNode.this.findBestSplit(this.n, this.count, falseCount, this.impurity, this.j);
            }
        }
    }

    class Node {
        int output = -1;
        int splitFeature = -1;
        double splitValue = Double.NaN;
        double splitScore = 0.0;
        Node trueChild = null;
        Node falseChild = null;
        int trueChildOutput = -1;
        int falseChildOutput = -1;

        public Node() {
        }

        public Node(int output) {
            this.output = output;
        }

        public int predict(double[] x) {
            if (this.trueChild == null && this.falseChild == null) {
                return this.output;
            }
            if (((DecisionTree)DecisionTree.this).attributes[this.splitFeature].type == Attribute.Type.NOMINAL) {
                if (x[this.splitFeature] == this.splitValue) {
                    return this.trueChild.predict(x);
                }
                return this.falseChild.predict(x);
            }
            if (((DecisionTree)DecisionTree.this).attributes[this.splitFeature].type == Attribute.Type.NUMERIC) {
                if (x[this.splitFeature] <= this.splitValue) {
                    return this.trueChild.predict(x);
                }
                return this.falseChild.predict(x);
            }
            throw new IllegalStateException("Unsupported attribute type: " + (Object)((Object)((DecisionTree)DecisionTree.this).attributes[this.splitFeature].type));
        }
    }

    public static enum SplitRule {
        GINI,
        ENTROPY,
        CLASSIFICATION_ERROR;

    }

    public static class Trainer
    extends ClassifierTrainer<double[]> {
        private SplitRule rule = SplitRule.GINI;
        private int J = 100;

        public Trainer() {
        }

        public Trainer(int J) {
            if (J < 2) {
                throw new IllegalArgumentException("Invalid number of leaf nodes: " + J);
            }
            this.J = J;
        }

        public Trainer(Attribute[] attributes, int J) {
            super(attributes);
            if (J < 2) {
                throw new IllegalArgumentException("Invalid number of leaf nodes: " + J);
            }
            this.J = J;
        }

        public Trainer setSplitRule(SplitRule rule) {
            this.rule = rule;
            return this;
        }

        public Trainer setMaximumLeafNodes(int J) {
            if (J < 2) {
                throw new IllegalArgumentException("Invalid number of leaf nodes: " + J);
            }
            this.J = J;
            return this;
        }

        public DecisionTree train(double[][] x, int[] y) {
            return new DecisionTree(this.attributes, x, y, this.J, this.rule);
        }
    }
}

