/*
 * Decompiled with CFR 0.152.
 */
package hivemall.smile.classification;

import hivemall.UDTFWithOptions;
import hivemall.smile.regression.RegressionTree;
import hivemall.smile.utils.SmileExtUtils;
import hivemall.utils.codec.Base91;
import hivemall.utils.collections.lists.IntArrayList;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.hadoop.SerdeUtils;
import hivemall.utils.hadoop.WritableUtils;
import hivemall.utils.lang.Primitives;
import hivemall.utils.math.MathUtils;
import hivemall.utils.random.PRNG;
import hivemall.utils.random.RandomNumberGeneratorFactory;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import matrix4j.matrix.Matrix;
import matrix4j.matrix.builders.CSRMatrixBuilder;
import matrix4j.matrix.builders.MatrixBuilder;
import matrix4j.matrix.builders.RowMajorDenseMatrixBuilder;
import matrix4j.vector.AbstractVector;
import matrix4j.vector.DenseVector;
import matrix4j.vector.SparseVector;
import matrix4j.vector.Vector;
import matrix4j.vector.VectorProcedure;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.io.FloatWritable;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapred.Counters;
import org.apache.hadoop.mapred.Reporter;
import org.roaringbitmap.RoaringBitmap;
import smile.math.Math;

@Description(name="train_gradient_tree_boosting_classifier", value="_FUNC_(array<double|string> features, int label [, string options]) - Returns a relation consists of <int iteration, int model_type, array<string> pred_models, double intercept, double shrinkage, array<double> var_importance, float oob_error_rate>")
public final class GradientTreeBoostingClassifierUDTF
extends UDTFWithOptions {
    private static final Log logger = LogFactory.getLog(GradientTreeBoostingClassifierUDTF.class);
    private ListObjectInspector featureListOI;
    private PrimitiveObjectInspector featureElemOI;
    private PrimitiveObjectInspector labelOI;
    private boolean denseInput;
    private MatrixBuilder matrixBuilder;
    private IntArrayList labels;
    private int _numTrees;
    private double _eta;
    private double _subsample = 0.7;
    private float _numVars;
    private int _maxDepth;
    private int _maxLeafNodes;
    private int _minSamplesSplit;
    private int _minSamplesLeaf;
    private long _seed;
    private byte[] _nominalAttrs;
    @Nullable
    private transient Reporter _progressReporter;
    @Nullable
    private transient Counters.Counter _iterationCounter;

    @Override
    protected Options getOptions() {
        Options opts = new Options();
        opts.addOption("trees", "num_trees", true, "The number of trees for each task [default: 500]");
        opts.addOption("eta", "learning_rate", true, "The learning rate (0, 1]  of procedure [default: 0.05]");
        opts.addOption("subsample", "sampling_frac", true, "The fraction of samples to be used for fitting the individual base learners [default: 0.7]");
        opts.addOption("vars", "num_variables", true, "The number of random selected features [default: ceil(sqrt(x[0].length))]. int(num_variables * x[0].length) is considered if num_variable is (0,1]");
        opts.addOption("depth", "max_depth", true, "The maximum number of the tree depth [default: 8]");
        opts.addOption("leafs", "max_leaf_nodes", true, "The maximum number of leaf nodes [default: Integer.MAX_VALUE]");
        opts.addOption("splits", "min_split", true, "A node that has greater than or equals to `min_split` examples will split [default: 5]");
        opts.addOption("min_samples_leaf", true, "The minimum number of samples in a leaf node [default: 1]");
        opts.addOption("seed", true, "seed value in long [default: -1 (random)]");
        opts.addOption("attrs", "attribute_types", true, "Comma separated attribute types (Q for quantitative variable and C for categorical variable. e.g., [Q,C,Q,C])");
        opts.addOption("nominal_attr_indicies", "categorical_attr_indicies", true, "Comma seperated indicies of categorical attributes, e.g., [3,5,6]");
        return opts;
    }

    @Override
    protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException {
        int trees = 500;
        int maxDepth = 8;
        int maxLeafNodes = Integer.MAX_VALUE;
        int minSamplesSplit = 5;
        int minSamplesLeaf = 1;
        float numVars = -1.0f;
        double eta = 0.05;
        double subsample = 0.7;
        RoaringBitmap attrs = new RoaringBitmap();
        long seed = -1L;
        CommandLine cl = null;
        if (argOIs.length >= 3) {
            String rawArgs = HiveUtils.getConstString(argOIs[2]);
            cl = this.parseOptions(rawArgs);
            trees = Primitives.parseInt(cl.getOptionValue("num_trees"), trees);
            if (trees < 1) {
                throw new IllegalArgumentException("Invalid number of trees: " + trees);
            }
            eta = Primitives.parseDouble(cl.getOptionValue("learning_rate"), eta);
            subsample = Primitives.parseDouble(cl.getOptionValue("subsample"), subsample);
            numVars = Primitives.parseFloat(cl.getOptionValue("num_variables"), numVars);
            maxDepth = Primitives.parseInt(cl.getOptionValue("max_depth"), maxDepth);
            maxLeafNodes = Primitives.parseInt(cl.getOptionValue("max_leaf_nodes"), maxLeafNodes);
            String min_samples_split = cl.getOptionValue("min_samples_split");
            minSamplesSplit = min_samples_split == null ? Primitives.parseInt(cl.getOptionValue("min_split"), minSamplesSplit) : Integer.parseInt(min_samples_split);
            minSamplesLeaf = Primitives.parseInt(cl.getOptionValue("min_samples_leaf"), minSamplesLeaf);
            seed = Primitives.parseLong(cl.getOptionValue("seed"), seed);
            String nominal_attr_indicies = cl.getOptionValue("nominal_attr_indicies");
            attrs = nominal_attr_indicies != null ? SmileExtUtils.parseNominalAttributeIndicies(nominal_attr_indicies) : SmileExtUtils.resolveAttributes(cl.getOptionValue("attribute_types"));
        }
        this._numTrees = trees;
        this._eta = eta;
        this._subsample = subsample;
        this._numVars = numVars;
        this._maxDepth = maxDepth;
        this._maxLeafNodes = maxLeafNodes;
        this._minSamplesSplit = minSamplesSplit;
        this._minSamplesLeaf = minSamplesLeaf;
        this._seed = seed;
        this._nominalAttrs = SerdeUtils.serializeRoaring(attrs);
        return cl;
    }

    public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
        if (argOIs.length != 2 && argOIs.length != 3) {
            throw new UDFArgumentException(((Object)((Object)this)).getClass().getSimpleName() + " takes 2 or 3 arguments: array<double|string> features, int label [, const string options]: " + argOIs.length);
        }
        ListObjectInspector listOI = HiveUtils.asListOI(argOIs, 0);
        ObjectInspector elemOI = listOI.getListElementObjectInspector();
        this.featureListOI = listOI;
        if (HiveUtils.isNumberOI(elemOI)) {
            this.featureElemOI = HiveUtils.asDoubleCompatibleOI(elemOI);
            this.denseInput = true;
            this.matrixBuilder = new RowMajorDenseMatrixBuilder(8192);
        } else if (HiveUtils.isStringOI(elemOI)) {
            this.featureElemOI = HiveUtils.asStringOI(elemOI);
            this.denseInput = false;
            this.matrixBuilder = new CSRMatrixBuilder(8192);
        } else {
            throw new UDFArgumentException("_FUNC_ takes double[] or string[] for the first argument: " + listOI.getTypeName());
        }
        this.labelOI = HiveUtils.asIntCompatibleOI(argOIs, 1);
        this.processOptions(argOIs);
        this.labels = new IntArrayList(1024);
        ArrayList<String> fieldNames = new ArrayList<String>(6);
        ArrayList<Object> fieldOIs = new ArrayList<Object>(6);
        fieldNames.add("iteration");
        fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
        fieldNames.add("pred_models");
        fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector((ObjectInspector)PrimitiveObjectInspectorFactory.writableStringObjectInspector));
        fieldNames.add("intercept");
        fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
        fieldNames.add("shrinkage");
        fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
        fieldNames.add("var_importance");
        if (this.denseInput) {
            fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector((ObjectInspector)PrimitiveObjectInspectorFactory.writableDoubleObjectInspector));
        } else {
            fieldOIs.add(ObjectInspectorFactory.getStandardMapObjectInspector((ObjectInspector)PrimitiveObjectInspectorFactory.writableIntObjectInspector, (ObjectInspector)PrimitiveObjectInspectorFactory.writableDoubleObjectInspector));
        }
        fieldNames.add("oob_error_rate");
        fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
        return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
    }

    public void process(Object[] args) throws HiveException {
        if (args[0] == null) {
            throw new HiveException("array<double> features was null");
        }
        this.parseFeatures(args[0], this.matrixBuilder);
        int label = PrimitiveObjectInspectorUtils.getInt((Object)args[1], (PrimitiveObjectInspector)this.labelOI);
        this.labels.add(label);
    }

    private void parseFeatures(@Nonnull Object argObj, @Nonnull MatrixBuilder builder) {
        if (this.denseInput) {
            int length = this.featureListOI.getListLength(argObj);
            for (int i = 0; i < length; ++i) {
                Object o = this.featureListOI.getListElement(argObj, i);
                if (o == null) continue;
                double v = PrimitiveObjectInspectorUtils.getDouble((Object)o, (PrimitiveObjectInspector)this.featureElemOI);
                builder.nextColumn(i, v);
            }
        } else {
            int length = this.featureListOI.getListLength(argObj);
            for (int i = 0; i < length; ++i) {
                Object o = this.featureListOI.getListElement(argObj, i);
                if (o == null) continue;
                String fv = o.toString();
                builder.nextColumn(fv);
            }
        }
        builder.nextRow();
    }

    public void close() throws HiveException {
        this._progressReporter = this.getReporter();
        this._iterationCounter = this._progressReporter == null ? null : this._progressReporter.getCounter("hivemall.smile.GradientTreeBoostingClassifier$Counter", "iteration");
        GradientTreeBoostingClassifierUDTF.reportProgress(this._progressReporter);
        if (!this.labels.isEmpty()) {
            Matrix x = this.matrixBuilder.buildMatrix();
            this.matrixBuilder = null;
            int[] y = this.labels.toArray();
            this.labels = null;
            this.train(x, y);
        }
        this.featureListOI = null;
        this.featureElemOI = null;
        this.labelOI = null;
    }

    private void checkOptions() throws HiveException {
        if (this._eta <= 0.0 || this._eta > 1.0) {
            throw new HiveException("Invalid shrinkage: " + this._eta);
        }
        if (this._subsample <= 0.0 || this._subsample > 1.0) {
            throw new HiveException("Invalid sampling fraction: " + this._subsample);
        }
        if (this._minSamplesSplit <= 0) {
            throw new HiveException("Invalid minSamplesSplit: " + this._minSamplesSplit);
        }
        if (this._maxDepth < 1) {
            throw new HiveException("Invalid maxDepth: " + this._maxDepth);
        }
    }

    private void train(@Nonnull Matrix x, @Nonnull int[] y) throws HiveException {
        int numRows = x.numRows();
        if (numRows != y.length) {
            throw new HiveException(String.format("The sizes of X and Y don't match: %d != %d", numRows, y.length));
        }
        this.checkOptions();
        x = SmileExtUtils.shuffle(x, y, this._seed);
        int k = Math.max(y) + 1;
        if (k < 2) {
            throw new UDFArgumentException("Only one class or negative class labels.");
        }
        if (k == 2) {
            int[] y2 = new int[numRows];
            for (int i = 0; i < numRows; ++i) {
                y2[i] = y[i] == 1 ? 1 : -1;
            }
            this.train2(x, y2);
        } else {
            this.traink(x, y, k);
        }
    }

    private void train2(@Nonnull Matrix x, @Nonnull int[] y) throws HiveException {
        int numVars = SmileExtUtils.computeNumInputVars(this._numVars, x);
        if (logger.isInfoEnabled()) {
            logger.info((Object)("k: 2, numTrees: " + this._numTrees + ", shrinkage: " + this._eta + ", subsample: " + this._subsample + ", numVars: " + numVars + ", maxDepth: " + this._maxDepth + ", minSamplesSplit: " + this._minSamplesSplit + ", maxLeafs: " + this._maxLeafNodes + ", seed: " + this._seed));
        }
        int numInstances = x.numRows();
        int numSamples = (int)java.lang.Math.round((double)numInstances * this._subsample);
        double[] h = new double[numInstances];
        double[] response = new double[numInstances];
        double mu = Math.mean(y);
        double intercept = 0.5 * java.lang.Math.log((1.0 + mu) / (1.0 - mu));
        for (int i = 0; i < numInstances; ++i) {
            h[i] = intercept;
        }
        L2NodeOutput output = new L2NodeOutput(response);
        int[] samples = new int[numInstances];
        int[] perm = MathUtils.permutation(numInstances);
        long s = this._seed == -1L ? SmileExtUtils.generateSeed() : RandomNumberGeneratorFactory.createPRNG(this._seed).nextLong();
        PRNG rnd1 = RandomNumberGeneratorFactory.createPRNG(s);
        PRNG rnd2 = RandomNumberGeneratorFactory.createPRNG(rnd1.nextLong());
        RoaringBitmap nominalAttrs = SerdeUtils.deserializeRoaring(this._nominalAttrs);
        this._nominalAttrs = null;
        Vector xProbe = x.rowVector();
        for (int m = 0; m < this._numTrees; ++m) {
            int i;
            GradientTreeBoostingClassifierUDTF.reportProgress(this._progressReporter);
            Arrays.fill(samples, 0);
            SmileExtUtils.shuffle(perm, rnd1);
            for (i = 0; i < numSamples; ++i) {
                int index;
                int n = index = perm[i];
                samples[n] = samples[n] + 1;
            }
            for (i = 0; i < numInstances; ++i) {
                response[i] = 2.0 * (double)y[i] / (1.0 + java.lang.Math.exp(2.0 * (double)y[i] * h[i]));
            }
            RegressionTree tree = new RegressionTree(nominalAttrs, x, response, numVars, this._maxDepth, this._maxLeafNodes, this._minSamplesSplit, this._minSamplesLeaf, samples, output, rnd2);
            int i2 = 0;
            while (i2 < numInstances) {
                x.getRow(i2, xProbe);
                int n = i2++;
                h[n] = h[n] + this._eta * tree.predict(xProbe);
            }
            int oobTests = 0;
            int oobErrors = 0;
            for (int i3 = 0; i3 < samples.length; ++i3) {
                int pred;
                if (samples[i3] != 0) continue;
                ++oobTests;
                int n = pred = h[i3] > 0.0 ? 1 : 0;
                if (pred == y[i3]) continue;
                ++oobErrors;
            }
            float oobErrorRate = 0.0f;
            if (oobTests > 0) {
                oobErrorRate = (float)oobErrors / (float)oobTests;
            }
            this.forward(m + 1, intercept, this._eta, oobErrorRate, x.numColumns(), tree);
        }
    }

    private void traink(Matrix x, int[] y, int k) throws HiveException {
        int numVars = SmileExtUtils.computeNumInputVars(this._numVars, x);
        if (logger.isInfoEnabled()) {
            logger.info((Object)("k: " + k + ", numTrees: " + this._numTrees + ", shrinkage: " + this._eta + ", subsample: " + this._subsample + ", numVars: " + numVars + ", minSamplesSplit: " + this._minSamplesSplit + ", maxDepth: " + this._maxDepth + ", maxLeafs: " + this._maxLeafNodes + ", seed: " + this._seed));
        }
        int numInstances = x.numRows();
        int numSamples = (int)java.lang.Math.round((double)numInstances * this._subsample);
        double[][] h = new double[k][numInstances];
        double[][] p = new double[k][numInstances];
        double[][] response = new double[k][numInstances];
        LKNodeOutput[] output = new LKNodeOutput[k];
        for (int i = 0; i < k; ++i) {
            output[i] = new LKNodeOutput(response[i], k);
        }
        int[] samples = new int[numInstances];
        int[] perm = MathUtils.permutation(numInstances);
        long s = this._seed == -1L ? SmileExtUtils.generateSeed() : RandomNumberGeneratorFactory.createPRNG(this._seed).nextLong();
        PRNG rnd1 = RandomNumberGeneratorFactory.createPRNG(s);
        PRNG rnd2 = RandomNumberGeneratorFactory.createPRNG(rnd1.nextLong());
        RoaringBitmap nominalAttrs = SerdeUtils.deserializeRoaring(this._nominalAttrs);
        this._nominalAttrs = null;
        int[] prediction = new int[numInstances];
        Vector xProbe = x.rowVector();
        for (int m = 0; m < this._numTrees; ++m) {
            int j;
            for (int i = 0; i < numInstances; ++i) {
                double max = Double.NEGATIVE_INFINITY;
                for (int j2 = 0; j2 < k; ++j2) {
                    double h_ji = h[j2][i];
                    if (!(max < h_ji)) continue;
                    max = h_ji;
                }
                double Z = 0.0;
                for (j = 0; j < k; ++j) {
                    double p_ji;
                    p[j][i] = p_ji = java.lang.Math.exp(h[j][i] - max);
                    Z += p_ji;
                }
                for (j = 0; j < k; ++j) {
                    double[] dArray = p[j];
                    int n = i;
                    dArray[n] = dArray[n] / Z;
                }
            }
            RegressionTree[] trees = new RegressionTree[k];
            Arrays.fill(prediction, -1);
            double max_h = Double.NEGATIVE_INFINITY;
            int oobTests = 0;
            int oobErrors = 0;
            for (j = 0; j < k; ++j) {
                RegressionTree tree;
                int i;
                GradientTreeBoostingClassifierUDTF.reportProgress(this._progressReporter);
                double[] response_j = response[j];
                double[] p_j = p[j];
                double[] h_j = h[j];
                for (i = 0; i < numInstances; ++i) {
                    response_j[i] = y[i] == j ? 1.0 : 0.0;
                    int n = i;
                    response_j[n] = response_j[n] - p_j[i];
                }
                Arrays.fill(samples, 0);
                SmileExtUtils.shuffle(perm, rnd1);
                for (i = 0; i < numSamples; ++i) {
                    int index;
                    int n = index = perm[i];
                    samples[n] = samples[n] + 1;
                }
                trees[j] = tree = new RegressionTree(nominalAttrs, x, response[j], numVars, this._maxDepth, this._maxLeafNodes, this._minSamplesSplit, this._minSamplesLeaf, samples, output[j], rnd2);
                for (int i2 = 0; i2 < numInstances; ++i2) {
                    x.getRow(i2, xProbe);
                    double h_ji = h_j[i2] + this._eta * tree.predict(xProbe);
                    int n = i2;
                    h_j[n] = h_j[n] + h_ji;
                    if (!(h_ji > max_h)) continue;
                    max_h = h_ji;
                    prediction[i2] = j;
                }
            }
            for (int i = 0; i < samples.length; ++i) {
                if (samples[i] != 0) continue;
                ++oobTests;
                if (prediction[i] == y[i]) continue;
                ++oobErrors;
            }
            float oobErrorRate = 0.0f;
            if (oobTests > 0) {
                oobErrorRate = (float)oobErrors / (float)oobTests;
            }
            this.forward(m + 1, 0.0, this._eta, oobErrorRate, x.numColumns(), trees);
        }
    }

    private void forward(int m, double intercept, double shrinkage, float oobErrorRate, int numColumns, RegressionTree ... trees) throws HiveException {
        Text[] models = GradientTreeBoostingClassifierUDTF.getModel(trees);
        AbstractVector importance = this.denseInput ? new DenseVector(numColumns) : new SparseVector();
        for (RegressionTree tree : trees) {
            Vector imp = tree.importance();
            int size = imp.size();
            for (int i = 0; i < size; ++i) {
                importance.incr(i, imp.get(i));
            }
        }
        Object[] forwardObjs = new Object[6];
        forwardObjs[0] = new IntWritable(m);
        forwardObjs[1] = models;
        forwardObjs[2] = new DoubleWritable(intercept);
        forwardObjs[3] = new DoubleWritable(shrinkage);
        if (this.denseInput) {
            forwardObjs[4] = WritableUtils.toWritableList(importance.toArray());
        } else {
            final HashMap map = new HashMap(importance.size());
            importance.each(new VectorProcedure(){

                @Override
                public void apply(int i, double value) {
                    map.put(new IntWritable(i), new DoubleWritable(value));
                }
            });
            forwardObjs[4] = map;
        }
        forwardObjs[5] = new FloatWritable(oobErrorRate);
        this.forward(forwardObjs);
        GradientTreeBoostingClassifierUDTF.reportProgress(this._progressReporter);
        GradientTreeBoostingClassifierUDTF.incrCounter(this._iterationCounter, 1L);
        logger.info((Object)("Forwarded the output of " + m + "-th Boosting iteration out of " + this._numTrees));
    }

    @Nonnull
    private static Text[] getModel(@Nonnull RegressionTree[] trees) throws HiveException {
        int m = trees.length;
        Text[] models = new Text[m];
        for (int i = 0; i < m; ++i) {
            byte[] b = trees[i].serialize(true);
            b = Base91.encode(b);
            models[i] = new Text(b);
        }
        return models;
    }

    private static final class LKNodeOutput
    implements RegressionTree.NodeOutput {
        final double[] y;
        final double k;

        public LKNodeOutput(double[] response, int k) {
            this.y = response;
            this.k = k;
        }

        @Override
        public double calculate(int[] samples) {
            int n = 0;
            double nu = 0.0;
            double de = 0.0;
            for (int i = 0; i < samples.length; ++i) {
                if (samples[i] <= 0) continue;
                ++n;
                double y_i = this.y[i];
                double abs = java.lang.Math.abs(y_i);
                nu += y_i;
                de += abs * (1.0 - abs);
            }
            if (de < 1.0E-10) {
                return nu / (double)n;
            }
            return (this.k - 1.0) / this.k * (nu / de);
        }
    }

    private static final class L2NodeOutput
    implements RegressionTree.NodeOutput {
        final double[] y;

        public L2NodeOutput(double[] y) {
            this.y = y;
        }

        @Override
        public double calculate(int[] samples) {
            double nu = 0.0;
            double de = 0.0;
            for (int i = 0; i < samples.length; ++i) {
                if (samples[i] <= 0) continue;
                double y_i = this.y[i];
                double abs = java.lang.Math.abs(y_i);
                nu += y_i;
                de += abs * (2.0 - abs);
            }
            return nu / de;
        }
    }
}

