/*
 * Decompiled with CFR 0.152.
 */
package smile.classification;

import java.util.Arrays;
import smile.classification.Classifier;
import smile.classification.ClassifierTrainer;
import smile.classification.DecisionTree;
import smile.data.Attribute;
import smile.data.NumericAttribute;
import smile.math.Math;
import smile.util.SmileUtils;
import smile.validation.Accuracy;
import smile.validation.ClassificationMeasure;

public class AdaBoost
implements Classifier<double[]> {
    private int k;
    private DecisionTree[] trees;
    private double[] alpha;
    private double[] error;
    private double[] importance;

    public AdaBoost(double[][] x, int[] y, int T) {
        this(null, x, y, T);
    }

    public AdaBoost(double[][] x, int[] y, int T, int J) {
        this(null, x, y, T, J);
    }

    public AdaBoost(Attribute[] attributes, double[][] x, int[] y, int T) {
        this(attributes, x, y, T, 2);
    }

    public AdaBoost(Attribute[] attributes, double[][] x, int[] y, int T, int J) {
        if (x.length != y.length) {
            throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", x.length, y.length));
        }
        if (T < 1) {
            throw new IllegalArgumentException("Invlaid number of trees: " + T);
        }
        if (J < 2) {
            throw new IllegalArgumentException("Invalid maximum leaves: " + J);
        }
        int[] labels = Math.unique(y);
        Arrays.sort(labels);
        for (int i = 0; i < labels.length; ++i) {
            if (labels[i] < 0) {
                throw new IllegalArgumentException("Negative class label: " + labels[i]);
            }
            if (i <= 0 || labels[i] - labels[i - 1] <= 1) continue;
            throw new IllegalArgumentException("Missing class: " + labels[i] + 1);
        }
        this.k = labels.length;
        if (this.k < 2) {
            throw new IllegalArgumentException("Only one class.");
        }
        if (attributes == null) {
            int p = x[0].length;
            attributes = new Attribute[p];
            for (int i = 0; i < p; ++i) {
                attributes[i] = new NumericAttribute("V" + (i + 1));
            }
        }
        int[][] order = SmileUtils.sort(attributes, x);
        int n = x.length;
        int[] samples = new int[n];
        double[] w = new double[n];
        boolean[] err = new boolean[n];
        for (int i = 0; i < n; ++i) {
            w[i] = 1.0;
        }
        double guess = 1.0 / (double)this.k;
        double b = Math.log(this.k - 1);
        this.trees = new DecisionTree[T];
        this.alpha = new double[T];
        this.error = new double[T];
        for (int t = 0; t < T; ++t) {
            int[] rand;
            double W = Math.sum(w);
            int i = 0;
            while (i < n) {
                int n2 = i++;
                w[n2] = w[n2] / W;
            }
            Arrays.fill(samples, 0);
            int[] nArray = rand = Math.random(w, n);
            int n3 = nArray.length;
            for (int j = 0; j < n3; ++j) {
                int s;
                int n4 = s = nArray[j];
                samples[n4] = samples[n4] + 1;
            }
            this.trees[t] = new DecisionTree(attributes, x, y, J, samples, order, DecisionTree.SplitRule.GINI);
            for (int i2 = 0; i2 < n; ++i2) {
                err[i2] = this.trees[t].predict(x[i2]) != y[i2];
            }
            double e = 0.0;
            for (int i3 = 0; i3 < n; ++i3) {
                if (!err[i3]) continue;
                e += w[i3];
            }
            if (1.0 - e <= guess) {
                System.err.format("Weak classifier %d makes %.2f%% weighted error\n", t, 100.0 * e);
                this.trees = Arrays.copyOf(this.trees, t);
                this.alpha = Arrays.copyOf(this.alpha, t);
                this.error = Arrays.copyOf(this.error, t);
                break;
            }
            this.error[t] = e;
            this.alpha[t] = Math.log((1.0 - e) / Math.max(1.0E-10, e)) + b;
            double a = Math.exp(this.alpha[t]);
            for (int i4 = 0; i4 < n; ++i4) {
                if (!err[i4]) continue;
                int n5 = i4;
                w[n5] = w[n5] * a;
            }
        }
        this.importance = new double[attributes.length];
        for (DecisionTree tree : this.trees) {
            double[] imp = tree.importance();
            for (int i = 0; i < imp.length; ++i) {
                int n6 = i;
                this.importance[n6] = this.importance[n6] + imp[i];
            }
        }
    }

    public double[] importance() {
        return this.importance;
    }

    public int size() {
        return this.trees.length;
    }

    public void trim(int T) {
        if (T > this.trees.length) {
            throw new IllegalArgumentException("The new model size is larger than the current size.");
        }
        if (T <= 0) {
            throw new IllegalArgumentException("Invalid new model size: " + T);
        }
        if (T < this.trees.length) {
            this.trees = Arrays.copyOf(this.trees, T);
            this.alpha = Arrays.copyOf(this.alpha, T);
            this.error = Arrays.copyOf(this.error, T);
        }
    }

    @Override
    public int predict(double[] x) {
        if (this.k == 2) {
            double y = 0.0;
            for (int i = 0; i < this.trees.length; ++i) {
                y += this.alpha[i] * (double)this.trees[i].predict(x);
            }
            return y > 0.0 ? 1 : 0;
        }
        double[] y = new double[this.k];
        for (int i = 0; i < this.trees.length; ++i) {
            int n = this.trees[i].predict(x);
            y[n] = y[n] + this.alpha[i];
        }
        return Math.whichMax(y);
    }

    @Override
    public int predict(double[] x, double[] posteriori) {
        throw new UnsupportedOperationException("Not supported.");
    }

    public double[] test(double[][] x, int[] y) {
        int T = this.trees.length;
        double[] accuracy = new double[T];
        int n = x.length;
        int[] label = new int[n];
        Accuracy measure = new Accuracy();
        if (this.k == 2) {
            double[] prediction = new double[n];
            for (int i = 0; i < T; ++i) {
                for (int j = 0; j < n; ++j) {
                    int n2 = j;
                    prediction[n2] = prediction[n2] + this.alpha[i] * (double)this.trees[i].predict(x[j]);
                    label[j] = prediction[j] > 0.0 ? 1 : 0;
                }
                accuracy[i] = measure.measure(y, label);
            }
        } else {
            double[][] prediction = new double[n][this.k];
            for (int i = 0; i < T; ++i) {
                for (int j = 0; j < n; ++j) {
                    double[] dArray = prediction[j];
                    int n3 = this.trees[i].predict(x[j]);
                    dArray[n3] = dArray[n3] + this.alpha[i];
                    label[j] = Math.whichMax(prediction[j]);
                }
                accuracy[i] = measure.measure(y, label);
            }
        }
        return accuracy;
    }

    public double[][] test(double[][] x, int[] y, ClassificationMeasure[] measures) {
        int T = this.trees.length;
        int m = measures.length;
        double[][] results = new double[T][m];
        int n = x.length;
        int[] label = new int[n];
        if (this.k == 2) {
            double[] prediction = new double[n];
            for (int i = 0; i < T; ++i) {
                int j;
                for (j = 0; j < n; ++j) {
                    int n2 = j;
                    prediction[n2] = prediction[n2] + this.alpha[i] * (double)this.trees[i].predict(x[j]);
                    label[j] = prediction[j] > 0.0 ? 1 : 0;
                }
                for (j = 0; j < m; ++j) {
                    results[i][j] = measures[j].measure(y, label);
                }
            }
        } else {
            double[][] prediction = new double[n][this.k];
            for (int i = 0; i < T; ++i) {
                int j;
                for (j = 0; j < n; ++j) {
                    double[] dArray = prediction[j];
                    int n3 = this.trees[i].predict(x[j]);
                    dArray[n3] = dArray[n3] + this.alpha[i];
                    label[j] = Math.whichMax(prediction[j]);
                }
                for (j = 0; j < m; ++j) {
                    results[i][j] = measures[j].measure(y, label);
                }
            }
        }
        return results;
    }

    public static class Trainer
    extends ClassifierTrainer<double[]> {
        private int T = 500;
        private int J = 2;

        public Trainer() {
        }

        public Trainer(int T) {
            if (T < 1) {
                throw new IllegalArgumentException("Invalid number of trees: " + T);
            }
            this.T = T;
        }

        public Trainer(Attribute[] attributes, int T) {
            super(attributes);
            if (T < 1) {
                throw new IllegalArgumentException("Invalid number of trees: " + T);
            }
            this.T = T;
        }

        public Trainer setNumTrees(int T) {
            if (T < 1) {
                throw new IllegalArgumentException("Invalid number of trees: " + T);
            }
            this.T = T;
            return this;
        }

        public Trainer setMaximumLeafNodes(int J) {
            if (J < 2) {
                throw new IllegalArgumentException("Invalid number of leaf nodes: " + J);
            }
            this.J = J;
            return this;
        }

        public AdaBoost train(double[][] x, int[] y) {
            return new AdaBoost(this.attributes, x, y, this.T, this.J);
        }
    }
}

