/*
 * Decompiled with CFR 0.152.
 */
package hivemall.smile.tools;

import hivemall.utils.collections.lists.IntArrayList;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.hadoop.WritableUtils;
import hivemall.utils.lang.Counter;
import hivemall.utils.lang.Preconditions;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Map;
import javax.annotation.CheckForNull;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.hive.serde2.lazybinary.LazyBinaryArray;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StandardListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StandardMapObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StandardStructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructField;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.IntObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableDoubleObjectInspector;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.io.IntWritable;
import smile.math.Math;

@Description(name="rf_ensemble", value="_FUNC_(int yhat [, array<double> proba [, double model_weight=1.0]]) - Returns ensembled prediction results in <int label, double probability, array<double> probabilities>")
public final class RandomForestEnsembleUDAF
extends AbstractGenericUDAFResolver {
    public GenericUDAFEvaluator getEvaluator(@Nonnull TypeInfo[] typeInfo) throws SemanticException {
        switch (typeInfo.length) {
            case 1: {
                if (!HiveUtils.isIntegerTypeInfo(typeInfo[0])) {
                    throw new UDFArgumentTypeException(0, "Expected INT for yhat: " + typeInfo[0]);
                }
                return new RfEvaluatorV1();
            }
            case 3: {
                if (!HiveUtils.isFloatingPointTypeInfo(typeInfo[2])) {
                    throw new UDFArgumentTypeException(2, "Expected DOUBLE or FLOAT for model_weight: " + typeInfo[2]);
                }
            }
            case 2: {
                if (!HiveUtils.isIntegerTypeInfo(typeInfo[0])) {
                    throw new UDFArgumentTypeException(0, "Expected INT for yhat: " + typeInfo[0]);
                }
                if (!HiveUtils.isFloatingPointListTypeInfo(typeInfo[1])) {
                    throw new UDFArgumentTypeException(1, "ARRAY<double> is expected for a posteriori: " + typeInfo[1]);
                }
                return new RfEvaluatorV2();
            }
        }
        throw new UDFArgumentLengthException("Expected 1~3 arguments but got " + typeInfo.length);
    }

    public static final class RfAggregationBufferV2
    extends GenericUDAFEvaluator.AbstractAggregationBuffer {
        @Nullable
        private double[] _posteriori;
        private int _k;

        public RfAggregationBufferV2() {
            this.reset();
        }

        void reset() {
            this._posteriori = null;
            this._k = -1;
        }

        void iterate(int yhat, double weight, @Nonnull double[] posteriori) throws HiveException {
            if (this._posteriori == null) {
                this._k = posteriori.length;
                this._posteriori = new double[this._k];
            }
            if (yhat >= this._k) {
                throw new HiveException("Predicted class " + yhat + " is out of bounds: " + this._k);
            }
            if (posteriori.length != this._k) {
                throw new HiveException("Given |a posteriori| " + posteriori.length + " is differs from expected one: " + this._k);
            }
            int n = yhat;
            this._posteriori[n] = this._posteriori[n] + posteriori[yhat] * weight;
        }

        void merge(int size, @Nonnull Object posterioriObj, @Nonnull StandardListObjectInspector posterioriOI) throws HiveException {
            if (size != this._k) {
                if (this._k == -1) {
                    this._k = size;
                    this._posteriori = new double[size];
                } else {
                    throw new HiveException("Mismatch in the number of elements: _k=" + this._k + ", size=" + size);
                }
            }
            double[] posteriori = this._posteriori;
            WritableDoubleObjectInspector doubleOI = PrimitiveObjectInspectorFactory.writableDoubleObjectInspector;
            int i = 0;
            int len = this._k;
            while (i < len) {
                Object o2 = posterioriOI.getListElement(posterioriObj, i);
                int n = i++;
                posteriori[n] = posteriori[n] + doubleOI.get(o2);
            }
        }

        public int estimate() {
            if (this._k == -1) {
                return 0;
            }
            return 4 + this._k * 8;
        }
    }

    public static final class RfEvaluatorV2
    extends GenericUDAFEvaluator {
        private PrimitiveObjectInspector yhatOI;
        private ListObjectInspector posterioriOI;
        private PrimitiveObjectInspector posterioriElemOI;
        @Nullable
        private PrimitiveObjectInspector weightOI;
        private StructObjectInspector internalMergeOI;
        private StructField sizeField;
        private StructField posterioriField;
        private IntObjectInspector sizeFieldOI;
        private StandardListObjectInspector posterioriFieldOI;

        public ObjectInspector init(@Nonnull GenericUDAFEvaluator.Mode mode, @Nonnull ObjectInspector[] parameters) throws HiveException {
            StandardStructObjectInspector outputOI;
            super.init(mode, parameters);
            if (mode == GenericUDAFEvaluator.Mode.PARTIAL1 || mode == GenericUDAFEvaluator.Mode.COMPLETE) {
                this.yhatOI = HiveUtils.asIntegerOI(parameters[0]);
                this.posterioriOI = HiveUtils.asListOI(parameters[1]);
                this.posterioriElemOI = HiveUtils.asDoubleCompatibleOI(this.posterioriOI.getListElementObjectInspector());
                if (parameters.length == 3) {
                    this.weightOI = HiveUtils.asDoubleCompatibleOI(parameters[2]);
                }
            } else {
                StructObjectInspector soi;
                this.internalMergeOI = soi = (StructObjectInspector)parameters[0];
                this.sizeField = soi.getStructFieldRef("size");
                this.posterioriField = soi.getStructFieldRef("posteriori");
                this.sizeFieldOI = PrimitiveObjectInspectorFactory.writableIntObjectInspector;
                this.posterioriFieldOI = ObjectInspectorFactory.getStandardListObjectInspector((ObjectInspector)PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
            }
            if (mode == GenericUDAFEvaluator.Mode.PARTIAL1 || mode == GenericUDAFEvaluator.Mode.PARTIAL2) {
                ArrayList<String> fieldNames = new ArrayList<String>(3);
                ArrayList<Object> fieldOIs = new ArrayList<Object>(3);
                fieldNames.add("size");
                fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
                fieldNames.add("posteriori");
                fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector((ObjectInspector)PrimitiveObjectInspectorFactory.writableDoubleObjectInspector));
                outputOI = ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
            } else {
                ArrayList<String> fieldNames = new ArrayList<String>(3);
                ArrayList<Object> fieldOIs = new ArrayList<Object>(3);
                fieldNames.add("label");
                fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
                fieldNames.add("probability");
                fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
                fieldNames.add("probabilities");
                fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector((ObjectInspector)PrimitiveObjectInspectorFactory.writableDoubleObjectInspector));
                outputOI = ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
            }
            return outputOI;
        }

        public RfAggregationBufferV2 getNewAggregationBuffer() throws HiveException {
            RfAggregationBufferV2 buf = new RfAggregationBufferV2();
            this.reset((GenericUDAFEvaluator.AggregationBuffer)buf);
            return buf;
        }

        public void reset(GenericUDAFEvaluator.AggregationBuffer agg) throws HiveException {
            RfAggregationBufferV2 buf = (RfAggregationBufferV2)agg;
            buf.reset();
        }

        public void iterate(GenericUDAFEvaluator.AggregationBuffer agg, Object[] parameters) throws HiveException {
            RfAggregationBufferV2 buf = (RfAggregationBufferV2)agg;
            Preconditions.checkNotNull(parameters[0]);
            int yhat = PrimitiveObjectInspectorUtils.getInt((Object)parameters[0], (PrimitiveObjectInspector)this.yhatOI);
            Preconditions.checkNotNull(parameters[1]);
            double[] posteriori = HiveUtils.asDoubleArray(parameters[1], this.posterioriOI, this.posterioriElemOI);
            double weight = 1.0;
            if (parameters.length == 3) {
                Preconditions.checkNotNull(parameters[2]);
                weight = PrimitiveObjectInspectorUtils.getDouble((Object)parameters[2], (PrimitiveObjectInspector)this.weightOI);
            }
            buf.iterate(yhat, weight, posteriori);
        }

        public Object terminatePartial(GenericUDAFEvaluator.AggregationBuffer agg) throws HiveException {
            RfAggregationBufferV2 buf = (RfAggregationBufferV2)agg;
            if (buf._k == -1) {
                return null;
            }
            Object[] partial = new Object[]{new IntWritable(buf._k), WritableUtils.toWritableList(buf._posteriori)};
            return partial;
        }

        public void merge(GenericUDAFEvaluator.AggregationBuffer agg, Object partial) throws HiveException {
            if (partial == null) {
                return;
            }
            RfAggregationBufferV2 buf = (RfAggregationBufferV2)agg;
            Object o1 = this.internalMergeOI.getStructFieldData(partial, this.sizeField);
            int size = this.sizeFieldOI.get(o1);
            Object posteriori = this.internalMergeOI.getStructFieldData(partial, this.posterioriField);
            if (posteriori instanceof LazyBinaryArray) {
                posteriori = ((LazyBinaryArray)posteriori).getList();
            }
            buf.merge(size, posteriori, this.posterioriFieldOI);
        }

        public Object terminate(GenericUDAFEvaluator.AggregationBuffer agg) throws HiveException {
            RfAggregationBufferV2 buf = (RfAggregationBufferV2)agg;
            if (buf._k == -1) {
                return null;
            }
            double[] posteriori = buf._posteriori;
            int label = Math.whichMax(posteriori);
            Math.unitize1(posteriori);
            double proba = posteriori[label];
            Object[] result = new Object[]{new IntWritable(label), new DoubleWritable(proba), WritableUtils.toWritableList(posteriori)};
            return result;
        }
    }

    public static final class RfAggregationBufferV1
    extends GenericUDAFEvaluator.AbstractAggregationBuffer {
        @Nonnull
        private Counter<Integer> partial;

        public RfAggregationBufferV1() {
            this.reset();
        }

        void reset() {
            this.partial = new Counter();
        }

        void iterate(int k) {
            this.partial.increment(k);
        }

        @Nonnull
        Map<Integer, Integer> terminatePartial() {
            return this.partial.getMap();
        }

        void merge(int k, int v) {
            this.partial.increment(k, v);
        }

        @Nullable
        Object[] terminate() {
            Map<Integer, Integer> counts = this.partial.getMap();
            int size = counts.size();
            if (size == 0) {
                return null;
            }
            IntArrayList keyList = new IntArrayList(size);
            long totalCnt = 0L;
            Integer maxKey = null;
            int maxCnt = Integer.MIN_VALUE;
            for (Map.Entry<Integer, Integer> e : counts.entrySet()) {
                Integer key = e.getKey();
                keyList.add(key);
                int cnt = e.getValue();
                totalCnt += (long)cnt;
                if (cnt < maxCnt) continue;
                maxCnt = cnt;
                maxKey = key;
            }
            int[] keyArray = keyList.toArray();
            Arrays.sort(keyArray);
            int last = keyArray[keyArray.length - 1];
            double totalCnt_d = totalCnt;
            double[] probabilities = new double[java.lang.Math.max(2, last + 1)];
            int len = probabilities.length;
            for (int i = 0; i < len; ++i) {
                Integer cnt = counts.get(i);
                probabilities[i] = cnt == null ? 0.0 : (double)cnt.intValue() / totalCnt_d;
            }
            Object[] result = new Object[3];
            result[0] = new IntWritable(maxKey.intValue());
            double proba = (double)maxCnt / totalCnt_d;
            result[1] = new DoubleWritable(proba);
            result[2] = WritableUtils.toWritableList(probabilities);
            return result;
        }
    }

    @Deprecated
    public static final class RfEvaluatorV1
    extends GenericUDAFEvaluator {
        private PrimitiveObjectInspector yhatOI;
        private StandardMapObjectInspector internalMergeOI;
        private IntObjectInspector keyOI;
        private IntObjectInspector valueOI;

        public ObjectInspector init(@Nonnull GenericUDAFEvaluator.Mode mode, @Nonnull ObjectInspector[] argOIs) throws HiveException {
            StandardMapObjectInspector outputOI;
            super.init(mode, argOIs);
            if (mode == GenericUDAFEvaluator.Mode.PARTIAL1 || mode == GenericUDAFEvaluator.Mode.COMPLETE) {
                this.yhatOI = HiveUtils.asIntegerOI(argOIs, 0);
            } else {
                this.internalMergeOI = (StandardMapObjectInspector)argOIs[0];
                this.keyOI = HiveUtils.asIntOI(this.internalMergeOI.getMapKeyObjectInspector());
                this.valueOI = HiveUtils.asIntOI(this.internalMergeOI.getMapValueObjectInspector());
            }
            if (mode == GenericUDAFEvaluator.Mode.PARTIAL1 || mode == GenericUDAFEvaluator.Mode.PARTIAL2) {
                outputOI = ObjectInspectorFactory.getStandardMapObjectInspector((ObjectInspector)PrimitiveObjectInspectorFactory.javaIntObjectInspector, (ObjectInspector)PrimitiveObjectInspectorFactory.javaIntObjectInspector);
            } else {
                ArrayList<String> fieldNames = new ArrayList<String>(3);
                ArrayList<Object> fieldOIs = new ArrayList<Object>(3);
                fieldNames.add("label");
                fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
                fieldNames.add("probability");
                fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
                fieldNames.add("probabilities");
                fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector((ObjectInspector)PrimitiveObjectInspectorFactory.writableDoubleObjectInspector));
                outputOI = ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
            }
            return outputOI;
        }

        public RfAggregationBufferV1 getNewAggregationBuffer() throws HiveException {
            RfAggregationBufferV1 buf = new RfAggregationBufferV1();
            buf.reset();
            return buf;
        }

        public void reset(GenericUDAFEvaluator.AggregationBuffer agg) throws HiveException {
            RfAggregationBufferV1 buf = (RfAggregationBufferV1)agg;
            buf.reset();
        }

        public void iterate(GenericUDAFEvaluator.AggregationBuffer agg, Object[] parameters) throws HiveException {
            RfAggregationBufferV1 buf = (RfAggregationBufferV1)agg;
            Preconditions.checkNotNull(parameters[0]);
            int yhat = PrimitiveObjectInspectorUtils.getInt((Object)parameters[0], (PrimitiveObjectInspector)this.yhatOI);
            buf.iterate(yhat);
        }

        public Object terminatePartial(GenericUDAFEvaluator.AggregationBuffer agg) throws HiveException {
            RfAggregationBufferV1 buf = (RfAggregationBufferV1)agg;
            return buf.terminatePartial();
        }

        public void merge(GenericUDAFEvaluator.AggregationBuffer agg, Object partial) throws HiveException {
            RfAggregationBufferV1 buf = (RfAggregationBufferV1)agg;
            Map partialResult = this.internalMergeOI.getMap(partial);
            for (Map.Entry entry : partialResult.entrySet()) {
                this.putIntoMap(entry.getKey(), entry.getValue(), buf);
            }
        }

        private void putIntoMap(@CheckForNull Object key, @CheckForNull Object value, @Nonnull RfAggregationBufferV1 dst) {
            Preconditions.checkNotNull(key);
            Preconditions.checkNotNull(value);
            int k = this.keyOI.get(key);
            int v = this.valueOI.get(value);
            dst.merge(k, v);
        }

        public Object terminate(GenericUDAFEvaluator.AggregationBuffer agg) throws HiveException {
            RfAggregationBufferV1 buf = (RfAggregationBufferV1)agg;
            return buf.terminate();
        }
    }
}

