/*
 * Decompiled with CFR 0.152.
 */
package hex;

import hex.ConfusionMatrix;
import hex.CustomMetric;
import hex.Model;
import hex.ModelMetrics;
import hex.ModelMetricsSupervised;
import hex.genmodel.GenModel;
import java.util.Arrays;
import water.Key;
import water.MRTask;
import water.Scope;
import water.exceptions.H2OIllegalArgumentException;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.ArrayUtils;
import water.util.MathUtils;
import water.util.TwoDimTable;

public class ModelMetricsOrdinal
extends ModelMetricsSupervised {
    public final float[] _hit_ratios;
    public final ConfusionMatrix _cm;
    public final double _logloss;
    public final double _mean_per_class_error;

    public ModelMetricsOrdinal(Model model, Frame frame, long nobs, double mse, String[] domain, double sigma, ConfusionMatrix cm, float[] hr, double logloss, CustomMetric customMetric) {
        super(model, frame, nobs, mse, domain, sigma, customMetric);
        this._cm = cm;
        this._hit_ratios = hr;
        this._logloss = logloss;
        this._mean_per_class_error = cm == null || cm.tooLarge() ? Double.NaN : cm.mean_per_class_error();
    }

    @Override
    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append(super.toString());
        sb.append(" logloss: " + (float)this._logloss + "\n");
        sb.append(" mean_per_class_error: " + (float)this._mean_per_class_error + "\n");
        sb.append(" hit ratios: " + Arrays.toString(this._hit_ratios) + "\n");
        if (this.cm() != null) {
            if (this.cm().nclasses() <= 20) {
                sb.append(" CM: " + this.cm().toASCII());
            } else {
                sb.append(" CM: too large to print.\n");
            }
        }
        return sb.toString();
    }

    public double logloss() {
        return this._logloss;
    }

    public double mean_per_class_error() {
        return this._mean_per_class_error;
    }

    @Override
    public ConfusionMatrix cm() {
        return this._cm;
    }

    @Override
    public float[] hr() {
        return this._hit_ratios;
    }

    public static ModelMetricsOrdinal getFromDKV(Model model, Frame frame) {
        ModelMetrics mm4 = ModelMetrics.getFromDKV(model, frame);
        if (!(mm4 instanceof ModelMetricsOrdinal)) {
            throw new H2OIllegalArgumentException("Expected to find a Multinomial ModelMetrics for model: " + model._key.toString() + " and frame: " + frame._key.toString(), "Expected to find a ModelMetricsMultinomial for model: " + model._key.toString() + " and frame: " + frame._key.toString() + " but found a: " + mm4.getClass());
        }
        return (ModelMetricsOrdinal)mm4;
    }

    public static void updateHits(double w2, int iact, double[] ds, double[] hits) {
        ModelMetricsOrdinal.updateHits(w2, iact, ds, hits, null);
    }

    public static void updateHits(double w2, int iact, double[] ds, double[] hits, double[] priorClassDistribution) {
        double after;
        if ((double)iact == ds[0]) {
            hits[0] = hits[0] + 1.0;
            return;
        }
        double before = ArrayUtils.sum(hits);
        double[] ds_copy = Arrays.copyOf(ds, ds.length);
        ds_copy[1 + (int)ds[0]] = 0.0;
        for (int k2 = 1; k2 < hits.length; ++k2) {
            int pred_labels = GenModel.getPrediction(ds_copy, priorClassDistribution, ds, 0.5);
            ds_copy[1 + pred_labels] = 0.0;
            if (pred_labels != iact) continue;
            int n2 = k2;
            hits[n2] = hits[n2] + w2;
            break;
        }
        if (hits.length == ds.length - 1 && (after = ArrayUtils.sum(hits)) == before) {
            int n3 = hits.length - 1;
            hits[n3] = hits[n3] + w2;
        }
    }

    public static TwoDimTable getHitRatioTable(float[] hits) {
        String tableHeader = "Top-" + hits.length + " Hit Ratios";
        String[] rowHeaders = new String[hits.length];
        for (int k2 = 0; k2 < hits.length; ++k2) {
            rowHeaders[k2] = Integer.toString(k2 + 1);
        }
        String[] colHeaders = new String[]{"Hit Ratio"};
        String[] colTypes = new String[]{"float"};
        String[] colFormats = new String[]{"%f"};
        String colHeaderForRowHeaders = "K";
        TwoDimTable table = new TwoDimTable(tableHeader, null, rowHeaders, colHeaders, colTypes, colFormats, colHeaderForRowHeaders);
        for (int k3 = 0; k3 < hits.length; ++k3) {
            table.set(k3, 0, Float.valueOf(hits[k3]));
        }
        return table;
    }

    public static ModelMetricsOrdinal make(Frame perClassProbs, Vec actualLabels) {
        String[] label;
        String[] names = perClassProbs.names();
        String[] union = ArrayUtils.union(names, label = actualLabels.domain(), true);
        if (union.length == names.length + label.length) {
            throw new IllegalArgumentException("Column names of per-class-probabilities and categorical domain of actual labels have no common values!");
        }
        return ModelMetricsOrdinal.make(perClassProbs, actualLabels, perClassProbs.names());
    }

    public static ModelMetricsOrdinal make(Frame perClassProbs, Vec actualLabels, String[] domain) {
        Scope.enter();
        Vec _labels = actualLabels.toCategoricalVec();
        if (_labels == null || perClassProbs == null) {
            throw new IllegalArgumentException("Missing actualLabels or predictedProbs for multinomial metrics!");
        }
        if (_labels.length() != perClassProbs.numRows()) {
            throw new IllegalArgumentException("Both arguments must have the same length for multinomial metrics (" + _labels.length() + "!=" + perClassProbs.numRows() + ")!");
        }
        for (Vec p2 : perClassProbs.vecs()) {
            if (!p2.isNumeric()) {
                throw new IllegalArgumentException("Predicted probabilities must be numeric per-class probabilities for multinomial metrics.");
            }
            if (!(p2.min() < 0.0) && !(p2.max() > 1.0)) continue;
            throw new IllegalArgumentException("Predicted probabilities must be between 0 and 1 for multinomial metrics.");
        }
        int nclasses = perClassProbs.numCols();
        if (domain.length != nclasses) {
            throw new IllegalArgumentException("Given domain has " + domain.length + " classes, but predictions have " + nclasses + " columns (per-class probabilities) for multinomial metrics.");
        }
        _labels = _labels.adaptTo(domain);
        Frame predsLabel = new Frame(perClassProbs);
        predsLabel.add("labels", _labels);
        MetricBuilderOrdinal mb = ((OrdinalMetrics)new OrdinalMetrics(_labels.domain()).doAll(predsLabel))._mb;
        _labels.remove();
        ModelMetricsOrdinal mm4 = (ModelMetricsOrdinal)mb.makeModelMetrics(null, predsLabel, null, null);
        mm4._description = "Computed on user-given predictions and labels.";
        Scope.exit(new Key[0]);
        return mm4;
    }

    public static class MetricBuilderOrdinal<T extends MetricBuilderOrdinal<T>>
    extends ModelMetricsSupervised.MetricBuilderSupervised<T> {
        double[][] _cm;
        double[] _hits;
        int _K;
        double _logloss;
        public transient double[] _priorDistribution;

        public MetricBuilderOrdinal(int nclasses, String[] domain) {
            super(nclasses, domain);
            this._cm = domain.length > ConfusionMatrix.maxClasses() ? (double[][])null : new double[domain.length][domain.length];
            this._K = Math.min(10, this._nclasses);
            this._hits = new double[this._K];
        }

        @Override
        public double[] perRow(double[] ds, float[] yact, Model m4) {
            return this.perRow(ds, yact, 1.0, 0.0, m4);
        }

        @Override
        public double[] perRow(double[] ds, float[] yact, double w2, double o2, Model m4) {
            if (this._cm == null) {
                return ds;
            }
            if (Float.isNaN(yact[0])) {
                return ds;
            }
            if (ArrayUtils.hasNaNs(ds)) {
                return ds;
            }
            if (w2 == 0.0 || Double.isNaN(w2)) {
                return ds;
            }
            int iact = (int)yact[0];
            ++this._count;
            this._wcount += w2;
            this._wY += w2 * (double)iact;
            this._wYY += w2 * (double)iact * (double)iact;
            double err = iact + 1 < ds.length ? 1.0 - ds[iact + 1] : 1.0;
            this._sumsqe += w2 * err * err;
            assert (!Double.isNaN(this._sumsqe));
            double[] dArray = this._cm[iact];
            int n2 = (int)ds[0];
            dArray[n2] = dArray[n2] + 1.0;
            if (this._K > 0 && iact < ds.length - 1) {
                ModelMetricsOrdinal.updateHits(w2, iact, ds, this._hits, m4 != null ? ((Model.Output)m4._output)._priorClassDist : this._priorDistribution);
            }
            this._logloss += w2 * MathUtils.logloss(err);
            return ds;
        }

        @Override
        public void reduce(T mb) {
            if (this._cm == null) {
                return;
            }
            super.reduce(mb);
            assert (((MetricBuilderOrdinal)mb)._K == this._K);
            ArrayUtils.add(this._cm, ((MetricBuilderOrdinal)mb)._cm);
            this._hits = ArrayUtils.add(this._hits, ((MetricBuilderOrdinal)mb)._hits);
            this._logloss += ((MetricBuilderOrdinal)mb)._logloss;
        }

        @Override
        public ModelMetrics makeModelMetrics(Model m4, Frame f2, Frame adaptedFrame, Frame preds) {
            double mse = Double.NaN;
            double logloss = Double.NaN;
            float[] hr = new float[this._K];
            ConfusionMatrix cm = new ConfusionMatrix(this._cm, this._domain);
            double sigma = this.weightedSigma();
            if (this._wcount > 0.0) {
                if (this._hits != null) {
                    int i2;
                    for (i2 = 0; i2 < hr.length; ++i2) {
                        hr[i2] = (float)(this._hits[i2] / this._wcount);
                    }
                    for (i2 = 1; i2 < hr.length; ++i2) {
                        int n2 = i2;
                        hr[n2] = hr[n2] + hr[i2 - 1];
                    }
                }
                mse = this._sumsqe / this._wcount;
                logloss = this._logloss / this._wcount;
            }
            ModelMetricsOrdinal mm4 = new ModelMetricsOrdinal(m4, f2, this._count, mse, this._domain, sigma, cm, hr, logloss, this._customMetric);
            if (m4 != null) {
                m4.addModelMetrics(mm4);
            }
            return mm4;
        }
    }

    private static class OrdinalMetrics
    extends MRTask<OrdinalMetrics> {
        String[] domain;
        private MetricBuilderOrdinal _mb;

        public OrdinalMetrics(String[] domain) {
            this.domain = domain;
        }

        @Override
        public void map(Chunk[] chks) {
            this._mb = new MetricBuilderOrdinal(this.domain.length, this.domain);
            Chunk actuals = chks[chks.length - 1];
            double[] ds = new double[chks.length];
            for (int i2 = 0; i2 < chks[0]._len; ++i2) {
                for (int c2 = 1; c2 < chks.length; ++c2) {
                    ds[c2] = chks[c2 - 1].atd(i2);
                }
                ds[0] = GenModel.getPrediction(ds, null, ds, 0.5);
                this._mb.perRow(ds, new float[]{actuals.at8(i2)}, null);
            }
        }

        @Override
        public void reduce(OrdinalMetrics mrt) {
            this._mb.reduce(mrt._mb);
        }
    }
}

