/*
 * Decompiled with CFR 0.152.
 */
package hivemall.ftvec.text;

import hivemall.UDFWithOptions;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.lang.Primitives;
import hivemall.utils.lang.StringUtils;
import javax.annotation.Nonnull;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.UDFType;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;

@Description(name="bm25", value="_FUNC_(double termFrequency, int docLength, double avgDocLength, int numDocs, int numDocsWithTerm [, const string options]) - Return an Okapi BM25 score in double. Refer http://hivemall.incubator.apache.org/userguide/ft_engineering/bm25.html for usage")
@UDFType(deterministic=true, stateful=false)
public final class OkapiBM25UDF
extends UDFWithOptions {
    private double k1 = 1.2;
    private double b = 0.75;
    private double delta = 0.0;
    private double minIDF = 1.0E-8;
    private PrimitiveObjectInspector frequencyOI;
    private PrimitiveObjectInspector docLengthOI;
    private PrimitiveObjectInspector averageDocLengthOI;
    private PrimitiveObjectInspector numDocsOI;
    private PrimitiveObjectInspector numDocsWithTermOI;
    @Nonnull
    private final DoubleWritable result = new DoubleWritable();

    @Override
    protected Options getOptions() {
        Options opts = new Options();
        opts.addOption("k1", true, "Hyperparameter with type double, usually in range 1.2 and 2.0 [default: 1.2]");
        opts.addOption("b", true, "Hyperparameter with type double in range 0.0 and 1.0 [default: 0.75]");
        opts.addOption("d", "delta", true, "Hyperparameter delta of BM25+ [default: 0.0]");
        opts.addOption("min_idf", "epsilon", true, "Hyperparameter delta of BM25+ [default: 1e-8]");
        return opts;
    }

    @Override
    protected CommandLine processOptions(@Nonnull String opts) throws UDFArgumentException {
        CommandLine cl = this.parseOptions(opts);
        this.k1 = Primitives.parseDouble(cl.getOptionValue("k1"), this.k1);
        if (!Primitives.isFinite(this.k1) || this.k1 < 0.0) {
            throw new UDFArgumentException("k1 must be a non-negative finite value: " + this.k1);
        }
        this.b = Primitives.parseDouble(cl.getOptionValue("b"), this.b);
        if (Double.isNaN(this.b) || this.b < 0.0 || this.b > 1.0) {
            throw new UDFArgumentException("b1 hyperparameter must be in the range [0.0, 1.0]: " + this.b);
        }
        this.delta = Primitives.parseDouble(cl.getOptionValue("delta"), this.delta);
        if (!Primitives.isFinite(this.delta)) {
            throw new UDFArgumentException("Delta must be a finite value: " + this.delta);
        }
        this.minIDF = Primitives.parseDouble(cl.getOptionValue("min_idf"), this.minIDF);
        if (this.minIDF < 0.0) {
            throw new UDFArgumentException("min_idf must not be negative value: " + this.minIDF);
        }
        return cl;
    }

    public ObjectInspector initialize(@Nonnull ObjectInspector[] argOIs) throws UDFArgumentException {
        int numArgOIs = argOIs.length;
        if (numArgOIs < 5) {
            this.showHelp("#arguments must be greater than or equal to 5: " + numArgOIs);
        } else if (numArgOIs == 6) {
            String opts = HiveUtils.getConstString(argOIs[5]);
            this.processOptions(opts);
        }
        this.frequencyOI = HiveUtils.asDoubleCompatibleOI(argOIs[0]);
        this.docLengthOI = HiveUtils.asIntegerOI(argOIs[1]);
        this.averageDocLengthOI = HiveUtils.asDoubleCompatibleOI(argOIs[2]);
        this.numDocsOI = HiveUtils.asIntegerOI(argOIs[3]);
        this.numDocsWithTermOI = HiveUtils.asIntegerOI(argOIs[4]);
        return PrimitiveObjectInspectorFactory.writableDoubleObjectInspector;
    }

    public DoubleWritable evaluate(@Nonnull GenericUDF.DeferredObject[] arguments) throws HiveException {
        Object arg0 = arguments[0].get();
        Object arg1 = arguments[1].get();
        Object arg2 = arguments[2].get();
        Object arg3 = arguments[3].get();
        Object arg4 = arguments[4].get();
        if (arg0 == null || arg1 == null || arg2 == null || arg3 == null || arg4 == null) {
            throw new UDFArgumentException("Required arguments cannot be null");
        }
        double frequency = PrimitiveObjectInspectorUtils.getDouble((Object)arg0, (PrimitiveObjectInspector)this.frequencyOI);
        int docLength = PrimitiveObjectInspectorUtils.getInt((Object)arg1, (PrimitiveObjectInspector)this.docLengthOI);
        double averageDocLength = PrimitiveObjectInspectorUtils.getDouble((Object)arg2, (PrimitiveObjectInspector)this.averageDocLengthOI);
        int numDocs = PrimitiveObjectInspectorUtils.getInt((Object)arg3, (PrimitiveObjectInspector)this.numDocsOI);
        int numDocsWithTerm = PrimitiveObjectInspectorUtils.getInt((Object)arg4, (PrimitiveObjectInspector)this.numDocsWithTermOI);
        OkapiBM25UDF.assumeFalse(frequency < 0.0, "#frequency must be positive");
        OkapiBM25UDF.assumeFalse(docLength < 1, "#docLength must be greater than or equal to 1");
        OkapiBM25UDF.assumeFalse(averageDocLength <= 0.0, "#averageDocLength must be positive");
        OkapiBM25UDF.assumeFalse(numDocs < 1, "#numDocs must be greater than or equal to 1");
        OkapiBM25UDF.assumeFalse(numDocsWithTerm < 1, "#numDocsWithTerm must be greater than or equal to 1");
        double v = this.bm25(frequency, docLength, averageDocLength, numDocs, numDocsWithTerm);
        this.result.set(v);
        return this.result;
    }

    private double bm25(double tf, int docLength, double averageDocLength, int numDocs, int numDocsWithTerm) {
        double numerator = tf * (this.k1 + 1.0);
        double denominator = tf + this.k1 * (1.0 - this.b + this.b * (double)docLength / averageDocLength);
        double idf = Math.max(this.minIDF, OkapiBM25UDF.idf(numDocs, numDocsWithTerm));
        return idf * (numerator / denominator + this.delta);
    }

    private static double idf(int numDocs, int numDocsWithTerm) {
        return Math.log10(1.0 + ((double)(numDocs - numDocsWithTerm) + 0.5) / ((double)numDocsWithTerm + 0.5));
    }

    public String getDisplayString(String[] children) {
        return "bm25(" + StringUtils.join(children, ',') + ")";
    }
}

