/*
 * Decompiled with CFR 0.152.
 */
package hex;

import hex.CustomMetric;
import hex.Distribution;
import hex.DistributionFactory;
import hex.MeanResidualDeviance;
import hex.Model;
import hex.ModelMetrics;
import hex.ModelMetricsSupervised;
import hex.genmodel.utils.DistributionFamily;
import water.IcedUtils;
import water.MRTask;
import water.exceptions.H2OIllegalArgumentException;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.fvec.Vec;
import water.util.ArrayUtils;
import water.util.MathUtils;

public class ModelMetricsRegression
extends ModelMetricsSupervised {
    public final double _mean_residual_deviance;
    public final double _mean_absolute_error;
    public final double _root_mean_squared_log_error;

    public double residual_deviance() {
        return this._mean_residual_deviance;
    }

    public double mean_residual_deviance() {
        return this._mean_residual_deviance;
    }

    public double mae() {
        return this._mean_absolute_error;
    }

    public double rmsle() {
        return this._root_mean_squared_log_error;
    }

    public ModelMetricsRegression(Model model, Frame frame, long nobs, double mse, double sigma, double mae, double rmsle, double meanResidualDeviance, CustomMetric customMetric) {
        super(model, frame, nobs, mse, null, sigma, customMetric);
        this._mean_residual_deviance = meanResidualDeviance;
        this._mean_absolute_error = mae;
        this._root_mean_squared_log_error = rmsle;
    }

    public static ModelMetricsRegression getFromDKV(Model model, Frame frame) {
        ModelMetrics mm4 = ModelMetrics.getFromDKV(model, frame);
        if (!(mm4 instanceof ModelMetricsRegression)) {
            throw new H2OIllegalArgumentException("Expected to find a Regression ModelMetrics for model: " + model._key.toString() + " and frame: " + frame._key.toString(), "Expected to find a ModelMetricsRegression for model: " + model._key.toString() + " and frame: " + frame._key.toString() + " but found a: " + mm4.getClass());
        }
        return (ModelMetricsRegression)mm4;
    }

    @Override
    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append(super.toString());
        if (!Double.isNaN(this._mean_residual_deviance)) {
            sb.append(" mean residual deviance: " + (float)this._mean_residual_deviance + "\n");
        } else {
            sb.append(" mean residual deviance: N/A\n");
        }
        sb.append(" mean absolute error: " + (float)this._mean_absolute_error + "\n");
        sb.append(" root mean squared log error: " + (float)this._root_mean_squared_log_error + "\n");
        return sb.toString();
    }

    public static ModelMetricsRegression make(Vec predicted, Vec actual, DistributionFamily family) {
        return ModelMetricsRegression.make(predicted, actual, null, family);
    }

    public static ModelMetricsRegression make(Vec predicted, Vec actual, Vec weights, DistributionFamily family) {
        if (predicted == null || actual == null) {
            throw new IllegalArgumentException("Missing actual or predicted targets for regression metrics!");
        }
        if (!predicted.isNumeric()) {
            throw new IllegalArgumentException("Predicted values must be numeric for regression metrics.");
        }
        if (!actual.isNumeric()) {
            throw new IllegalArgumentException("Actual values must be numeric for regression metrics.");
        }
        if (family == DistributionFamily.quantile || family == DistributionFamily.tweedie || family == DistributionFamily.huber) {
            throw new IllegalArgumentException("Unsupported distribution family, requires additional parameters which cannot be specified right now.");
        }
        Frame fr = new Frame(predicted);
        fr.add("actual", actual);
        if (weights != null) {
            fr.add("weights", weights);
        }
        family = family == null ? DistributionFamily.gaussian : family;
        MetricBuilderRegression mb = ((RegressionMetrics)new RegressionMetrics((DistributionFamily)family).doAll((Frame)fr))._mb;
        ModelMetricsRegression mm4 = mb.makeModelMetrics(null, fr, null, null);
        mm4._description = "Computed on user-given predictions and targets, distribution: " + family.toString() + ".";
        return mm4;
    }

    public static double computeHuberDelta(Vec actual, Vec preds, Vec weight, double huberAlpha) {
        Vec absdiff = ((MRTask)new MRTask(){

            @Override
            public void map(Chunk[] cs, NewChunk[] nc) {
                for (int i2 = 0; i2 < cs[0].len(); ++i2) {
                    nc[0].addNum(Math.abs(cs[0].atd(i2) - cs[1].atd(i2)));
                }
            }
        }.doAll(1, (byte)3, new Frame(new String[]{"preds", "actual"}, new Vec[]{preds, actual}))).outputFrame().anyVec();
        double hd = MathUtils.computeWeightedQuantile(weight, absdiff, huberAlpha);
        absdiff.remove();
        return hd;
    }

    public static class MetricBuilderRegression<T extends MetricBuilderRegression<T>>
    extends ModelMetricsSupervised.MetricBuilderSupervised<T> {
        double _sumdeviance;
        Distribution _dist;
        double _abserror;
        double _rmslerror;

        public MetricBuilderRegression() {
            super(1, null);
        }

        public MetricBuilderRegression(Distribution dist) {
            super(1, null);
            this._dist = dist;
        }

        @Override
        public double[] perRow(double[] ds, float[] yact, Model m4) {
            return this.perRow(ds, yact, 1.0, 0.0, m4);
        }

        @Override
        public double[] perRow(double[] ds, float[] yact, double w2, double o2, Model m4) {
            if (Float.isNaN(yact[0])) {
                return ds;
            }
            if (ArrayUtils.hasNaNs(ds)) {
                return ds;
            }
            if (w2 == 0.0 || Double.isNaN(w2)) {
                return ds;
            }
            double err = (double)yact[0] - ds[0];
            double err_msle = Math.pow(Math.log1p(ds[0]) - Math.log1p(yact[0]), 2.0);
            this._sumsqe += w2 * err * err;
            this._abserror += w2 * Math.abs(err);
            this._rmslerror += w2 * err_msle;
            assert (!Double.isNaN(this._sumsqe));
            if (m4 != null && ((Model.Parameters)m4._parms)._distribution != DistributionFamily.custom || this._dist != null && this._dist._family != DistributionFamily.custom) {
                if (m4 != null && !m4.isDistributionHuber()) {
                    this._sumdeviance += m4.deviance(w2, yact[0], ds[0]);
                } else if (this._dist != null) {
                    this._sumdeviance += this._dist.deviance(w2, yact[0], ds[0]);
                }
            }
            ++this._count;
            this._wcount += w2;
            this._wY += w2 * (double)yact[0];
            this._wYY += w2 * (double)yact[0] * (double)yact[0];
            return ds;
        }

        @Override
        public void reduce(T mb) {
            super.reduce(mb);
            this._sumdeviance += ((MetricBuilderRegression)mb)._sumdeviance;
            this._abserror += ((MetricBuilderRegression)mb)._abserror;
            this._rmslerror += ((MetricBuilderRegression)mb)._rmslerror;
        }

        @Override
        public ModelMetricsRegression makeModelMetrics(Model m4, Frame f2, Frame adaptedFrame, Frame preds) {
            ModelMetricsRegression mm4 = this.computeModelMetrics(m4, f2, adaptedFrame, preds);
            if (m4 != null) {
                m4.addModelMetrics(mm4);
            }
            return mm4;
        }

        ModelMetricsRegression computeModelMetrics(Model m4, Frame f2, Frame adaptedFrame, Frame preds) {
            double mse = this._sumsqe / this._wcount;
            double mae = this._abserror / this._wcount;
            double rmsle = Math.sqrt(this._rmslerror / this._wcount);
            if (adaptedFrame == null) {
                adaptedFrame = f2;
            }
            double meanResDeviance = 0.0;
            if (m4 != null && m4.isDistributionHuber()) {
                assert (this._sumdeviance == 0.0);
                if (preds != null) {
                    Vec actual = adaptedFrame.vec(((Model.Parameters)m4._parms)._response_column);
                    Vec weight = adaptedFrame.vec(((Model.Parameters)m4._parms)._weights_column);
                    double huberDelta = ModelMetricsRegression.computeHuberDelta(actual, preds.anyVec(), weight, ((Model.Parameters)m4._parms)._huber_alpha);
                    this._dist = IcedUtils.deepCopy(m4._dist);
                    this._dist.setHuberDelta(huberDelta);
                    meanResDeviance = new MeanResidualDeviance((Distribution)this._dist, (Vec)preds.anyVec(), (Vec)actual, (Vec)weight).exec().meanResidualDeviance;
                }
            } else {
                meanResDeviance = m4 != null && ((Model.Parameters)m4._parms)._distribution != DistributionFamily.custom || this._dist != null && this._dist._family != DistributionFamily.custom ? this._sumdeviance / this._wcount : Double.NaN;
            }
            ModelMetricsRegression mm4 = new ModelMetricsRegression(m4, f2, this._count, mse, this.weightedSigma(), mae, rmsle, meanResDeviance, this._customMetric);
            return mm4;
        }
    }

    private static class RegressionMetrics
    extends MRTask<RegressionMetrics> {
        public MetricBuilderRegression _mb;
        final Distribution _distribution;

        RegressionMetrics(DistributionFamily family) {
            this._distribution = DistributionFactory.getDistribution(family);
        }

        @Override
        public void map(Chunk[] chks) {
            this._mb = new MetricBuilderRegression(this._distribution);
            Chunk preds = chks[0];
            Chunk actuals = chks[1];
            Chunk weights = chks.length == 3 ? chks[2] : null;
            double[] ds = new double[1];
            float[] acts = new float[1];
            for (int i2 = 0; i2 < chks[0]._len; ++i2) {
                ds[0] = preds.atd(i2);
                acts[0] = (float)actuals.atd(i2);
                double w2 = weights != null ? weights.atd(i2) : 1.0;
                this._mb.perRow(ds, acts, w2, 0.0, null);
            }
        }

        @Override
        public void reduce(RegressionMetrics mrt) {
            this._mb.reduce(mrt._mb);
        }
    }
}

