/*
 * Decompiled with CFR 0.152.
 */
package hivemall.knn.similarity;

import hivemall.UDTFWithOptions;
import hivemall.factorization.fm.Feature;
import hivemall.factorization.fm.IntFeature;
import hivemall.factorization.fm.StringFeature;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.lang.Primitives;
import hivemall.utils.random.PRNG;
import hivemall.utils.random.RandomNumberGeneratorFactory;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.MapObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Text;

@Description(name="dimsum_mapper", value="_FUNC_(array<string> row, map<int col_id, double norm> colNorms [, const string options]) - Returns column-wise partial similarities")
public final class DIMSUMMapperUDTF
extends UDTFWithOptions {
    private ListObjectInspector rowOI;
    private MapObjectInspector colNormsOI;
    @Nullable
    private Feature[] probes;
    @Nonnull
    private PRNG rnd;
    private double threshold;
    private double sqrtGamma;
    private boolean symmetricOutput;
    private boolean parseFeatureAsInt;
    private Map<Object, Double> colNorms;
    private Map<Object, Double> colProbs;

    @Override
    protected Options getOptions() {
        Options opts = new Options();
        opts.addOption("th", "threshold", true, "Theoretically, similarities above this threshold are estimated [default: 0.5]");
        opts.addOption("g", "gamma", true, "Oversampling parameter; if `gamma` is given, `threshold` will be ignored [default: 10 * log(numCols) / threshold]");
        opts.addOption("disable_symmetric", "disable_symmetric_output", false, "Output only contains (col j, col k) pair; symmetric (col k, col j) pair is omitted");
        opts.addOption("int_feature", "feature_as_integer", false, "Parse a feature (i.e. column ID) as integer");
        return opts;
    }

    @Override
    protected CommandLine processOptions(@Nonnull ObjectInspector[] argOIs) throws UDFArgumentException {
        double threshold = 0.5;
        double gamma = Double.POSITIVE_INFINITY;
        boolean symmetricOutput = true;
        boolean parseFeatureAsInt = false;
        CommandLine cl = null;
        if (argOIs.length >= 3) {
            String rawArgs = HiveUtils.getConstString(argOIs[2]);
            cl = this.parseOptions(rawArgs);
            threshold = Primitives.parseDouble(cl.getOptionValue("threshold"), threshold);
            if (threshold < 0.0 || threshold >= 1.0) {
                throw new UDFArgumentException("`threshold` MUST be in range [0,1): " + threshold);
            }
            gamma = Primitives.parseDouble(cl.getOptionValue("gamma"), gamma);
            if (gamma <= 1.0) {
                throw new UDFArgumentException("`gamma` MUST be greater than 1: " + gamma);
            }
            symmetricOutput = !cl.hasOption("disable_symmetric_output");
            parseFeatureAsInt = cl.hasOption("feature_as_integer");
        }
        this.threshold = threshold;
        this.sqrtGamma = Math.sqrt(gamma);
        this.symmetricOutput = symmetricOutput;
        this.parseFeatureAsInt = parseFeatureAsInt;
        return cl;
    }

    public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
        if (argOIs.length != 2 && argOIs.length != 3) {
            throw new UDFArgumentException(((Object)((Object)this)).getClass().getSimpleName() + " takes 2 or 3 arguments: array<string> x, map<long, double> colNorms [, CONSTANT STRING options]: " + Arrays.toString(argOIs));
        }
        this.rowOI = HiveUtils.asListOI(argOIs[0]);
        HiveUtils.validateFeatureOI(this.rowOI.getListElementObjectInspector());
        this.colNormsOI = HiveUtils.asMapOI(argOIs[1]);
        this.processOptions(argOIs);
        this.rnd = RandomNumberGeneratorFactory.createPRNG(1001L);
        this.colNorms = null;
        this.colProbs = null;
        ArrayList<String> fieldNames = new ArrayList<String>();
        fieldNames.add("j");
        fieldNames.add("k");
        fieldNames.add("b_jk");
        ArrayList<Object> fieldOIs = new ArrayList<Object>();
        if (this.parseFeatureAsInt) {
            fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
            fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
        } else {
            fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
            fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
        }
        fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
        return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
    }

    public void process(Object[] args) throws HiveException {
        Feature[] row = this.parseFeatures(args[0]);
        if (row == null) {
            return;
        }
        this.probes = row;
        if (this.colNorms == null || this.colProbs == null) {
            int numCols = this.colNormsOI.getMapSize(args[1]);
            if (this.sqrtGamma == Double.POSITIVE_INFINITY && this.threshold > 0.0) {
                this.sqrtGamma = Math.sqrt(10.0 * Math.log(numCols) / this.threshold);
            }
            this.colNorms = new HashMap<Object, Double>(numCols);
            this.colProbs = new HashMap<Object, Double>(numCols);
            Map m = this.colNormsOI.getMap(args[1]);
            for (Map.Entry e : m.entrySet()) {
                Object j = e.getKey();
                j = this.parseFeatureAsInt ? Integer.valueOf(HiveUtils.asJavaInt(j)) : j.toString();
                double norm = HiveUtils.asJavaDouble(e.getValue());
                if (norm == 0.0) {
                    norm = 1.0;
                }
                this.colNorms.put(j, norm);
                double p = Math.min(1.0, this.sqrtGamma / norm);
                this.colProbs.put(j, p);
            }
        }
        if (this.parseFeatureAsInt) {
            this.forwardAsIntFeature(row);
        } else {
            this.forwardAsStringFeature(row);
        }
    }

    private void forwardAsIntFeature(@Nonnull Feature[] row) throws HiveException {
        int length = row.length;
        Feature[] rowScaled = new Feature[length];
        for (int i = 0; i < length; ++i) {
            int j = row[i].getFeatureIndex();
            double norm = Primitives.doubleValue(this.colNorms.get(j), 0.0);
            if (norm == 0.0) {
                norm = 1.0;
            }
            double scaled = row[i].getValue() / Math.min(this.sqrtGamma, norm);
            rowScaled[i] = new IntFeature(j, scaled);
        }
        IntWritable jWritable = new IntWritable();
        IntWritable kWritable = new IntWritable();
        DoubleWritable bWritable = new DoubleWritable();
        Object[] forwardObjs = new Object[]{jWritable, kWritable, bWritable};
        for (int ij = 0; ij < length; ++ij) {
            int j = rowScaled[ij].getFeatureIndex();
            double jVal = rowScaled[ij].getValue();
            double jProb = Primitives.doubleValue(this.colProbs.get(j), 0.0);
            if (jVal == 0.0 || !(this.rnd.nextDouble() < jProb)) continue;
            for (int ik = ij + 1; ik < length; ++ik) {
                int k = rowScaled[ik].getFeatureIndex();
                double kVal = rowScaled[ik].getValue();
                double kProb = Primitives.doubleValue(this.colProbs.get(k), 0.0);
                if (kVal == 0.0 || !(this.rnd.nextDouble() < kProb)) continue;
                bWritable.set(jVal * kVal);
                if (this.symmetricOutput) {
                    jWritable.set(j);
                    kWritable.set(k);
                    this.forward(forwardObjs);
                    jWritable.set(k);
                    kWritable.set(j);
                    this.forward(forwardObjs);
                    continue;
                }
                if (j < k) {
                    jWritable.set(j);
                    kWritable.set(k);
                } else {
                    jWritable.set(k);
                    kWritable.set(j);
                }
                this.forward(forwardObjs);
            }
        }
    }

    private void forwardAsStringFeature(@Nonnull Feature[] row) throws HiveException {
        int length = row.length;
        Feature[] rowScaled = new Feature[length];
        for (int i = 0; i < length; ++i) {
            String j = row[i].getFeature();
            double norm = Primitives.doubleValue(this.colNorms.get(j), 0.0);
            if (norm == 0.0) {
                norm = 1.0;
            }
            double scaled = row[i].getValue() / Math.min(this.sqrtGamma, norm);
            rowScaled[i] = new StringFeature(j, scaled);
        }
        Text jWritable = new Text();
        Text kWritable = new Text();
        DoubleWritable bWritable = new DoubleWritable();
        Object[] forwardObjs = new Object[]{jWritable, kWritable, bWritable};
        for (int ij = 0; ij < length; ++ij) {
            String j = rowScaled[ij].getFeature();
            double jVal = rowScaled[ij].getValue();
            double jProb = Primitives.doubleValue(this.colProbs.get(j), 0.0);
            if (jVal == 0.0 || !(this.rnd.nextDouble() < jProb)) continue;
            for (int ik = ij + 1; ik < length; ++ik) {
                String k = rowScaled[ik].getFeature();
                double kVal = rowScaled[ik].getValue();
                double kProb = Primitives.doubleValue(this.colProbs.get(j), 0.0);
                if (kVal == 0.0 || !(this.rnd.nextDouble() < kProb)) continue;
                bWritable.set(jVal * kVal);
                if (this.symmetricOutput) {
                    jWritable.set(j);
                    kWritable.set(k);
                    this.forward(forwardObjs);
                    jWritable.set(k);
                    kWritable.set(j);
                    this.forward(forwardObjs);
                    continue;
                }
                if (j.compareTo(k) < 0) {
                    jWritable.set(j);
                    kWritable.set(k);
                } else {
                    jWritable.set(k);
                    kWritable.set(j);
                }
                this.forward(forwardObjs);
            }
        }
    }

    @Nullable
    protected Feature[] parseFeatures(@Nonnull Object arg) throws HiveException {
        return Feature.parseFeatures(arg, this.rowOI, this.probes, this.parseFeatureAsInt);
    }

    public void close() throws HiveException {
        this.probes = null;
        this.colNorms = null;
        this.colProbs = null;
    }
}

