/*
 * Decompiled with CFR 0.152.
 */
package hex;

import hex.Distribution;
import hex.MeanResidualDeviance;
import hex.Model;
import hex.ModelMetrics;
import hex.ModelMetricsSupervised;
import water.MRTask;
import water.exceptions.H2OIllegalArgumentException;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.fvec.Vec;
import water.util.ArrayUtils;
import water.util.MathUtils;

public class ModelMetricsRegression
extends ModelMetricsSupervised {
    public final double _mean_residual_deviance;
    public final double _mean_absolute_error;

    public double residual_deviance() {
        return this._mean_residual_deviance;
    }

    public ModelMetricsRegression(Model model, Frame frame, long nobs, double mse, double sigma, double mae, double meanResidualDeviance) {
        super(model, frame, nobs, mse, null, sigma);
        this._mean_residual_deviance = meanResidualDeviance;
        this._mean_absolute_error = mae;
    }

    public static ModelMetricsRegression getFromDKV(Model model, Frame frame) {
        ModelMetrics mm = ModelMetrics.getFromDKV(model, frame);
        if (!(mm instanceof ModelMetricsRegression)) {
            throw new H2OIllegalArgumentException("Expected to find a Regression ModelMetrics for model: " + model._key.toString() + " and frame: " + frame._key.toString(), "Expected to find a ModelMetricsRegression for model: " + model._key.toString() + " and frame: " + frame._key.toString() + " but found a: " + mm.getClass());
        }
        return (ModelMetricsRegression)mm;
    }

    @Override
    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append(super.toString());
        sb.append(" mean residual deviance: " + (float)this._mean_residual_deviance + "\n");
        sb.append(" mean absolute error: " + (float)this._mean_absolute_error + "\n");
        return sb.toString();
    }

    public static ModelMetricsRegression make(Vec predicted, Vec actual, Distribution.Family distribution) {
        if (predicted == null || actual == null) {
            throw new IllegalArgumentException("Missing actual or predicted targets for regression metrics!");
        }
        if (!predicted.isNumeric()) {
            throw new IllegalArgumentException("Predicted values must be numeric for regression metrics.");
        }
        if (!actual.isNumeric()) {
            throw new IllegalArgumentException("Actual values must be numeric for regression metrics.");
        }
        if (distribution == Distribution.Family.quantile || distribution == Distribution.Family.tweedie || distribution == Distribution.Family.huber) {
            throw new IllegalArgumentException("Unsupported distribution family, requires additional parameters which cannot be specified right now.");
        }
        Frame predsActual = new Frame(predicted);
        predsActual.add("actual", actual);
        MetricBuilderRegression mb = ((RegressionMetrics)new RegressionMetrics((Distribution.Family)distribution).doAll((Frame)predsActual))._mb;
        ModelMetricsRegression mm = (ModelMetricsRegression)mb.makeModelMetrics(null, predsActual, null, null);
        mm._description = "Computed on user-given predictions and targets, distribution: " + (distribution == null ? Distribution.Family.gaussian.toString() : distribution.toString()) + ".";
        return mm;
    }

    public static class MetricBuilderRegression<T extends MetricBuilderRegression<T>>
    extends ModelMetricsSupervised.MetricBuilderSupervised<T> {
        double _sumdeviance;
        Distribution _dist;
        double _abserror;

        public MetricBuilderRegression() {
            super(1, null);
        }

        public MetricBuilderRegression(Distribution dist) {
            super(1, null);
            this._dist = dist;
        }

        @Override
        public double[] perRow(double[] ds, float[] yact, Model m) {
            return this.perRow(ds, yact, 1.0, 0.0, m);
        }

        @Override
        public double[] perRow(double[] ds, float[] yact, double w, double o, Model m) {
            if (Float.isNaN(yact[0])) {
                return ds;
            }
            if (ArrayUtils.hasNaNs(ds)) {
                return ds;
            }
            if (w == 0.0 || Double.isNaN(w)) {
                return ds;
            }
            double err = (double)yact[0] - ds[0];
            this._sumsqe += w * err * err;
            this._abserror += Math.abs(err);
            assert (!Double.isNaN(this._sumsqe));
            if (m != null && ((Model.Parameters)m._parms)._distribution != Distribution.Family.huber) {
                this._sumdeviance += m.deviance(w, yact[0], ds[0]);
            } else if (this._dist != null) {
                this._sumdeviance += this._dist.deviance(w, yact[0], ds[0]);
            }
            ++this._count;
            this._wcount += w;
            this._wY += w * (double)yact[0];
            this._wYY += w * (double)yact[0] * (double)yact[0];
            return ds;
        }

        @Override
        public void reduce(T mb) {
            super.reduce(mb);
            this._sumdeviance += ((MetricBuilderRegression)mb)._sumdeviance;
            this._abserror += ((MetricBuilderRegression)mb)._abserror;
        }

        @Override
        public ModelMetrics makeModelMetrics(Model m, Frame f, Frame adaptedFrame, Frame preds) {
            double mse = this._sumsqe / this._wcount;
            double mae = this._abserror / this._wcount;
            if (adaptedFrame == null) {
                adaptedFrame = f;
            }
            double meanResDeviance = 0.0;
            if (m != null && ((Model.Parameters)m._parms)._distribution == Distribution.Family.huber) {
                assert (this._sumdeviance == 0.0);
                if (preds != null) {
                    Vec actual = adaptedFrame.vec(((Model.Parameters)m._parms)._response_column);
                    Vec absdiff = ((MRTask)new MRTask(){

                        @Override
                        public void map(Chunk[] cs, NewChunk[] nc) {
                            for (int i = 0; i < cs[0].len(); ++i) {
                                nc[0].addNum(Math.abs(cs[0].atd(i) - cs[1].atd(i)));
                            }
                        }
                    }.doAll(1, (byte)3, new Frame(new String[]{"preds", "actual"}, new Vec[]{preds.anyVec(), actual}))).outputFrame().anyVec();
                    Distribution dist = new Distribution((Model.Parameters)m._parms);
                    Vec weight = adaptedFrame.vec(((Model.Parameters)m._parms)._weights_column);
                    double huberDelta = MathUtils.computeWeightedQuantile(weight, absdiff, ((Model.Parameters)m._parms)._huber_alpha);
                    absdiff.remove();
                    dist.setHuberDelta(huberDelta);
                    meanResDeviance = new MeanResidualDeviance((Distribution)dist, (Vec)preds.anyVec(), (Vec)actual, (Vec)weight).exec().meanResidualDeviance;
                }
            } else {
                meanResDeviance = this._sumdeviance / this._wcount;
            }
            ModelMetricsRegression mm = new ModelMetricsRegression(m, f, this._count, mse, this.weightedSigma(), mae, meanResDeviance);
            if (m != null) {
                ((Model.Output)m._output).addModelMetrics(mm);
            }
            return mm;
        }
    }

    private static class RegressionMetrics
    extends MRTask<RegressionMetrics> {
        public MetricBuilderRegression _mb;
        final Distribution _distribution;

        RegressionMetrics(Distribution.Family distribution) {
            this._distribution = distribution == null ? new Distribution(Distribution.Family.gaussian) : new Distribution(distribution);
        }

        @Override
        public void map(Chunk[] chks) {
            this._mb = new MetricBuilderRegression(this._distribution);
            Chunk preds = chks[0];
            Chunk actuals = chks[1];
            double[] ds = new double[1];
            for (int i = 0; i < chks[0]._len; ++i) {
                ds[0] = preds.atd(i);
                this._mb.perRow(ds, new float[]{(float)actuals.atd(i)}, null);
            }
        }

        @Override
        public void reduce(RegressionMetrics mrt) {
            this._mb.reduce(mrt._mb);
        }
    }
}

