/*
 * Decompiled with CFR 0.152.
 */
package hivemall.ftvec.selection;

import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.hadoop.WritableUtils;
import hivemall.utils.lang.Preconditions;
import hivemall.utils.math.StatsUtils;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.UDFType;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StandardListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;

@Description(name="chi2", value="_FUNC_(array<array<number>> observed, array<array<number>> expected) - Returns chi2_val and p_val of each columns as <array<double>, array<double>>")
@UDFType(deterministic=true, stateful=false)
public final class ChiSquareUDF
extends GenericUDF {
    private ListObjectInspector observedOI;
    private ListObjectInspector observedRowOI;
    private PrimitiveObjectInspector observedElOI;
    private ListObjectInspector expectedOI;
    private ListObjectInspector expectedRowOI;
    private PrimitiveObjectInspector expectedElOI;
    private int nFeatures = -1;
    private double[] observedRow = null;
    private double[] expectedRow = null;
    private double[][] observed = null;
    private double[][] expected = null;
    private List<DoubleWritable>[] result;

    public ObjectInspector initialize(ObjectInspector[] OIs) throws UDFArgumentException {
        if (OIs.length != 2) {
            throw new UDFArgumentLengthException("Specify two arguments: " + OIs.length);
        }
        if (!HiveUtils.isNumberListListOI(OIs[0])) {
            throw new UDFArgumentTypeException(0, "Only array<array<number>> type argument is acceptable but " + OIs[0].getTypeName() + " was passed as `observed`");
        }
        if (!HiveUtils.isNumberListListOI(OIs[1])) {
            throw new UDFArgumentTypeException(1, "Only array<array<number>> type argument is acceptable but " + OIs[1].getTypeName() + " was passed as `expected`");
        }
        this.observedOI = HiveUtils.asListOI(OIs[1]);
        this.observedRowOI = HiveUtils.asListOI(this.observedOI.getListElementObjectInspector());
        this.observedElOI = HiveUtils.asDoubleCompatibleOI(this.observedRowOI.getListElementObjectInspector());
        this.expectedOI = HiveUtils.asListOI(OIs[0]);
        this.expectedRowOI = HiveUtils.asListOI(this.expectedOI.getListElementObjectInspector());
        this.expectedElOI = HiveUtils.asDoubleCompatibleOI(this.expectedRowOI.getListElementObjectInspector());
        this.result = new List[2];
        ArrayList<StandardListObjectInspector> fieldOIs = new ArrayList<StandardListObjectInspector>();
        fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector((ObjectInspector)PrimitiveObjectInspectorFactory.writableDoubleObjectInspector));
        fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector((ObjectInspector)PrimitiveObjectInspectorFactory.writableDoubleObjectInspector));
        return ObjectInspectorFactory.getStandardStructObjectInspector(Arrays.asList("chi2", "pvalue"), fieldOIs);
    }

    public List<DoubleWritable>[] evaluate(GenericUDF.DeferredObject[] dObj) throws HiveException {
        List observedObj = this.observedOI.getList(dObj[0].get());
        List expectedObj = this.expectedOI.getList(dObj[1].get());
        if (observedObj == null || expectedObj == null) {
            return null;
        }
        int nClasses = observedObj.size();
        Preconditions.checkArgument(nClasses == expectedObj.size(), UDFArgumentException.class);
        for (int i = 0; i < nClasses; ++i) {
            Object observedObjRow = observedObj.get(i);
            Object expectedObjRow = expectedObj.get(i);
            Preconditions.checkNotNull(observedObjRow, UDFArgumentException.class);
            Preconditions.checkNotNull(expectedObjRow, UDFArgumentException.class);
            if (this.observedRow == null) {
                this.observedRow = HiveUtils.asDoubleArray(observedObjRow, this.observedRowOI, this.observedElOI, false);
                this.expectedRow = HiveUtils.asDoubleArray(expectedObjRow, this.expectedRowOI, this.expectedElOI, false);
                this.nFeatures = this.observedRow.length;
                this.observed = new double[this.nFeatures][nClasses];
                this.expected = new double[this.nFeatures][nClasses];
            } else {
                HiveUtils.toDoubleArray(observedObjRow, this.observedRowOI, this.observedElOI, this.observedRow, false);
                HiveUtils.toDoubleArray(expectedObjRow, this.expectedRowOI, this.expectedElOI, this.expectedRow, false);
            }
            for (int j = 0; j < this.nFeatures; ++j) {
                this.observed[j][i] = this.observedRow[j];
                this.expected[j][i] = this.expectedRow[j];
            }
        }
        Map.Entry<double[], double[]> chi2 = StatsUtils.chiSquare(this.observed, this.expected);
        this.result[0] = WritableUtils.toWritableList(chi2.getKey(), this.result[0]);
        this.result[1] = WritableUtils.toWritableList(chi2.getValue(), this.result[1]);
        return this.result;
    }

    public void close() throws IOException {
        this.observedRow = null;
        this.expectedRow = null;
        this.observed = null;
        this.expected = null;
        this.result = null;
    }

    public String getDisplayString(String[] children) {
        StringBuilder sb = new StringBuilder();
        sb.append("chi2");
        sb.append("(");
        if (children.length > 0) {
            sb.append(children[0]);
            for (int i = 1; i < children.length; ++i) {
                sb.append(", ");
                sb.append(children[i]);
            }
        }
        sb.append(")");
        return sb.toString();
    }
}

