/*
 * Decompiled with CFR 0.152.
 */
package hivemall.evaluation;

import java.util.ArrayList;
import java.util.List;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDAF;
import org.apache.hadoop.hive.ql.exec.UDAFEvaluator;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.serde2.io.DoubleWritable;

@Description(name="r2", value="_FUNC_(double predicted, double actual) - Return R Squared (coefficient of determination)")
public final class R2UDAF
extends UDAF {

    public static class PartialResult {
        double residual_sum_of_squares = 0.0;
        List<Double> actuals = new ArrayList<Double>();
        double sum_actuals = 0.0;
        long count = 0L;

        PartialResult() {
        }

        void iterate(double predicted, double actual) {
            this.residual_sum_of_squares += Math.pow(actual - predicted, 2.0);
            this.actuals.add(actual);
            this.sum_actuals += actual;
            ++this.count;
        }

        void merge(PartialResult other) {
            this.residual_sum_of_squares += other.residual_sum_of_squares;
            this.actuals.addAll(other.actuals);
            this.sum_actuals += other.sum_actuals;
            this.count += other.count;
        }

        double getR2() {
            double avg_actuals = this.sum_actuals / (double)this.count;
            double total_sum_of_squares = 0.0;
            for (Double a : this.actuals) {
                total_sum_of_squares += Math.pow(a - avg_actuals, 2.0);
            }
            if (total_sum_of_squares == 0.0) {
                return 1.0;
            }
            return 1.0 - this.residual_sum_of_squares / total_sum_of_squares;
        }
    }

    public static class Evaluator
    implements UDAFEvaluator {
        private PartialResult partial;

        public void init() {
            this.partial = null;
        }

        public boolean iterate(DoubleWritable predicted, DoubleWritable actual) throws HiveException {
            if (predicted == null || actual == null) {
                return true;
            }
            if (this.partial == null) {
                this.partial = new PartialResult();
            }
            this.partial.iterate(predicted.get(), actual.get());
            return true;
        }

        public PartialResult terminatePartial() {
            return this.partial;
        }

        public boolean merge(PartialResult other) throws HiveException {
            if (other == null) {
                return true;
            }
            if (this.partial == null) {
                this.partial = new PartialResult();
            }
            this.partial.merge(other);
            return true;
        }

        public double terminate() {
            if (this.partial == null) {
                return 0.0;
            }
            return this.partial.getR2();
        }
    }
}

