/*
 * Decompiled with CFR 0.152.
 */
package hivemall.classifier;

import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.lang.Preconditions;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;

@Description(name="kpa_predict", value="_FUNC_(@Nonnull double xh, @Nonnull double xk, @Nullable float w0, @Nonnull float w1, @Nonnull float w2, @Nullable float w3) - Returns a prediction value in Double")
public final class KPAPredictUDAF
extends AbstractGenericUDAFResolver {
    public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters) throws SemanticException {
        if (parameters.length != 6) {
            throw new UDFArgumentException("_FUNC_(double xh, double xk, float w0, float w1, float w2, float w3) takes exactly 6 arguments but got: " + parameters.length);
        }
        if (!HiveUtils.isNumberTypeInfo(parameters[0])) {
            throw new UDFArgumentTypeException(0, "Number type is expected for xh (1st argument): " + parameters[0].getTypeName());
        }
        if (!HiveUtils.isNumberTypeInfo(parameters[1])) {
            throw new UDFArgumentTypeException(1, "Number type is expected for xk (2nd argument): " + parameters[1].getTypeName());
        }
        if (!HiveUtils.isNumberTypeInfo(parameters[2])) {
            throw new UDFArgumentTypeException(2, "Number type is expected for w0 (3rd argument): " + parameters[2].getTypeName());
        }
        if (!HiveUtils.isNumberTypeInfo(parameters[3])) {
            throw new UDFArgumentTypeException(3, "Number type is expected for w1 (4th argument): " + parameters[3].getTypeName());
        }
        if (!HiveUtils.isNumberTypeInfo(parameters[4])) {
            throw new UDFArgumentTypeException(4, "Number type is expected for w2 (5th argument): " + parameters[4].getTypeName());
        }
        if (!HiveUtils.isNumberTypeInfo(parameters[5])) {
            throw new UDFArgumentTypeException(5, "Number type is expected for w3 (6th argument): " + parameters[5].getTypeName());
        }
        return new Evaluator();
    }

    @GenericUDAFEvaluator.AggregationType(estimable=true)
    static class AggrBuffer
    extends GenericUDAFEvaluator.AbstractAggregationBuffer {
        double score;

        AggrBuffer() {
            this.reset();
        }

        public int estimate() {
            return 8;
        }

        void reset() {
            this.score = 0.0;
        }

        double get() {
            return this.score;
        }

        void addW0(@Nonnull double w0) {
            this.score += w0;
        }

        void addW1W2(double xh, double w1h, double w2h) {
            this.score += w1h * xh + w2h * xh * xh;
        }

        void addW3(double xh, double xk, double w3hk) {
            this.score += w3hk * xh * xk;
        }

        void merge(double other) {
            this.score += other;
        }
    }

    public static class Evaluator
    extends GenericUDAFEvaluator {
        @Nullable
        private transient PrimitiveObjectInspector xhOI;
        @Nullable
        private transient PrimitiveObjectInspector xkOI;
        @Nullable
        private transient PrimitiveObjectInspector w0OI;
        @Nullable
        private transient PrimitiveObjectInspector w1OI;
        @Nullable
        private transient PrimitiveObjectInspector w2OI;
        @Nullable
        private transient PrimitiveObjectInspector w3OI;

        public ObjectInspector init(GenericUDAFEvaluator.Mode m, ObjectInspector[] parameters) throws HiveException {
            super.init(m, parameters);
            if (m == GenericUDAFEvaluator.Mode.PARTIAL1 || m == GenericUDAFEvaluator.Mode.COMPLETE) {
                this.xhOI = HiveUtils.asNumberOI(parameters[0]);
                this.xkOI = HiveUtils.asNumberOI(parameters[1]);
                this.w0OI = HiveUtils.asNumberOI(parameters[2]);
                this.w1OI = HiveUtils.asNumberOI(parameters[3]);
                this.w2OI = HiveUtils.asNumberOI(parameters[4]);
                this.w3OI = HiveUtils.asNumberOI(parameters[5]);
            }
            return PrimitiveObjectInspectorFactory.writableDoubleObjectInspector;
        }

        public AggrBuffer getNewAggregationBuffer() throws HiveException {
            return new AggrBuffer();
        }

        public void reset(GenericUDAFEvaluator.AggregationBuffer agg) throws HiveException {
            AggrBuffer aggr = (AggrBuffer)agg;
            aggr.reset();
        }

        public void iterate(GenericUDAFEvaluator.AggregationBuffer agg, Object[] parameters) throws HiveException {
            Preconditions.checkArgument(parameters.length == 6, HiveException.class);
            AggrBuffer aggr = (AggrBuffer)agg;
            if (parameters[0] != null) {
                double xh = HiveUtils.getDouble(parameters[0], this.xhOI);
                if (parameters[1] != null) {
                    if (parameters[5] == null) {
                        return;
                    }
                    double xk = HiveUtils.getDouble(parameters[1], this.xkOI);
                    double w3hk = HiveUtils.getDouble(parameters[5], this.w3OI);
                    aggr.addW3(xh, xk, w3hk);
                } else {
                    if (parameters[3] == null) {
                        return;
                    }
                    Preconditions.checkNotNull(parameters[4], HiveException.class);
                    double w1h = HiveUtils.getDouble(parameters[3], this.w1OI);
                    double w2h = HiveUtils.getDouble(parameters[4], this.w2OI);
                    aggr.addW1W2(xh, w1h, w2h);
                }
            } else if (parameters[2] != null) {
                double w0 = HiveUtils.getDouble(parameters[2], this.w0OI);
                aggr.addW0(w0);
            }
        }

        public Object terminatePartial(GenericUDAFEvaluator.AggregationBuffer agg) throws HiveException {
            AggrBuffer aggr = (AggrBuffer)agg;
            double v = aggr.get();
            return new DoubleWritable(v);
        }

        public void merge(GenericUDAFEvaluator.AggregationBuffer agg, Object partial) throws HiveException {
            if (partial == null) {
                return;
            }
            AggrBuffer aggr = (AggrBuffer)agg;
            DoubleWritable other = (DoubleWritable)partial;
            double v = other.get();
            aggr.merge(v);
        }

        public DoubleWritable terminate(GenericUDAFEvaluator.AggregationBuffer agg) throws HiveException {
            AggrBuffer aggr = (AggrBuffer)agg;
            double v = aggr.get();
            return new DoubleWritable(v);
        }
    }
}

