/*
 * Decompiled with CFR 0.152.
 */
package hivemall.ensemble.bagging;

import javax.annotation.Nullable;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDAF;
import org.apache.hadoop.hive.ql.exec.UDAFEvaluator;
import org.apache.hadoop.hive.serde2.io.DoubleWritable;

@Description(name="voted_avg", value="_FUNC_(double value) - Returns an averaged value by bagging for classification")
public final class VotedAvgUDAF
extends UDAF {

    public static class Evaluator
    implements UDAFEvaluator {
        private PartialResult partial;

        public void init() {
            this.partial = null;
        }

        public boolean iterate(@Nullable DoubleWritable o) {
            double w;
            if (o == null) {
                return true;
            }
            if (this.partial == null) {
                this.partial = new PartialResult();
                this.partial.init();
            }
            if ((w = o.get()) > 0.0) {
                this.partial.positiveSum += w;
                ++this.partial.positiveCnt;
            } else if (w < 0.0) {
                this.partial.negativeSum += w;
                ++this.partial.negativeCnt;
            }
            return true;
        }

        public PartialResult terminatePartial() {
            return this.partial;
        }

        public boolean merge(PartialResult other) {
            if (other == null) {
                return true;
            }
            if (this.partial == null) {
                this.partial = new PartialResult();
                this.partial.init();
            }
            this.partial.positiveSum += other.positiveSum;
            this.partial.positiveCnt += other.positiveCnt;
            this.partial.negativeSum += other.negativeSum;
            this.partial.negativeCnt += other.negativeCnt;
            return true;
        }

        public DoubleWritable terminate() {
            if (this.partial == null) {
                return null;
            }
            if (this.partial.positiveCnt > this.partial.negativeCnt) {
                return new DoubleWritable(this.partial.positiveSum / (double)this.partial.positiveCnt);
            }
            if (this.partial.negativeCnt == 0) {
                assert (this.partial.negativeSum == 0.0) : this.partial.negativeSum;
                return new DoubleWritable(0.0);
            }
            return new DoubleWritable(this.partial.negativeSum / (double)this.partial.negativeCnt);
        }

        public static class PartialResult {
            double positiveSum;
            int positiveCnt;
            double negativeSum;
            int negativeCnt;

            void init() {
                this.positiveSum = 0.0;
                this.positiveCnt = 0;
                this.negativeSum = 0.0;
                this.negativeCnt = 0;
            }
        }
    }
}

