/*
 * Decompiled with CFR 0.152.
 */
package smile.regression;

import java.util.Arrays;
import smile.data.Attribute;
import smile.data.NumericAttribute;
import smile.math.Math;
import smile.regression.Regression;
import smile.regression.RegressionTrainer;
import smile.regression.RegressionTree;
import smile.sort.QuickSelect;
import smile.util.SmileUtils;
import smile.validation.RMSE;
import smile.validation.RegressionMeasure;

public class GradientTreeBoost
implements Regression<double[]> {
    private RegressionTree[] trees;
    private double b = 0.0;
    private double[] importance;
    private Loss loss = Loss.LeastAbsoluteDeviation;
    private double shrinkage = 0.005;
    private int J = 6;
    private int T = 500;
    private double f = 0.7;

    public GradientTreeBoost(double[][] x, double[] y, int T) {
        this(null, x, y, T);
    }

    public GradientTreeBoost(double[][] x, double[] y, Loss loss, int T, int J, double shrinkage, double f) {
        this(null, x, y, loss, T, J, shrinkage, f);
    }

    public GradientTreeBoost(Attribute[] attributes, double[][] x, double[] y, int T) {
        this(attributes, x, y, Loss.LeastAbsoluteDeviation, T, 6, x.length < 2000 ? 0.005 : 0.05, 0.7);
    }

    public GradientTreeBoost(Attribute[] attributes, double[][] x, double[] y, Loss loss, int T, int J, double shrinkage, double f) {
        int i;
        if (x.length != y.length) {
            throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", x.length, y.length));
        }
        if (shrinkage <= 0.0 || shrinkage > 1.0) {
            throw new IllegalArgumentException("Invalid shrinkage: " + shrinkage);
        }
        if (f <= 0.0 || f > 1.0) {
            throw new IllegalArgumentException("Invalid sampling fraction: " + f);
        }
        if (attributes == null) {
            int p = x[0].length;
            attributes = new Attribute[p];
            for (int i2 = 0; i2 < p; ++i2) {
                attributes[i2] = new NumericAttribute("V" + (i2 + 1));
            }
        }
        this.loss = loss;
        this.T = T;
        this.J = J;
        this.shrinkage = shrinkage;
        this.f = f;
        int n = x.length;
        int N = (int)Math.round((double)n * f);
        int[] perm = new int[n];
        int[] samples = new int[n];
        for (int i3 = 0; i3 < n; ++i3) {
            perm[i3] = i3;
        }
        double[] residual = new double[n];
        double[] response = null;
        RegressionTree.NodeOutput output = null;
        if (loss == Loss.LeastSquares) {
            response = residual;
            this.b = Math.mean(y);
            for (i = 0; i < n; ++i) {
                residual[i] = y[i] - this.b;
            }
        } else if (loss == Loss.LeastAbsoluteDeviation) {
            output = new LADNodeOutput(residual);
            System.arraycopy(y, 0, residual, 0, n);
            this.b = QuickSelect.median(residual);
            response = new double[n];
            for (i = 0; i < n; ++i) {
                residual[i] = y[i] - this.b;
                response[i] = Math.signum(residual[i]);
            }
        } else if (loss == Loss.Huber) {
            response = new double[n];
            System.arraycopy(y, 0, residual, 0, n);
            this.b = QuickSelect.median(residual);
            for (i = 0; i < n; ++i) {
                residual[i] = y[i] - this.b;
            }
        }
        int[][] order = SmileUtils.sort(attributes, x);
        this.trees = new RegressionTree[T];
        for (int m = 0; m < T; ++m) {
            int i4;
            Arrays.fill(samples, 0);
            Math.permutate(perm);
            for (i4 = 0; i4 < N; ++i4) {
                samples[perm[i4]] = 1;
            }
            if (loss == Loss.Huber) {
                output = new HuberNodeOutput(residual, response, 0.9);
            }
            this.trees[m] = new RegressionTree(attributes, x, response, J, order, samples, output);
            for (i4 = 0; i4 < n; ++i4) {
                int n2 = i4;
                residual[n2] = residual[n2] - shrinkage * this.trees[m].predict(x[i4]);
                if (loss != Loss.LeastAbsoluteDeviation) continue;
                response[i4] = Math.signum(residual[i4]);
            }
        }
        this.importance = new double[attributes.length];
        for (RegressionTree tree : this.trees) {
            double[] imp = tree.importance();
            for (int i5 = 0; i5 < imp.length; ++i5) {
                int n3 = i5;
                this.importance[n3] = this.importance[n3] + imp[i5];
            }
        }
    }

    public double[] importance() {
        return this.importance;
    }

    public double getSamplingRate() {
        return this.f;
    }

    public int getNumLeaves() {
        return this.J;
    }

    public Loss getLossFunction() {
        return this.loss;
    }

    public int size() {
        return this.trees.length;
    }

    public void trim(int T) {
        if (T > this.trees.length) {
            throw new IllegalArgumentException("The new model size is larger than the current size.");
        }
        if (T <= 0) {
            throw new IllegalArgumentException("Invalid new model size: " + T);
        }
        if (T < this.trees.length) {
            this.trees = Arrays.copyOf(this.trees, T);
        }
    }

    @Override
    public double predict(double[] x) {
        double y = this.b;
        for (int i = 0; i < this.T; ++i) {
            y += this.shrinkage * this.trees[i].predict(x);
        }
        return y;
    }

    public double[] test(double[][] x, double[] y) {
        double[] rmse = new double[this.T];
        int n = x.length;
        double[] prediction = new double[n];
        Arrays.fill(prediction, this.b);
        RMSE measure = new RMSE();
        for (int i = 0; i < this.T; ++i) {
            for (int j = 0; j < n; ++j) {
                int n2 = j;
                prediction[n2] = prediction[n2] + this.shrinkage * this.trees[i].predict(x[j]);
            }
            rmse[i] = measure.measure(y, prediction);
        }
        return rmse;
    }

    public double[][] test(double[][] x, double[] y, RegressionMeasure[] measures) {
        int m = measures.length;
        double[][] results = new double[this.T][m];
        int n = x.length;
        double[] prediction = new double[n];
        Arrays.fill(prediction, this.b);
        for (int i = 0; i < this.T; ++i) {
            int j;
            for (j = 0; j < n; ++j) {
                int n2 = j;
                prediction[n2] = prediction[n2] + this.shrinkage * this.trees[i].predict(x[j]);
            }
            for (j = 0; j < m; ++j) {
                results[i][j] = measures[j].measure(y, prediction);
            }
        }
        return results;
    }

    class HuberNodeOutput
    implements RegressionTree.NodeOutput {
        double[] residual;
        double[] response;
        double alpha;
        double delta;

        public HuberNodeOutput(double[] residual, double[] response, double alpha) {
            int i;
            this.residual = residual;
            this.response = response;
            this.alpha = alpha;
            int n = residual.length;
            for (i = 0; i < n; ++i) {
                response[i] = Math.abs(residual[i]);
            }
            this.delta = QuickSelect.select(response, (int)((double)n * alpha));
            for (i = 0; i < n; ++i) {
                response[i] = Math.abs(residual[i]) <= this.delta ? residual[i] : this.delta * Math.signum(residual[i]);
            }
        }

        @Override
        public double calculate(int[] samples) {
            int n = 0;
            for (int s : samples) {
                if (s <= 0) continue;
                ++n;
            }
            double[] res = new double[n];
            int j = 0;
            for (int i = 0; i < samples.length; ++i) {
                if (samples[i] <= 0) continue;
                res[j++] = this.residual[i];
            }
            double r = QuickSelect.median(res);
            double output = 0.0;
            for (int i = 0; i < samples.length; ++i) {
                if (samples[i] <= 0) continue;
                double d = this.residual[i] - r;
                output += Math.signum(d) * Math.min(this.delta, Math.abs(d));
            }
            output = r + output / (double)n;
            return output;
        }
    }

    class LADNodeOutput
    implements RegressionTree.NodeOutput {
        double[] residual;

        public LADNodeOutput(double[] residual) {
            this.residual = residual;
        }

        @Override
        public double calculate(int[] samples) {
            int n = 0;
            for (int s : samples) {
                if (s <= 0) continue;
                ++n;
            }
            double[] r = new double[n];
            int j = 0;
            for (int i = 0; i < samples.length; ++i) {
                if (samples[i] <= 0) continue;
                r[j++] = this.residual[i];
            }
            return QuickSelect.median(r);
        }
    }

    public static class Trainer
    extends RegressionTrainer<double[]> {
        private Loss loss = Loss.LeastAbsoluteDeviation;
        private int T = 500;
        private double shrinkage = 0.005;
        private int J = 6;
        private double f = 0.7;

        public Trainer(int T) {
            if (T < 1) {
                throw new IllegalArgumentException("Invlaid number of trees: " + T);
            }
            this.T = T;
        }

        public Trainer(Attribute[] attributes, int T) {
            super(attributes);
            if (T < 1) {
                throw new IllegalArgumentException("Invlaid number of trees: " + T);
            }
            this.T = T;
        }

        public Trainer setLoss(Loss loss) {
            this.loss = loss;
            return this;
        }

        public Trainer setNumTrees(int T) {
            if (T < 1) {
                throw new IllegalArgumentException("Invlaid number of trees: " + T);
            }
            this.T = T;
            return this;
        }

        public Trainer setMaximumLeafNodes(int J) {
            if (J < 2) {
                throw new IllegalArgumentException("Invalid number of leaf nodes: " + J);
            }
            this.J = J;
            return this;
        }

        public Trainer setShrinkage(double shrinkage) {
            if (shrinkage <= 0.0 || shrinkage > 1.0) {
                throw new IllegalArgumentException("Invalid shrinkage: " + shrinkage);
            }
            this.shrinkage = shrinkage;
            return this;
        }

        public Trainer setSamplingRates(double f) {
            if (f <= 0.0 || f > 1.0) {
                throw new IllegalArgumentException("Invalid sampling fraction: " + f);
            }
            this.f = f;
            return this;
        }

        public GradientTreeBoost train(double[][] x, double[] y) {
            return new GradientTreeBoost(this.attributes, x, y, this.loss, this.T, this.J, this.shrinkage, this.f);
        }
    }

    public static enum Loss {
        LeastSquares,
        LeastAbsoluteDeviation,
        Huber;

    }
}

