/*
 * Decompiled with CFR 0.152.
 */
package hivemall.smile.regression;

import hivemall.UDTFWithOptions;
import hivemall.smile.regression.RegressionTree;
import hivemall.smile.utils.SmileExtUtils;
import hivemall.smile.utils.SmileTaskExecutor;
import hivemall.utils.codec.Base91;
import hivemall.utils.collections.lists.DoubleArrayList;
import hivemall.utils.datetime.StopWatch;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.hadoop.SerdeUtils;
import hivemall.utils.hadoop.WritableUtils;
import hivemall.utils.lang.Primitives;
import hivemall.utils.lang.RandomUtils;
import hivemall.utils.random.PRNG;
import hivemall.utils.random.RandomNumberGeneratorFactory;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.concurrent.Callable;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import javax.annotation.Nonnegative;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import matrix4j.matrix.Matrix;
import matrix4j.matrix.builders.CSRMatrixBuilder;
import matrix4j.matrix.builders.MatrixBuilder;
import matrix4j.matrix.builders.RowMajorDenseMatrixBuilder;
import matrix4j.vector.Vector;
import matrix4j.vector.VectorProcedure;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.MapredContext;
import org.apache.hadoop.hive.ql.exec.MapredContextAccessor;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapred.Counters;
import org.apache.hadoop.mapred.Reporter;
import org.roaringbitmap.RoaringBitmap;
import smile.math.Math;

@Description(name="train_randomforest_regressor", value="_FUNC_(array<double|string> features, double target [, string options]) - Returns a relation consists of <int model_id, int model_type, string model, array<double> var_importance, double oob_errors, int oob_tests>")
public final class RandomForestRegressionUDTF
extends UDTFWithOptions {
    private static final Log logger = LogFactory.getLog(RandomForestRegressionUDTF.class);
    private ListObjectInspector featureListOI;
    private PrimitiveObjectInspector featureElemOI;
    private PrimitiveObjectInspector targetOI;
    private boolean denseInput;
    private MatrixBuilder matrixBuilder;
    private DoubleArrayList targets;
    private int _numTrees;
    private float _numVars;
    private int _maxDepth;
    private int _maxLeafNodes;
    private int _minSamplesSplit;
    private int _minSamplesLeaf;
    private long _seed;
    private byte[] _nominalAttrs;
    @Nullable
    private transient Reporter _progressReporter;
    @Nullable
    private transient Counters.Counter _treeBuildTaskCounter;
    @Nullable
    private transient Counters.Counter _treeConstructionTimeCounter;
    @Nullable
    private transient Counters.Counter _treeSerializationTimeCounter;

    @Override
    protected Options getOptions() {
        Options opts = new Options();
        opts.addOption("trees", "num_trees", true, "The number of trees for each task [default: 50]");
        opts.addOption("vars", "num_variables", true, "The number of random selected features [default: ceil(sqrt(x[0].length))]. int(num_variables * x[0].length) is considered if num_variable is (0.0,1.0]");
        opts.addOption("depth", "max_depth", true, "The maximum number of the tree depth [default: Integer.MAX_VALUE]");
        opts.addOption("leafs", "max_leaf_nodes", true, "The maximum number of leaf nodes [default: Integer.MAX_VALUE]");
        opts.addOption("min_samples_split", true, "A node that has greater than or equals to `min_split` examples will split [default: 5]");
        opts.addOption("split", "min_split", true, "A node that has greater than or equals to `min_split` examples will split [default: 5]");
        opts.addOption("min_samples_leaf", true, "The minimum number of samples in a leaf node [default: 1]");
        opts.addOption("seed", true, "seed value in long [default: -1 (random)]");
        opts.addOption("attrs", "attribute_types", true, "Comma separated attribute types (Q for quantitative variable and C for categorical variable. e.g., [Q,C,Q,C])");
        opts.addOption("nominal_attr_indicies", "categorical_attr_indicies", true, "Comma seperated indicies of categorical attributes, e.g., [3,5,6]");
        return opts;
    }

    @Override
    protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException {
        int trees = 50;
        int maxDepth = Integer.MAX_VALUE;
        int maxLeafNodes = Integer.MAX_VALUE;
        int minSamplesSplit = 5;
        int minSamplesLeaf = 1;
        float numVars = -1.0f;
        RoaringBitmap attrs = new RoaringBitmap();
        long seed = -1L;
        CommandLine cl = null;
        if (argOIs.length >= 3) {
            String rawArgs = HiveUtils.getConstString(argOIs, 2);
            cl = this.parseOptions(rawArgs);
            trees = Primitives.parseInt(cl.getOptionValue("num_trees"), trees);
            if (trees < 1) {
                throw new IllegalArgumentException("Invalid number of trees: " + trees);
            }
            numVars = Primitives.parseFloat(cl.getOptionValue("num_variables"), numVars);
            maxDepth = Primitives.parseInt(cl.getOptionValue("max_depth"), maxDepth);
            maxLeafNodes = Primitives.parseInt(cl.getOptionValue("max_leaf_nodes"), maxLeafNodes);
            String min_samples_split = cl.getOptionValue("min_samples_split");
            minSamplesSplit = min_samples_split == null ? Primitives.parseInt(cl.getOptionValue("min_split"), minSamplesSplit) : Integer.parseInt(min_samples_split);
            minSamplesLeaf = Primitives.parseInt(cl.getOptionValue("min_samples_leaf"), minSamplesLeaf);
            seed = Primitives.parseLong(cl.getOptionValue("seed"), seed);
            String nominal_attr_indicies = cl.getOptionValue("nominal_attr_indicies");
            attrs = nominal_attr_indicies != null ? SmileExtUtils.parseNominalAttributeIndicies(nominal_attr_indicies) : SmileExtUtils.resolveAttributes(cl.getOptionValue("attribute_types"));
        }
        this._numTrees = trees;
        this._numVars = numVars;
        this._maxDepth = maxDepth;
        this._maxLeafNodes = maxLeafNodes;
        this._minSamplesSplit = minSamplesSplit;
        this._minSamplesLeaf = minSamplesLeaf;
        this._seed = seed;
        this._nominalAttrs = SerdeUtils.serializeRoaring(attrs);
        return cl;
    }

    public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
        if (argOIs.length != 2 && argOIs.length != 3) {
            throw new UDFArgumentException(((Object)((Object)this)).getClass().getSimpleName() + " takes 2 or 3 arguments: array<double|string> features, double target [, const string options]: " + argOIs.length);
        }
        ListObjectInspector listOI = HiveUtils.asListOI(argOIs, 0);
        ObjectInspector elemOI = listOI.getListElementObjectInspector();
        this.featureListOI = listOI;
        if (HiveUtils.isNumberOI(elemOI)) {
            this.featureElemOI = HiveUtils.asDoubleCompatibleOI(elemOI);
            this.denseInput = true;
            this.matrixBuilder = new RowMajorDenseMatrixBuilder(8192);
        } else if (HiveUtils.isStringOI(elemOI)) {
            this.featureElemOI = HiveUtils.asStringOI(elemOI);
            this.denseInput = false;
            this.matrixBuilder = new CSRMatrixBuilder(8192);
        } else {
            throw new UDFArgumentException("_FUNC_ takes double[] or string[] for the first argument: " + listOI.getTypeName());
        }
        this.targetOI = HiveUtils.asDoubleCompatibleOI(argOIs, 1);
        this.processOptions(argOIs);
        this.targets = new DoubleArrayList(1024);
        ArrayList<String> fieldNames = new ArrayList<String>(6);
        ArrayList<Object> fieldOIs = new ArrayList<Object>(6);
        fieldNames.add("model_id");
        fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
        fieldNames.add("model_err");
        fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
        fieldNames.add("model");
        fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
        fieldNames.add("var_importance");
        if (this.denseInput) {
            fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector((ObjectInspector)PrimitiveObjectInspectorFactory.writableDoubleObjectInspector));
        } else {
            fieldOIs.add(ObjectInspectorFactory.getStandardMapObjectInspector((ObjectInspector)PrimitiveObjectInspectorFactory.writableIntObjectInspector, (ObjectInspector)PrimitiveObjectInspectorFactory.writableDoubleObjectInspector));
        }
        fieldNames.add("oob_errors");
        fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
        fieldNames.add("oob_tests");
        fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
        return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
    }

    public void process(Object[] args) throws HiveException {
        if (args[0] == null) {
            throw new HiveException("array<double> features was null");
        }
        this.parseFeatures(args[0], this.matrixBuilder);
        double target = PrimitiveObjectInspectorUtils.getDouble((Object)args[1], (PrimitiveObjectInspector)this.targetOI);
        this.targets.add(target);
    }

    private void parseFeatures(@Nonnull Object argObj, @Nonnull MatrixBuilder builder) {
        if (this.denseInput) {
            int length = this.featureListOI.getListLength(argObj);
            for (int i = 0; i < length; ++i) {
                Object o = this.featureListOI.getListElement(argObj, i);
                if (o == null) continue;
                double v = PrimitiveObjectInspectorUtils.getDouble((Object)o, (PrimitiveObjectInspector)this.featureElemOI);
                builder.nextColumn(i, v);
            }
        } else {
            int length = this.featureListOI.getListLength(argObj);
            for (int i = 0; i < length; ++i) {
                Object o = this.featureListOI.getListElement(argObj, i);
                if (o == null) continue;
                String fv = o.toString();
                builder.nextColumn(fv);
            }
        }
        builder.nextRow();
    }

    public void close() throws HiveException {
        this._progressReporter = this.getReporter();
        this._treeBuildTaskCounter = this._progressReporter == null ? null : this._progressReporter.getCounter("hivemall.smile.RandomForestRegression$Counter", "Number of finished tree construction tasks");
        this._treeConstructionTimeCounter = this._progressReporter == null ? null : this._progressReporter.getCounter("hivemall.smile.RandomForestRegression$Counter", "Elapsed time in seconds for tree construction");
        this._treeSerializationTimeCounter = this._progressReporter == null ? null : this._progressReporter.getCounter("hivemall.smile.RandomForestRegression$Counter", "Elapsed time in seconds for tree serialization");
        RandomForestRegressionUDTF.reportProgress(this._progressReporter);
        if (!this.targets.isEmpty()) {
            Matrix x = this.matrixBuilder.buildMatrix();
            this.matrixBuilder = null;
            double[] y = this.targets.toArray();
            this.targets = null;
            this.train(x, y);
        }
        this.featureListOI = null;
        this.featureElemOI = null;
        this.targetOI = null;
        this._nominalAttrs = null;
    }

    private void checkOptions() throws HiveException {
        if (this._minSamplesSplit <= 0) {
            throw new HiveException("Invalid minSamplesSplit: " + this._minSamplesSplit);
        }
        if (this._maxDepth < 1) {
            throw new HiveException("Invalid maxDepth: " + this._maxDepth);
        }
    }

    private void train(@Nonnull Matrix x, @Nonnull double[] y) throws HiveException {
        int numExamples = x.numRows();
        if (numExamples != y.length) {
            throw new HiveException(String.format("The sizes of X and Y don't match: %d != %d", numExamples, y.length));
        }
        this.checkOptions();
        x = SmileExtUtils.shuffle(x, y, this._seed);
        int numInputVars = SmileExtUtils.computeNumInputVars(this._numVars, x);
        if (logger.isInfoEnabled()) {
            logger.info((Object)("numTrees: " + this._numTrees + ", numVars: " + numInputVars + ", minSamplesSplit: " + this._minSamplesSplit + ", maxDepth: " + this._maxDepth + ", maxLeafs: " + this._maxLeafNodes + ", nodeCapacity: " + this._minSamplesSplit + ", seed: " + this._seed));
        }
        RoaringBitmap nominalAttrs = SerdeUtils.deserializeRoaring(this._nominalAttrs);
        this._nominalAttrs = null;
        double[] prediction = new double[numExamples];
        int[] oob = new int[numExamples];
        AtomicInteger remainingTasks = new AtomicInteger(this._numTrees);
        ArrayList<TrainingTask> tasks = new ArrayList<TrainingTask>();
        for (int i = 0; i < this._numTrees; ++i) {
            long s = this._seed == -1L ? -1L : this._seed + (long)i;
            tasks.add(new TrainingTask(this, i, nominalAttrs, x, y, numInputVars, prediction, oob, s, remainingTasks));
        }
        MapredContext mapredContext = MapredContextAccessor.get();
        SmileTaskExecutor executor = new SmileTaskExecutor(mapredContext);
        try {
            executor.run(tasks);
        }
        catch (Exception ex) {
            throw new HiveException((Throwable)ex);
        }
        finally {
            executor.shutdown();
        }
    }

    synchronized void forward(int taskId, @Nonnull Text model, @Nonnull Vector importance, @Nonnegative double error, double[] y, double[] prediction, int[] oob, boolean lastTask) throws HiveException {
        double oobErrors = 0.0;
        int oobTests = 0;
        if (lastTask) {
            for (int i = 0; i < y.length; ++i) {
                if (oob[i] <= 0) continue;
                ++oobTests;
                double pred = prediction[i] / (double)oob[i];
                oobErrors += Math.sqr(pred - y[i]);
            }
        }
        String modelId = RandomUtils.getUUID();
        Object[] forwardObjs = new Object[6];
        forwardObjs[0] = new Text(modelId);
        forwardObjs[1] = new DoubleWritable(error);
        forwardObjs[2] = model;
        if (this.denseInput) {
            forwardObjs[3] = WritableUtils.toWritableList(importance.toArray());
        } else {
            final HashMap map = new HashMap(importance.size());
            importance.each(new VectorProcedure(){

                @Override
                public void apply(int i, double value) {
                    map.put(new IntWritable(i), new DoubleWritable(value));
                }
            });
            forwardObjs[3] = map;
        }
        forwardObjs[4] = new DoubleWritable(oobErrors);
        forwardObjs[5] = new IntWritable(oobTests);
        this.forward(forwardObjs);
        RandomForestRegressionUDTF.reportProgress(this._progressReporter);
        RandomForestRegressionUDTF.incrCounter(this._treeBuildTaskCounter, 1L);
        logger.info((Object)("Forwarded " + taskId + "-th RegressionTree out of " + this._numTrees));
    }

    private static final class TrainingTask
    implements Callable<Integer> {
        private final RoaringBitmap _nominalAttrs;
        private final Matrix _x;
        private final double[] _y;
        private final int _numVars;
        private final double[] _prediction;
        private final int[] _oob;
        private final RandomForestRegressionUDTF _udtf;
        private final int _taskId;
        private final long _seed;
        private final AtomicInteger _remainingTasks;

        TrainingTask(RandomForestRegressionUDTF udtf, int taskId, RoaringBitmap nominalAttrs, Matrix x, double[] y, int numVars, double[] prediction, int[] oob, long seed, AtomicInteger remainingTasks) {
            this._udtf = udtf;
            this._taskId = taskId;
            this._nominalAttrs = nominalAttrs;
            this._x = x;
            this._y = y;
            this._numVars = numVars;
            this._prediction = prediction;
            this._oob = oob;
            this._seed = seed;
            this._remainingTasks = remainingTasks;
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        @Override
        public Integer call() throws HiveException {
            long s = this._seed == -1L ? SmileExtUtils.generateSeed() : RandomNumberGeneratorFactory.createPRNG(this._seed).nextLong();
            PRNG rnd1 = RandomNumberGeneratorFactory.createPRNG(s);
            PRNG rnd2 = RandomNumberGeneratorFactory.createPRNG(rnd1.nextLong());
            int N = this._x.numRows();
            int[] samples = new int[N];
            for (int i = 0; i < N; ++i) {
                int index;
                int n = index = rnd1.nextInt(N);
                samples[n] = samples[n] + 1;
            }
            StopWatch stopwatch = new StopWatch();
            RegressionTree tree = new RegressionTree(this._nominalAttrs, this._x, this._y, this._numVars, this._udtf._maxDepth, this._udtf._maxLeafNodes, this._udtf._minSamplesSplit, this._udtf._minSamplesLeaf, samples, rnd2);
            RandomForestRegressionUDTF.incrCounter(this._udtf._treeConstructionTimeCounter, stopwatch.elapsed(TimeUnit.SECONDS));
            int oob = 0;
            double error = 0.0;
            Vector xProbe = this._x.rowVector();
            for (int i = 0; i < samples.length; ++i) {
                if (samples[i] != 0) continue;
                ++oob;
                this._x.getRow(i, xProbe);
                double pred = tree.predict(xProbe);
                RandomForestRegressionUDTF randomForestRegressionUDTF = this._udtf;
                synchronized (randomForestRegressionUDTF) {
                    int n = i;
                    this._prediction[n] = this._prediction[n] + pred;
                    int n2 = i;
                    this._oob[n2] = this._oob[n2] + 1;
                }
                error += java.lang.Math.abs(pred - this._y[i]);
            }
            if (oob != 0) {
                error /= (double)oob;
            }
            stopwatch.reset().start();
            Text model = TrainingTask.getModel(tree);
            Vector importance = tree.importance();
            tree = null;
            int remain = this._remainingTasks.decrementAndGet();
            boolean lastTask = remain == 0;
            this._udtf.forward(this._taskId + 1, model, importance, error, this._y, this._prediction, this._oob, lastTask);
            RandomForestRegressionUDTF.incrCounter(this._udtf._treeSerializationTimeCounter, stopwatch.elapsed(TimeUnit.SECONDS));
            return remain;
        }

        @Nonnull
        private static Text getModel(@Nonnull RegressionTree tree) throws HiveException {
            byte[] b = tree.serialize(true);
            b = Base91.encode(b);
            return new Text(b);
        }
    }
}

