/*
 * Decompiled with CFR 0.152.
 */
package smile.classification;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.Callable;
import smile.classification.Classifier;
import smile.classification.ClassifierTrainer;
import smile.classification.DecisionTree;
import smile.data.Attribute;
import smile.data.NumericAttribute;
import smile.math.Math;
import smile.math.Random;
import smile.util.MulticoreExecutor;
import smile.util.SmileUtils;
import smile.validation.Accuracy;
import smile.validation.ClassificationMeasure;

public class RandomForest
implements Classifier<double[]> {
    private List<DecisionTree> trees;
    private int k = 2;
    private double error;
    private double[] importance;

    public RandomForest(double[][] x, int[] y, int T) {
        this(null, x, y, T);
    }

    public RandomForest(double[][] x, int[] y, int T, int M) {
        this(null, x, y, T, M);
    }

    public RandomForest(Attribute[] attributes, double[][] x, int[] y, int T) {
        this(attributes, x, y, T, (int)Math.floor(Math.sqrt(x[0].length)));
    }

    public RandomForest(Attribute[] attributes, double[][] x, int[] y, int T, int M) {
        int i;
        if (x.length != y.length) {
            throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", x.length, y.length));
        }
        if (T < 1) {
            throw new IllegalArgumentException("Invlaid number of trees: " + T);
        }
        if (M < 1) {
            throw new IllegalArgumentException("Invalid number of variables for splitting: " + M);
        }
        int[] labels = Math.unique(y);
        Arrays.sort(labels);
        for (int i2 = 0; i2 < labels.length; ++i2) {
            if (labels[i2] < 0) {
                throw new IllegalArgumentException("Negative class label: " + labels[i2]);
            }
            if (i2 <= 0 || labels[i2] - labels[i2 - 1] <= 1) continue;
            throw new IllegalArgumentException("Missing class: " + labels[i2] + 1);
        }
        this.k = labels.length;
        if (this.k < 2) {
            throw new IllegalArgumentException("Only one class.");
        }
        if (attributes == null) {
            int p = x[0].length;
            attributes = new Attribute[p];
            for (int i3 = 0; i3 < p; ++i3) {
                attributes[i3] = new NumericAttribute("V" + (i3 + 1));
            }
        }
        int n = x.length;
        int[][] prediction = new int[n][this.k];
        int[][] order = SmileUtils.sort(attributes, x);
        ArrayList<TrainingTask> tasks = new ArrayList<TrainingTask>();
        for (int i4 = 0; i4 < T; ++i4) {
            tasks.add(new TrainingTask(attributes, x, y, M, order, prediction));
        }
        try {
            this.trees = MulticoreExecutor.run(tasks);
        }
        catch (Exception ex) {
            System.err.println(ex);
            this.trees = new ArrayList<DecisionTree>(T);
            for (i = 0; i < T; ++i) {
                this.trees.add(((TrainingTask)tasks.get(i)).call());
            }
        }
        int m = 0;
        for (i = 0; i < n; ++i) {
            int pred = Math.whichMax(prediction[i]);
            if (prediction[i][pred] <= 0) continue;
            ++m;
            if (pred == y[i]) continue;
            this.error += 1.0;
        }
        if (m > 0) {
            this.error /= (double)m;
        }
        this.importance = new double[attributes.length];
        for (DecisionTree tree : this.trees) {
            double[] imp = tree.importance();
            for (int i5 = 0; i5 < imp.length; ++i5) {
                int n2 = i5;
                this.importance[n2] = this.importance[n2] + imp[i5];
            }
        }
    }

    public double error() {
        return this.error;
    }

    public double[] importance() {
        return this.importance;
    }

    public int size() {
        return this.trees.size();
    }

    public void trim(int T) {
        if (T > this.trees.size()) {
            throw new IllegalArgumentException("The new model size is larger than the current size.");
        }
        if (T <= 0) {
            throw new IllegalArgumentException("Invalid new model size: " + T);
        }
        ArrayList<DecisionTree> model = new ArrayList<DecisionTree>(T);
        for (int i = 0; i < T; ++i) {
            model.add(this.trees.get(i));
        }
        this.trees = model;
    }

    @Override
    public int predict(double[] x) {
        int[] y = new int[this.k];
        for (DecisionTree tree : this.trees) {
            int n = tree.predict(x);
            y[n] = y[n] + 1;
        }
        return Math.whichMax(y);
    }

    @Override
    public int predict(double[] x, double[] posteriori) {
        if (posteriori.length != this.k) {
            throw new IllegalArgumentException(String.format("Invalid posteriori vector size: %d, expected: %d", posteriori.length, this.k));
        }
        int[] y = new int[this.k];
        for (DecisionTree tree : this.trees) {
            int n = tree.predict(x);
            y[n] = y[n] + 1;
        }
        double n = this.trees.size();
        for (int i = 0; i < this.k; ++i) {
            posteriori[i] = (double)y[i] / n;
        }
        return Math.whichMax(y);
    }

    public double[] test(double[][] x, int[] y) {
        int T = this.trees.size();
        double[] accuracy = new double[T];
        int n = x.length;
        int[] label = new int[n];
        int[][] prediction = new int[n][this.k];
        Accuracy measure = new Accuracy();
        for (int i = 0; i < T; ++i) {
            for (int j = 0; j < n; ++j) {
                int[] nArray = prediction[j];
                int n2 = this.trees.get(i).predict(x[j]);
                nArray[n2] = nArray[n2] + 1;
                label[j] = Math.whichMax(prediction[j]);
            }
            accuracy[i] = measure.measure(y, label);
        }
        return accuracy;
    }

    public double[][] test(double[][] x, int[] y, ClassificationMeasure[] measures) {
        int T = this.trees.size();
        int m = measures.length;
        double[][] results = new double[T][m];
        int n = x.length;
        int[] label = new int[n];
        double[][] prediction = new double[n][this.k];
        for (int i = 0; i < T; ++i) {
            int j;
            for (j = 0; j < n; ++j) {
                double[] dArray = prediction[j];
                int n2 = this.trees.get(i).predict(x[j]);
                dArray[n2] = dArray[n2] + 1.0;
                label[j] = Math.whichMax(prediction[j]);
            }
            for (j = 0; j < m; ++j) {
                results[i][j] = measures[j].measure(y, label);
            }
        }
        return results;
    }

    static class TrainingTask
    implements Callable<DecisionTree> {
        Attribute[] attributes;
        double[][] x;
        int[] y;
        int[][] order;
        int M;
        int[][] prediction;

        TrainingTask(Attribute[] attributes, double[][] x, int[] y, int M, int[][] order, int[][] prediction) {
            this.attributes = attributes;
            this.x = x;
            this.y = y;
            this.order = order;
            this.M = M;
            this.prediction = prediction;
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        @Override
        public DecisionTree call() {
            int n = this.x.length;
            Random random = new Random(Thread.currentThread().getId() * System.currentTimeMillis());
            int[] samples = new int[n];
            for (int i = 0; i < n; ++i) {
                int n2 = random.nextInt(n);
                samples[n2] = samples[n2] + 1;
            }
            DecisionTree tree = new DecisionTree(this.attributes, this.x, this.y, this.M, samples, this.order);
            for (int i = 0; i < n; ++i) {
                if (samples[i] != 0) continue;
                int p = tree.predict(this.x[i]);
                int[] nArray = this.prediction[i];
                synchronized (nArray) {
                    int[] nArray2 = this.prediction[i];
                    int n3 = p;
                    nArray2[n3] = nArray2[n3] + 1;
                    continue;
                }
            }
            return tree;
        }
    }

    public static class Trainer
    extends ClassifierTrainer<double[]> {
        private int T = 500;
        private int M = -1;

        public Trainer() {
        }

        public Trainer(int T) {
            if (T < 1) {
                throw new IllegalArgumentException("Invlaid number of trees: " + T);
            }
            this.T = T;
        }

        public Trainer(Attribute[] attributes, int T) {
            super(attributes);
            if (T < 1) {
                throw new IllegalArgumentException("Invlaid number of trees: " + T);
            }
            this.T = T;
        }

        public Trainer setNumTrees(int T) {
            if (T < 1) {
                throw new IllegalArgumentException("Invlaid number of trees: " + T);
            }
            this.T = T;
            return this;
        }

        public Trainer setNumRandomFeatures(int M) {
            if (M < 1) {
                throw new IllegalArgumentException("Invalid number of random selected features for splitting: " + M);
            }
            this.M = M;
            return this;
        }

        public RandomForest train(double[][] x, int[] y) {
            if (this.M < 0) {
                return new RandomForest(this.attributes, x, y, this.T);
            }
            return new RandomForest(this.attributes, x, y, this.T, this.M);
        }
    }
}

