/*
 * Decompiled with CFR 0.152.
 */
package hivemall.smile.tools;

import hivemall.UDFWithOptions;
import hivemall.smile.classification.DecisionTree;
import hivemall.smile.classification.PredictionHandler;
import hivemall.smile.regression.RegressionTree;
import hivemall.utils.codec.Base91;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.hadoop.WritableUtils;
import hivemall.utils.lang.Preconditions;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import matrix4j.vector.DenseVector;
import matrix4j.vector.SparseVector;
import matrix4j.vector.Vector;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.UDFType;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.StringObjectInspector;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Text;

@Description(name="tree_predict", value="_FUNC_(string modelId, string model, array<double|string> features [, const string options | const boolean classification=false]) - Returns a prediction result of a random forest in <int value, array<double> a posteriori> for classification and <double> for regression")
@UDFType(deterministic=true, stateful=false)
public final class TreePredictUDF
extends UDFWithOptions {
    private boolean classification;
    private StringObjectInspector modelOI;
    private ListObjectInspector featureListOI;
    private PrimitiveObjectInspector featureElemOI;
    private boolean denseInput;
    @Nullable
    private Vector featuresProbe;
    @Nullable
    private transient Evaluator evaluator;

    @Override
    protected Options getOptions() {
        Options opts = new Options();
        opts.addOption("c", "classification", false, "Predict as classification [default: not enabled]");
        return opts;
    }

    @Override
    protected CommandLine processOptions(@Nonnull String optionValue) throws UDFArgumentException {
        CommandLine cl = this.parseOptions(optionValue);
        this.classification = cl.hasOption("classification");
        return cl;
    }

    /*
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    public ObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
        ListObjectInspector listOI;
        if (argOIs.length != 3 && argOIs.length != 4) {
            this.showHelp("tree_predict takes 3 or 4 arguments");
        }
        this.modelOI = HiveUtils.asStringOI(argOIs, 1);
        this.featureListOI = listOI = HiveUtils.asListOI(argOIs, 2);
        ObjectInspector elemOI = listOI.getListElementObjectInspector();
        if (HiveUtils.isNumberOI(elemOI)) {
            this.featureElemOI = HiveUtils.asDoubleCompatibleOI(elemOI);
            this.denseInput = true;
        } else {
            if (!HiveUtils.isStringOI(elemOI)) throw new UDFArgumentException("tree_predict takes array<double> or array<string> for the second argument: " + listOI.getTypeName());
            this.featureElemOI = HiveUtils.asStringOI(elemOI);
            this.denseInput = false;
        }
        if (argOIs.length == 4) {
            ObjectInspector argOI3 = argOIs[3];
            if (HiveUtils.isConstBoolean(argOI3)) {
                this.classification = HiveUtils.getConstBoolean(argOI3);
            } else {
                if (!HiveUtils.isConstString(argOI3)) throw new UDFArgumentException("tree_predict expects <const boolean> or <const string> for the fourth argument: " + argOI3.getTypeName());
                String opts = HiveUtils.getConstString(argOI3);
                this.processOptions(opts);
            }
        } else {
            this.classification = false;
        }
        if (!this.classification) return PrimitiveObjectInspectorFactory.writableDoubleObjectInspector;
        ArrayList<String> fieldNames = new ArrayList<String>(2);
        ArrayList<Object> fieldOIs = new ArrayList<Object>(2);
        fieldNames.add("value");
        fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
        fieldNames.add("posteriori");
        fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector((ObjectInspector)PrimitiveObjectInspectorFactory.writableDoubleObjectInspector));
        return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
    }

    public Object evaluate(@Nonnull GenericUDF.DeferredObject[] arguments) throws HiveException {
        Object arg0 = arguments[0].get();
        if (arg0 == null) {
            throw new HiveException("modelId should not be null");
        }
        String modelId = arg0.toString();
        Object arg1 = arguments[1].get();
        if (arg1 == null) {
            return null;
        }
        Text model = this.modelOI.getPrimitiveWritableObject(arg1);
        Object arg2 = arguments[2].get();
        if (arg2 == null) {
            throw new HiveException("features was null");
        }
        this.featuresProbe = this.parseFeatures(arg2, this.featuresProbe);
        if (this.evaluator == null) {
            this.evaluator = this.classification ? new ClassificationEvaluator() : new RegressionEvaluator();
        }
        return this.evaluator.evaluate(modelId, model, this.featuresProbe);
    }

    @Nonnull
    private Vector parseFeatures(@Nonnull Object argObj, @Nullable Vector probe) throws UDFArgumentException {
        if (this.denseInput) {
            int length = this.featureListOI.getListLength(argObj);
            if (probe == null) {
                probe = new DenseVector(length);
            } else if (length != probe.size()) {
                probe = new DenseVector(length);
            }
            for (int i = 0; i < length; ++i) {
                Object o = this.featureListOI.getListElement(argObj, i);
                if (o == null) {
                    probe.set(i, 0.0);
                    continue;
                }
                double v = PrimitiveObjectInspectorUtils.getDouble((Object)o, (PrimitiveObjectInspector)this.featureElemOI);
                probe.set(i, v);
            }
        } else {
            if (probe == null) {
                probe = new SparseVector();
            } else {
                probe.clear();
            }
            int length = this.featureListOI.getListLength(argObj);
            for (int i = 0; i < length; ++i) {
                double value;
                String feature;
                Object o = this.featureListOI.getListElement(argObj, i);
                if (o == null) continue;
                String col = o.toString();
                int pos = col.indexOf(58);
                if (pos == 0) {
                    throw new UDFArgumentException("Invalid feature value representation: " + col);
                }
                if (pos > 0) {
                    feature = col.substring(0, pos);
                    String s2 = col.substring(pos + 1);
                    value = Double.parseDouble(s2);
                } else {
                    feature = col;
                    value = 1.0;
                }
                if (feature.indexOf(58) != -1) {
                    throw new UDFArgumentException("Invalid feature format `<index>:<value>`: " + col);
                }
                int colIndex = Integer.parseInt(feature);
                if (colIndex < 0) {
                    throw new UDFArgumentException("Col index MUST be greater than or equals to 0: " + colIndex);
                }
                probe.set(colIndex, value);
            }
        }
        return probe;
    }

    public void close() throws IOException {
        this.modelOI = null;
        this.featureElemOI = null;
        this.featureListOI = null;
        this.evaluator = null;
    }

    public String getDisplayString(String[] children) {
        return "tree_predict(" + Arrays.toString(children) + ")";
    }

    static final class RegressionEvaluator
    implements Evaluator {
        @Nonnull
        private final DoubleWritable result = new DoubleWritable();
        @Nullable
        private String prevModelId = null;
        private RegressionTree.Node rNode = null;

        RegressionEvaluator() {
        }

        @Nonnull
        public DoubleWritable evaluate(@Nonnull String modelId, @Nonnull Text script, @Nonnull Vector features) throws HiveException {
            if (!modelId.equals(this.prevModelId)) {
                this.prevModelId = modelId;
                int length = script.getLength();
                byte[] b = script.getBytes();
                b = Base91.decode(b, 0, length);
                this.rNode = RegressionTree.deserialize(b, b.length, true);
            }
            Preconditions.checkNotNull(this.rNode);
            double value = this.rNode.predict(features);
            this.result.set(value);
            return this.result;
        }
    }

    static final class ClassificationEvaluator
    implements Evaluator {
        @Nonnull
        private final Object[] result = new Object[2];
        @Nullable
        private String prevModelId = null;
        private DecisionTree.Node cNode = null;

        ClassificationEvaluator() {
        }

        @Nonnull
        public Object[] evaluate(@Nonnull String modelId, @Nonnull Text script, @Nonnull Vector features) throws HiveException {
            if (!modelId.equals(this.prevModelId)) {
                this.prevModelId = modelId;
                int length = script.getLength();
                byte[] b = script.getBytes();
                b = Base91.decode(b, 0, length);
                this.cNode = DecisionTree.deserialize(b, b.length, true);
            }
            Arrays.fill(this.result, null);
            Preconditions.checkNotNull(this.cNode);
            this.cNode.predict(features, new PredictionHandler(){

                @Override
                public void visitLeaf(int output, double[] posteriori) {
                    ((ClassificationEvaluator)ClassificationEvaluator.this).result[0] = new IntWritable(output);
                    ((ClassificationEvaluator)ClassificationEvaluator.this).result[1] = WritableUtils.toWritableList(posteriori);
                }
            });
            return this.result;
        }
    }

    static interface Evaluator {
        @Nonnull
        public Object evaluate(@Nonnull String var1, @Nonnull Text var2, @Nonnull Vector var3) throws HiveException;
    }
}

