/*
 * Decompiled with CFR 0.152.
 */
package hivemall.dataset;

import hivemall.UDTFWithOptions;
import hivemall.utils.hadoop.HadoopUtils;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.lang.NumberUtils;
import hivemall.utils.lang.Primitives;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.BitSet;
import java.util.Comparator;
import java.util.Random;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;

@Description(name="lr_datagen", value="_FUNC_(options string) - Generates a logistic regression dataset", extended="WITH dual AS (SELECT 1) SELECT lr_datagen('-n_examples 1k -n_features 10') FROM dual;")
public final class LogisticRegressionDataGeneratorUDTF
extends UDTFWithOptions {
    private static final int N_BUFFERS = 1000;
    private int position;
    private float[] labels;
    private String[][] featuresArray;
    private Float[][] featuresFloatArray;
    private int n_examples;
    private int n_features;
    private int n_dimensions;
    private float eps;
    private float prob_one;
    private long r_seed;
    private boolean dense;
    private boolean sort;
    private boolean classification;
    private Random rnd1 = null;
    private Random rnd2 = null;

    @Override
    protected Options getOptions() {
        Options opts = new Options();
        opts.addOption("ne", "n_examples", true, "Number of training examples created for each task [DEFAULT: 1000]");
        opts.addOption("nf", "n_features", true, "Number of features contained for each example [DEFAULT: 10]");
        opts.addOption("nd", "n_dims", true, "The size of feature dimensions [DEFAULT: 200]");
        opts.addOption("eps", true, "eps Epsilon factor by which positive examples are scaled [DEFAULT: 3.0]");
        opts.addOption("p1", "prob_one", true, " Probability in [0, 1.0) that a label is 1 [DEFAULT: 0.6]");
        opts.addOption("seed", true, "The seed value for random number generator [DEFAULT: 43L]");
        opts.addOption("dense", false, "Make a dense dataset or not. If not specified, a sparse dataset is generated.\nFor sparse, n_dims should be much larger than n_features. When disabled, n_features must be equals to n_dims ");
        opts.addOption("sort", false, "Sort features if specified (used only for sparse dataset)");
        opts.addOption("cl", "classification", false, "Toggle this option on to generate a classification dataset");
        return opts;
    }

    @Override
    protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException {
        if (argOIs.length != 1) {
            throw new UDFArgumentException("Expected number of arguments is 1: " + argOIs.length);
        }
        String opts = HiveUtils.getConstString(argOIs[0]);
        CommandLine cl = this.parseOptions(opts);
        this.n_examples = NumberUtils.parseInt(cl.getOptionValue("n_examples"), 1000);
        this.n_features = NumberUtils.parseInt(cl.getOptionValue("n_features"), 10);
        this.n_dimensions = NumberUtils.parseInt(cl.getOptionValue("n_dims"), 200);
        this.eps = Primitives.parseFloat(cl.getOptionValue("eps"), 3.0f);
        this.prob_one = Primitives.parseFloat(cl.getOptionValue("prob_one"), 0.6f);
        this.r_seed = Primitives.parseLong(cl.getOptionValue("seed"), 43L);
        this.dense = cl.hasOption("dense");
        this.sort = cl.hasOption("sort");
        this.classification = cl.hasOption("classification");
        if (this.n_features > this.n_dimensions) {
            throw new UDFArgumentException("n_features '" + this.n_features + "' should be greater than or equals to n_dimensions '" + this.n_dimensions + "'");
        }
        if (this.dense && this.n_features != this.n_dimensions) {
            throw new UDFArgumentException("n_features '" + this.n_features + "' must be equals to n_dimensions '" + this.n_dimensions + "' when making a dense dataset");
        }
        return cl;
    }

    public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
        this.processOptions(argOIs);
        this.init();
        ArrayList<String> fieldNames = new ArrayList<String>(2);
        ArrayList<Object> fieldOIs = new ArrayList<Object>(2);
        fieldNames.add("label");
        fieldOIs.add(PrimitiveObjectInspectorFactory.javaFloatObjectInspector);
        fieldNames.add("features");
        if (this.dense) {
            fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector((ObjectInspector)PrimitiveObjectInspectorFactory.javaFloatObjectInspector));
        } else {
            fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector((ObjectInspector)PrimitiveObjectInspectorFactory.javaStringObjectInspector));
        }
        return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
    }

    private void init() {
        this.labels = new float[1000];
        if (this.dense) {
            this.featuresFloatArray = new Float[1000][this.n_features];
        } else {
            this.featuresArray = new String[1000][this.n_features];
        }
        this.position = 0;
    }

    public void process(Object[] argOIs) throws HiveException {
        if (this.rnd1 == null) {
            assert (this.rnd2 == null);
            int taskid = HadoopUtils.getTaskId(-1);
            long seed = taskid == -1 ? this.r_seed : this.r_seed + (long)taskid;
            this.rnd1 = new Random(seed);
            this.rnd2 = new Random(seed + 1L);
        }
        for (int i = 0; i < this.n_examples; ++i) {
            if (this.dense) {
                this.generateDenseData();
            } else {
                this.generateSparseData();
            }
            ++this.position;
            if (this.position != 1000) continue;
            this.flushBuffered(this.position);
            this.position = 0;
        }
    }

    private void generateSparseData() throws HiveException {
        float label = this.rnd1.nextFloat();
        float sign = label <= this.prob_one ? 1.0f : 0.0f;
        this.labels[this.position] = this.classification ? sign : label;
        String[] features = this.featuresArray[this.position];
        assert (features != null);
        BitSet used = new BitSet(this.n_dimensions);
        int searchClearBitsFrom = 0;
        int retry = 0;
        for (int i = 0; i < this.n_features; ++i) {
            String y;
            int f = this.rnd2.nextInt(this.n_dimensions);
            if (used.get(f)) {
                if (retry < 3) {
                    --i;
                    ++retry;
                    continue;
                }
                f = searchClearBitsFrom = used.nextClearBit(searchClearBitsFrom);
            }
            used.set(f);
            float w = (float)this.rnd2.nextGaussian() + sign * this.eps;
            features[i] = y = f + ":" + w;
            retry = 0;
        }
        if (this.sort) {
            Arrays.sort(features, new Comparator<String>(){

                @Override
                public int compare(String o1, String o2) {
                    int i1 = Integer.parseInt(o1.split(":")[0]);
                    int i2 = Integer.parseInt(o2.split(":")[0]);
                    return Primitives.compare(i1, i2);
                }
            });
        }
    }

    private void generateDenseData() {
        float label = this.rnd1.nextFloat();
        float sign = label >= this.prob_one ? 1.0f : 0.0f;
        this.labels[this.position] = this.classification ? sign : label;
        Float[] features = this.featuresFloatArray[this.position];
        assert (features != null);
        for (int i = 0; i < this.n_features; ++i) {
            float w = (float)this.rnd2.nextGaussian() + sign * this.eps;
            features[i] = Float.valueOf(w);
        }
    }

    private void flushBuffered(int position) throws HiveException {
        Object[] forwardObjs = new Object[2];
        if (this.dense) {
            for (int i = 0; i < position; ++i) {
                forwardObjs[0] = Float.valueOf(this.labels[i]);
                forwardObjs[1] = Arrays.asList(this.featuresFloatArray[i]);
                this.forward(forwardObjs);
            }
        } else {
            for (int i = 0; i < position; ++i) {
                forwardObjs[0] = Float.valueOf(this.labels[i]);
                forwardObjs[1] = Arrays.asList(this.featuresArray[i]);
                this.forward(forwardObjs);
            }
        }
    }

    public void close() throws HiveException {
        if (this.position > 0) {
            this.flushBuffered(this.position);
        }
        this.labels = null;
        this.featuresArray = null;
        this.featuresFloatArray = null;
    }
}

