/*
 * Decompiled with CFR 0.152.
 */
package hivemall.smile.tools;

import hivemall.UDFWithOptions;
import hivemall.smile.classification.DecisionTree;
import hivemall.smile.classification.PredictionHandler;
import hivemall.smile.regression.RegressionTree;
import hivemall.utils.codec.Base91;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.lang.Preconditions;
import hivemall.utils.lang.StringUtils;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import matrix4j.vector.DenseVector;
import matrix4j.vector.SparseVector;
import matrix4j.vector.Vector;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.UDFType;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.StringObjectInspector;
import org.apache.hadoop.io.Text;

@Description(name="decision_path", value="_FUNC_(string modelId, string model, array<double|string> features [, const string options] [, optional array<string> featureNames=null, optional array<string> classNames=null]) - Returns a decision path for each prediction in array<string>", extended="SELECT\n  t.passengerid,\n  decision_path(m.model_id, m.model, t.features, '-classification')\nFROM\n  model_rf m\n  LEFT OUTER JOIN\n  test_rf t;\n | 892 | [\"2 [0.0] = 0.0\",\"0 [3.0] = 3.0\",\"1 [696.0] != 107.0\",\"7 [7.8292] <= 7.9104\",\"1 [696.0] != 828.0\",\"1 [696.0] != 391.0\",\"0 [0.961038961038961, 0.03896103896103896]\"] |\n\n-- Show 100 frequent branches\nWITH tmp as (\n  SELECT\n    decision_path(m.model_id, m.model, t.features, '-classification -no_verbose -no_leaf', array('pclass','name','sex','age','sibsp','parch','ticket','fare','cabin','embarked'), array('no','yes')) as path\n  FROM\n    model_rf m\n    LEFT OUTER JOIN -- CROSS JOIN\n    test_rf t\n)\nselect\n  r.branch,\n  count(1) as cnt\nfrom\n  tmp l\n  LATERAL VIEW explode(l.path) r as branch\ngroup by\n  r.branch\norder by\n  cnt desc\nlimit 100;")
@UDFType(deterministic=true, stateful=false)
public final class DecisionPathUDF
extends UDFWithOptions {
    private StringObjectInspector modelOI;
    private ListObjectInspector featureListOI;
    private PrimitiveObjectInspector featureElemOI;
    private boolean denseInput;
    private boolean classification = false;
    private boolean summarize = true;
    private boolean verbose = true;
    private boolean noLeaf = false;
    @Nullable
    private String[] featureNames;
    @Nullable
    private String[] classNames;
    @Nullable
    private transient Vector featuresProbe;
    @Nullable
    private transient Evaluator evaluator;

    @Override
    protected Options getOptions() {
        Options opts = new Options();
        opts.addOption("c", "classification", false, "Predict as classification [default: not enabled]");
        opts.addOption("no_sumarize", "disable_summarization", false, "Do not summarize decision paths");
        opts.addOption("no_verbose", "disable_verbose_output", false, "Disable verbose output [default: verbose]");
        opts.addOption("no_leaf", "disable_leaf_output", false, "Show leaf value [default: not enabled]");
        return opts;
    }

    @Override
    protected CommandLine processOptions(@Nonnull String optionValue) throws UDFArgumentException {
        CommandLine cl = this.parseOptions(optionValue);
        this.classification = cl.hasOption("classification");
        this.summarize = !cl.hasOption("no_sumarize");
        this.verbose = !cl.hasOption("disable_verbose_output");
        this.noLeaf = cl.hasOption("disable_leaf_output");
        return cl;
    }

    /*
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    public ObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
        ListObjectInspector listOI;
        if (argOIs.length < 3 || argOIs.length > 6) {
            this.showHelp("tree_predict takes 3 ~ 6 arguments");
        }
        this.modelOI = HiveUtils.asStringOI(argOIs[1]);
        this.featureListOI = listOI = HiveUtils.asListOI(argOIs[2]);
        ObjectInspector elemOI = listOI.getListElementObjectInspector();
        if (HiveUtils.isNumberOI(elemOI)) {
            this.featureElemOI = HiveUtils.asDoubleCompatibleOI(elemOI);
            this.denseInput = true;
        } else {
            if (!HiveUtils.isStringOI(elemOI)) throw new UDFArgumentException("tree_predict takes array<double> or array<string> for the 3rd argument: " + listOI.getTypeName());
            this.featureElemOI = HiveUtils.asStringOI(elemOI);
            this.denseInput = false;
        }
        if (argOIs.length < 4) return ObjectInspectorFactory.getStandardListObjectInspector((ObjectInspector)PrimitiveObjectInspectorFactory.javaStringObjectInspector);
        ObjectInspector argOI3 = argOIs[3];
        if (HiveUtils.isConstString(argOI3)) {
            String opts = HiveUtils.getConstString(argOI3);
            this.processOptions(opts);
            if (argOIs.length < 5) return ObjectInspectorFactory.getStandardListObjectInspector((ObjectInspector)PrimitiveObjectInspectorFactory.javaStringObjectInspector);
            ObjectInspector argOI4 = argOIs[4];
            if (!HiveUtils.isConstStringListOI(argOI4)) throw new UDFArgumentException("decision_path expects 'const array<string> featureNames' for the 5th argument: " + argOI4.getTypeName());
            this.featureNames = HiveUtils.getConstStringArray(argOI4);
            if (argOIs.length < 6) return ObjectInspectorFactory.getStandardListObjectInspector((ObjectInspector)PrimitiveObjectInspectorFactory.javaStringObjectInspector);
            ObjectInspector argOI5 = argOIs[5];
            if (!HiveUtils.isConstStringListOI(argOI5)) throw new UDFArgumentException("decision_path expects 'const array<string> classNames' for the 6th argument: " + argOI5.getTypeName());
            if (!this.classification) {
                throw new UDFArgumentException("classNames should not be provided for regression");
            }
            this.classNames = HiveUtils.getConstStringArray(argOI5);
            return ObjectInspectorFactory.getStandardListObjectInspector((ObjectInspector)PrimitiveObjectInspectorFactory.javaStringObjectInspector);
        } else {
            if (!HiveUtils.isConstStringListOI(argOI3)) throw new UDFArgumentException("decision_path expects 'const array<string> options' or 'const array<string> featureNames' for the 4th argument: " + argOI3.getTypeName());
            this.featureNames = HiveUtils.getConstStringArray(argOI3);
            if (argOIs.length < 5) return ObjectInspectorFactory.getStandardListObjectInspector((ObjectInspector)PrimitiveObjectInspectorFactory.javaStringObjectInspector);
            ObjectInspector argOI4 = argOIs[4];
            if (!HiveUtils.isConstStringListOI(argOI4)) throw new UDFArgumentException("decision_path expects 'const array<string> classNames' for the 5th argument: " + argOI4.getTypeName());
            if (!this.classification) {
                throw new UDFArgumentException("classNames should not be provided for regression");
            }
            this.classNames = HiveUtils.getConstStringArray(argOI4);
        }
        return ObjectInspectorFactory.getStandardListObjectInspector((ObjectInspector)PrimitiveObjectInspectorFactory.javaStringObjectInspector);
    }

    public List<String> evaluate(@Nonnull GenericUDF.DeferredObject[] arguments) throws HiveException {
        Object arg0 = arguments[0].get();
        if (arg0 == null) {
            throw new HiveException("modelId should not be null");
        }
        String modelId = arg0.toString();
        Object arg1 = arguments[1].get();
        if (arg1 == null) {
            return null;
        }
        Text model = this.modelOI.getPrimitiveWritableObject(arg1);
        Object arg2 = arguments[2].get();
        if (arg2 == null) {
            throw new HiveException("features was null");
        }
        this.featuresProbe = this.parseFeatures(arg2, this.featuresProbe);
        if (this.evaluator == null) {
            this.evaluator = this.classification ? new ClassificationEvaluator(this) : new RegressionEvaluator(this);
        }
        return this.evaluator.evaluate(modelId, model, this.featuresProbe);
    }

    @Nonnull
    private Vector parseFeatures(@Nonnull Object argObj, @Nullable Vector probe) throws UDFArgumentException {
        if (this.denseInput) {
            int length = this.featureListOI.getListLength(argObj);
            if (probe == null) {
                probe = new DenseVector(length);
            } else if (length != probe.size()) {
                probe = new DenseVector(length);
            }
            for (int i = 0; i < length; ++i) {
                Object o = this.featureListOI.getListElement(argObj, i);
                if (o == null) {
                    probe.set(i, 0.0);
                    continue;
                }
                double v = PrimitiveObjectInspectorUtils.getDouble((Object)o, (PrimitiveObjectInspector)this.featureElemOI);
                probe.set(i, v);
            }
        } else {
            if (probe == null) {
                probe = new SparseVector();
            } else {
                probe.clear();
            }
            int length = this.featureListOI.getListLength(argObj);
            for (int i = 0; i < length; ++i) {
                double value;
                String feature;
                Object o = this.featureListOI.getListElement(argObj, i);
                if (o == null) continue;
                String col = o.toString();
                int pos = col.indexOf(58);
                if (pos == 0) {
                    throw new UDFArgumentException("Invalid feature value representation: " + col);
                }
                if (pos > 0) {
                    feature = col.substring(0, pos);
                    String s2 = col.substring(pos + 1);
                    value = Double.parseDouble(s2);
                } else {
                    feature = col;
                    value = 1.0;
                }
                if (feature.indexOf(58) != -1) {
                    throw new UDFArgumentException("Invalid feature format `<index>:<value>`: " + col);
                }
                int colIndex = Integer.parseInt(feature);
                if (colIndex < 0) {
                    throw new UDFArgumentException("Col index MUST be greater than or equals to 0: " + colIndex);
                }
                probe.set(colIndex, value);
            }
        }
        return probe;
    }

    public void close() throws IOException {
        this.modelOI = null;
        this.featureElemOI = null;
        this.featureListOI = null;
        this.featureNames = null;
        this.classNames = null;
        this.featuresProbe = null;
        this.evaluator = null;
    }

    public String getDisplayString(String[] children) {
        return "decision_path(" + StringUtils.join(children, ',') + ")";
    }

    static final class RegressionEvaluator
    implements Evaluator {
        @Nullable
        private final String[] featureNames;
        @Nonnull
        private final List<String> result;
        @Nonnull
        private final PredictionHandler handler;
        @Nullable
        private String prevModelId = null;
        private RegressionTree.Node rNode = null;

        RegressionEvaluator(final @Nonnull DecisionPathUDF udf) {
            this.featureNames = udf.featureNames;
            final StringBuilder buf = new StringBuilder();
            final ArrayList<String> result = new ArrayList<String>();
            this.result = result;
            if (udf.summarize) {
                final LinkedHashMap map = new LinkedHashMap();
                this.handler = new PredictionHandler(){

                    @Override
                    public void init() {
                        map.clear();
                        result.clear();
                    }

                    @Override
                    public void visitBranch(PredictionHandler.Operator op, int splitFeatureIndex, double splitFeature, double splitValue) {
                        buf.append(RegressionEvaluator.this.resolveFeatureName(splitFeatureIndex));
                        if (udf.verbose) {
                            buf.append(" [" + splitFeature + "] ");
                        } else {
                            buf.append(' ');
                        }
                        buf.append((Object)op);
                        if (op == PredictionHandler.Operator.EQ || op == PredictionHandler.Operator.NE) {
                            buf.append(' ');
                            buf.append(splitValue);
                        }
                        String key = buf.toString();
                        map.put(key, splitValue);
                        StringUtils.clear(buf);
                    }

                    @Override
                    public void visitLeaf(double output) {
                        for (Map.Entry e : map.entrySet()) {
                            String key = (String)e.getKey();
                            if (key.indexOf(60) == -1 && key.indexOf(62) == -1) {
                                result.add(key);
                                continue;
                            }
                            double value = (Double)e.getValue();
                            result.add(key + ' ' + value);
                        }
                        if (udf.noLeaf) {
                            return;
                        }
                        result.add(Double.toString(output));
                    }

                    public ArrayList<String> getResult() {
                        return result;
                    }
                };
            } else {
                this.handler = new PredictionHandler(){

                    @Override
                    public void init() {
                        result.clear();
                    }

                    @Override
                    public void visitBranch(PredictionHandler.Operator op, int splitFeatureIndex, double splitFeature, double splitValue) {
                        buf.append(RegressionEvaluator.this.resolveFeatureName(splitFeatureIndex));
                        if (udf.verbose) {
                            buf.append(" [" + splitFeature + "] ");
                        }
                        buf.append((Object)op);
                        buf.append(' ');
                        buf.append(splitValue);
                        result.add(buf.toString());
                        StringUtils.clear(buf);
                    }

                    @Override
                    public void visitLeaf(double output) {
                        if (udf.noLeaf) {
                            return;
                        }
                        result.add(Double.toString(output));
                    }

                    public ArrayList<String> getResult() {
                        return result;
                    }
                };
            }
        }

        @Nonnull
        private String resolveFeatureName(int splitFeatureIndex) {
            if (this.featureNames == null) {
                return Integer.toString(splitFeatureIndex);
            }
            return this.featureNames[splitFeatureIndex];
        }

        @Override
        @Nonnull
        public List<String> evaluate(@Nonnull String modelId, @Nonnull Text script, @Nonnull Vector features) throws HiveException {
            if (!modelId.equals(this.prevModelId)) {
                this.prevModelId = modelId;
                int length = script.getLength();
                byte[] b = script.getBytes();
                b = Base91.decode(b, 0, length);
                this.rNode = RegressionTree.deserialize(b, b.length, true);
            }
            Preconditions.checkNotNull(this.rNode);
            this.handler.init();
            this.rNode.predict(features, this.handler);
            return (List)this.handler.getResult();
        }
    }

    static final class ClassificationEvaluator
    implements Evaluator {
        @Nullable
        private final String[] featureNames;
        @Nullable
        private final String[] classNames;
        @Nonnull
        private final List<String> result;
        @Nonnull
        private final PredictionHandler handler;
        @Nullable
        private String prevModelId = null;
        private DecisionTree.Node cNode = null;

        ClassificationEvaluator(final @Nonnull DecisionPathUDF udf) {
            this.featureNames = udf.featureNames;
            this.classNames = udf.classNames;
            final StringBuilder buf = new StringBuilder();
            final ArrayList<String> result = new ArrayList<String>();
            this.result = result;
            if (udf.summarize) {
                final LinkedHashMap map = new LinkedHashMap();
                this.handler = new PredictionHandler(){

                    @Override
                    public void init() {
                        map.clear();
                        result.clear();
                    }

                    @Override
                    public void visitBranch(PredictionHandler.Operator op, int splitFeatureIndex, double splitFeature, double splitValue) {
                        buf.append(ClassificationEvaluator.this.resolveFeatureName(splitFeatureIndex));
                        if (udf.verbose) {
                            buf.append(" [" + splitFeature + "] ");
                        } else {
                            buf.append(' ');
                        }
                        buf.append((Object)op);
                        if (op == PredictionHandler.Operator.EQ || op == PredictionHandler.Operator.NE) {
                            buf.append(' ');
                            buf.append(splitValue);
                        }
                        String key = buf.toString();
                        map.put(key, splitValue);
                        StringUtils.clear(buf);
                    }

                    @Override
                    public void visitLeaf(int output, double[] posteriori) {
                        for (Map.Entry e : map.entrySet()) {
                            String key = (String)e.getKey();
                            if (key.indexOf(60) == -1 && key.indexOf(62) == -1) {
                                result.add(key);
                                continue;
                            }
                            double value = (Double)e.getValue();
                            result.add(key + ' ' + value);
                        }
                        if (udf.noLeaf) {
                            return;
                        }
                        if (udf.verbose) {
                            buf.append(ClassificationEvaluator.this.resolveClassName(output));
                            buf.append(' ');
                            buf.append(Arrays.toString(posteriori));
                            result.add(buf.toString());
                            StringUtils.clear(buf);
                        } else {
                            result.add(ClassificationEvaluator.this.resolveClassName(output));
                        }
                    }

                    public ArrayList<String> getResult() {
                        return result;
                    }
                };
            } else {
                this.handler = new PredictionHandler(){

                    @Override
                    public void init() {
                        result.clear();
                    }

                    @Override
                    public void visitBranch(PredictionHandler.Operator op, int splitFeatureIndex, double splitFeature, double splitValue) {
                        buf.append(ClassificationEvaluator.this.resolveFeatureName(splitFeatureIndex));
                        if (udf.verbose) {
                            buf.append(" [" + splitFeature + "] ");
                        } else {
                            buf.append(' ');
                        }
                        buf.append((Object)op);
                        buf.append(' ');
                        buf.append(splitValue);
                        result.add(buf.toString());
                        StringUtils.clear(buf);
                    }

                    @Override
                    public void visitLeaf(int output, double[] posteriori) {
                        if (udf.noLeaf) {
                            return;
                        }
                        if (udf.verbose) {
                            buf.append(ClassificationEvaluator.this.resolveClassName(output));
                            buf.append(' ');
                            buf.append(Arrays.toString(posteriori));
                            result.add(buf.toString());
                            StringUtils.clear(buf);
                        } else {
                            result.add(ClassificationEvaluator.this.resolveClassName(output));
                        }
                    }

                    public ArrayList<String> getResult() {
                        return result;
                    }
                };
            }
        }

        @Nonnull
        private String resolveFeatureName(int splitFeatureIndex) {
            if (this.featureNames == null) {
                return Integer.toString(splitFeatureIndex);
            }
            return this.featureNames[splitFeatureIndex];
        }

        @Nonnull
        private String resolveClassName(int classLabel) {
            if (this.classNames == null) {
                return Integer.toString(classLabel);
            }
            return this.classNames[classLabel];
        }

        @Override
        @Nonnull
        public List<String> evaluate(@Nonnull String modelId, @Nonnull Text script, @Nonnull Vector features) throws HiveException {
            if (!modelId.equals(this.prevModelId)) {
                this.prevModelId = modelId;
                int length = script.getLength();
                byte[] b = script.getBytes();
                b = Base91.decode(b, 0, length);
                this.cNode = DecisionTree.deserialize(b, b.length, true);
            }
            Preconditions.checkNotNull(this.cNode);
            this.handler.init();
            this.cNode.predict(features, this.handler);
            return (List)this.handler.getResult();
        }
    }

    static interface Evaluator {
        @Nonnull
        public List<String> evaluate(@Nonnull String var1, @Nonnull Text var2, @Nonnull Vector var3) throws HiveException;
    }
}

