/*
 * Decompiled with CFR 0.152.
 */
package hex;

import hex.ConfusionMatrix;
import hex.Model;
import hex.ModelMetrics;
import hex.ModelMetricsSupervised;
import hex.genmodel.GenModel;
import java.util.Arrays;
import water.MRTask;
import water.exceptions.H2OIllegalArgumentException;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.ArrayUtils;
import water.util.TwoDimTable;

public class ModelMetricsMultinomial
extends ModelMetricsSupervised {
    public final float[] _hit_ratios;
    public final ConfusionMatrix _cm;
    public final double _logloss;
    public final double _mean_per_class_error;

    public ModelMetricsMultinomial(Model model, Frame frame, long nobs, double mse, String[] domain, double sigma, ConfusionMatrix cm, float[] hr, double logloss) {
        super(model, frame, nobs, mse, domain, sigma);
        this._cm = cm;
        this._hit_ratios = hr;
        this._logloss = logloss;
        this._mean_per_class_error = cm == null ? Double.NaN : cm.mean_per_class_error();
    }

    @Override
    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append(super.toString());
        sb.append(" logloss: " + (float)this._logloss + "\n");
        sb.append(" mean_per_class_error: " + (float)this._mean_per_class_error + "\n");
        sb.append(" hit ratios: " + Arrays.toString(this._hit_ratios) + "\n");
        if (this.cm() != null) {
            if (this.cm().nclasses() <= 20) {
                sb.append(" CM: " + this.cm().toASCII());
            } else {
                sb.append(" CM: too large to print.\n");
            }
        }
        return sb.toString();
    }

    public double logloss() {
        return this._logloss;
    }

    public double mean_per_class_error() {
        return this._mean_per_class_error;
    }

    @Override
    public ConfusionMatrix cm() {
        return this._cm;
    }

    @Override
    public float[] hr() {
        return this._hit_ratios;
    }

    public static ModelMetricsMultinomial getFromDKV(Model model, Frame frame) {
        ModelMetrics mm = ModelMetrics.getFromDKV(model, frame);
        if (!(mm instanceof ModelMetricsMultinomial)) {
            throw new H2OIllegalArgumentException("Expected to find a Multinomial ModelMetrics for model: " + model._key.toString() + " and frame: " + frame._key.toString(), "Expected to find a ModelMetricsMultinomial for model: " + model._key.toString() + " and frame: " + frame._key.toString() + " but found a: " + mm.getClass());
        }
        return (ModelMetricsMultinomial)mm;
    }

    public static void updateHits(double w, int iact, double[] ds, double[] hits) {
        ModelMetricsMultinomial.updateHits(w, iact, ds, hits, null);
    }

    public static void updateHits(double w, int iact, double[] ds, double[] hits, double[] priorClassDistribution) {
        double after;
        if ((double)iact == ds[0]) {
            hits[0] = hits[0] + 1.0;
            return;
        }
        double before = ArrayUtils.sum(hits);
        double[] ds_copy = Arrays.copyOf(ds, ds.length);
        ds_copy[1 + (int)ds[0]] = 0.0;
        for (int k = 1; k < hits.length; ++k) {
            int pred_labels = GenModel.getPrediction((double[])ds_copy, (double[])priorClassDistribution, (double[])ds, (double)0.5);
            ds_copy[1 + pred_labels] = 0.0;
            if (pred_labels != iact) continue;
            int n = k;
            hits[n] = hits[n] + w;
            break;
        }
        if (hits.length == ds.length - 1 && (after = ArrayUtils.sum(hits)) == before) {
            int n = hits.length - 1;
            hits[n] = hits[n] + w;
        }
    }

    public static TwoDimTable getHitRatioTable(float[] hits) {
        String tableHeader = "Top-" + hits.length + " Hit Ratios";
        String[] rowHeaders = new String[hits.length];
        for (int k = 0; k < hits.length; ++k) {
            rowHeaders[k] = Integer.toString(k + 1);
        }
        String[] colHeaders = new String[]{"Hit Ratio"};
        String[] colTypes = new String[]{"float"};
        String[] colFormats = new String[]{"%f"};
        String colHeaderForRowHeaders = "K";
        TwoDimTable table = new TwoDimTable(tableHeader, null, rowHeaders, colHeaders, colTypes, colFormats, colHeaderForRowHeaders);
        for (int k = 0; k < hits.length; ++k) {
            table.set(k, 0, Float.valueOf(hits[k]));
        }
        return table;
    }

    public static ModelMetricsMultinomial make(Frame perClassProbs, Vec actualLabels) {
        String[] label;
        String[] names = perClassProbs.names();
        String[] union = ArrayUtils.union(names, label = actualLabels.domain(), true);
        if (union.length == names.length + label.length) {
            throw new IllegalArgumentException("Column names of per-class-probabilities and categorical domain of actual labels have no common values!");
        }
        return ModelMetricsMultinomial.make(perClassProbs, actualLabels, perClassProbs.names());
    }

    public static ModelMetricsMultinomial make(Frame perClassProbs, Vec actualLabels, String[] domain) {
        Vec _labels = actualLabels.toCategoricalVec();
        if (_labels == null || perClassProbs == null) {
            throw new IllegalArgumentException("Missing actualLabels or predictedProbs for multinomial metrics!");
        }
        if (_labels.length() != perClassProbs.numRows()) {
            throw new IllegalArgumentException("Both arguments must have the same length for multinomial metrics (" + _labels.length() + "!=" + perClassProbs.numRows() + ")!");
        }
        for (Vec p : perClassProbs.vecs()) {
            if (!p.isNumeric()) {
                throw new IllegalArgumentException("Predicted probabilities must be numeric per-class probabilities for multinomial metrics.");
            }
            if (!(p.min() < 0.0) && !(p.max() > 1.0)) continue;
            throw new IllegalArgumentException("Predicted probabilities must be between 0 and 1 for multinomial metrics.");
        }
        int nclasses = perClassProbs.numCols();
        if (domain.length != nclasses) {
            throw new IllegalArgumentException("Given domain has " + domain.length + " classes, but predictions have " + nclasses + " columns (per-class probabilities) for multinomial metrics.");
        }
        _labels = _labels.adaptTo(domain);
        Frame predsLabel = new Frame(perClassProbs);
        predsLabel.add("labels", _labels);
        MetricBuilderMultinomial mb = ((MultinomialMetrics)new MultinomialMetrics(_labels.domain()).doAll(predsLabel))._mb;
        _labels.remove();
        ModelMetricsMultinomial mm = (ModelMetricsMultinomial)mb.makeModelMetrics(null, predsLabel, null, null);
        mm._description = "Computed on user-given predictions and labels.";
        return mm;
    }

    public static class MetricBuilderMultinomial<T extends MetricBuilderMultinomial<T>>
    extends ModelMetricsSupervised.MetricBuilderSupervised<T> {
        double[][] _cm;
        double[] _hits;
        int _K;
        double _logloss;
        public transient double[] _priorDistribution;

        public MetricBuilderMultinomial(int nclasses, String[] domain) {
            super(nclasses, domain);
            this._cm = new double[domain.length][domain.length];
            this._K = Math.min(10, this._nclasses);
            this._hits = new double[this._K];
        }

        @Override
        public double[] perRow(double[] ds, float[] yact, Model m) {
            return this.perRow(ds, yact, 1.0, 0.0, m);
        }

        @Override
        public double[] perRow(double[] ds, float[] yact, double w, double o, Model m) {
            if (Float.isNaN(yact[0])) {
                return ds;
            }
            if (ArrayUtils.hasNaNs(ds)) {
                return ds;
            }
            if (w == 0.0 || Double.isNaN(w)) {
                return ds;
            }
            int iact = (int)yact[0];
            ++this._count;
            this._wcount += w;
            this._wY += w * (double)iact;
            this._wYY += w * (double)iact * (double)iact;
            double err = iact + 1 < ds.length ? 1.0 - ds[iact + 1] : 1.0;
            this._sumsqe += w * err * err;
            assert (!Double.isNaN(this._sumsqe));
            double[] dArray = this._cm[iact];
            int n = (int)ds[0];
            dArray[n] = dArray[n] + 1.0;
            if (this._K > 0 && iact < ds.length - 1) {
                ModelMetricsMultinomial.updateHits(w, iact, ds, this._hits, m != null ? ((Model.Output)m._output)._priorClassDist : this._priorDistribution);
            }
            double eps = 1.0E-15;
            this._logloss -= w * Math.log(Math.max(1.0E-15, 1.0 - err));
            return ds;
        }

        @Override
        public void reduce(T mb) {
            super.reduce(mb);
            assert (((MetricBuilderMultinomial)mb)._K == this._K);
            ArrayUtils.add(this._cm, ((MetricBuilderMultinomial)mb)._cm);
            this._hits = ArrayUtils.add(this._hits, ((MetricBuilderMultinomial)mb)._hits);
            this._logloss += ((MetricBuilderMultinomial)mb)._logloss;
        }

        @Override
        public ModelMetrics makeModelMetrics(Model m, Frame f, Frame adaptedFrame, Frame preds) {
            double mse = Double.NaN;
            double logloss = Double.NaN;
            float[] hr = new float[this._K];
            ConfusionMatrix cm = new ConfusionMatrix(this._cm, this._domain);
            double sigma = this.weightedSigma();
            if (this._wcount > 0.0) {
                if (this._hits != null) {
                    int i;
                    for (i = 0; i < hr.length; ++i) {
                        hr[i] = (float)(this._hits[i] / this._wcount);
                    }
                    for (i = 1; i < hr.length; ++i) {
                        int n = i;
                        hr[n] = hr[n] + hr[i - 1];
                    }
                }
                mse = this._sumsqe / this._wcount;
                logloss = this._logloss / this._wcount;
            }
            ModelMetricsMultinomial mm = new ModelMetricsMultinomial(m, f, this._count, mse, this._domain, sigma, cm, hr, logloss);
            if (m != null) {
                ((Model.Output)m._output).addModelMetrics(mm);
            }
            return mm;
        }
    }

    private static class MultinomialMetrics
    extends MRTask<MultinomialMetrics> {
        public ModelMetricsMultinomial _mm;
        String[] domain;
        private MetricBuilderMultinomial _mb;

        public MultinomialMetrics(String[] domain) {
            this.domain = domain;
        }

        @Override
        public void map(Chunk[] chks) {
            this._mb = new MetricBuilderMultinomial(this.domain.length, this.domain);
            Chunk actuals = chks[chks.length - 1];
            double[] ds = new double[chks.length];
            for (int i = 0; i < chks[0]._len; ++i) {
                for (int c = 1; c < chks.length; ++c) {
                    ds[c] = chks[c - 1].atd(i);
                }
                ds[0] = GenModel.getPrediction((double[])ds, null, (double[])ds, (double)0.5);
                this._mb.perRow(ds, new float[]{actuals.at8(i)}, null);
            }
        }

        @Override
        public void reduce(MultinomialMetrics mrt) {
            this._mb.reduce(mrt._mb);
        }

        @Override
        protected void postGlobal() {
            this._mm = (ModelMetricsMultinomial)this._mb.makeModelMetrics(null, this._fr, null, null);
        }
    }
}

