/*
 * Decompiled with CFR 0.152.
 */
package hex;

import hex.AUC2;
import hex.ConfusionMatrix;
import hex.CustomMetric;
import hex.Model;
import hex.ModelMetrics;
import hex.ModelMetricsSupervised;
import hex.MultinomialAUC;
import hex.MultinomialAucType;
import hex.genmodel.GenModel;
import java.util.Arrays;
import water.Key;
import water.MRTask;
import water.Scope;
import water.exceptions.H2OIllegalArgumentException;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.ArrayUtils;
import water.util.Log;
import water.util.MathUtils;
import water.util.TwoDimTable;

public class ModelMetricsMultinomial
extends ModelMetricsSupervised {
    public final float[] _hit_ratios;
    public final ConfusionMatrix _cm;
    public final double _logloss;
    public double _mean_per_class_error;
    public MultinomialAUC _auc;

    public ModelMetricsMultinomial(Model model, Frame frame, long nobs, double mse, String[] domain, double sigma, ConfusionMatrix cm, float[] hr, double logloss, MultinomialAUC auc, CustomMetric customMetric) {
        super(model, frame, nobs, mse, domain, sigma, customMetric);
        this._cm = cm;
        this._hit_ratios = hr;
        this._logloss = logloss;
        this._mean_per_class_error = cm == null || cm.tooLarge() ? Double.NaN : cm.mean_per_class_error();
        this._auc = auc;
    }

    @Override
    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append(super.toString());
        sb.append(" logloss: " + (float)this._logloss + "\n");
        sb.append(" mean_per_class_error: " + (float)this._mean_per_class_error + "\n");
        sb.append(" hit ratios: " + Arrays.toString(this._hit_ratios) + "\n");
        sb.append(" AUC: " + this.auc() + "\n");
        sb.append(" pr_auc: " + this.pr_auc() + "\n");
        if (this._auc.getAucTable() == null) {
            sb.append(" AUC table: is not computed because it is disabled (model parameter 'auc_type' is set to AUTO or NONE) or due to domain size (maximum is 50 domains).\n");
            sb.append(" pr_auc table: is not computed because it is disabled (model parameter 'auc_type' is set to AUTO or NONE) or due to domain size (maximum is 50 domains).\n");
        } else if (this._domain.length <= 20) {
            sb.append(" AUC table: " + this._auc.getAucTable() + "\n");
            sb.append(" pr_auc table: " + this._auc.getAucPrTable() + "\n");
        } else {
            sb.append(" AUC table: too large to print.\n");
            sb.append(" pr_auc table: too large to print.\n");
        }
        if (this.cm() != null) {
            if (this.cm().nclasses() <= 20) {
                sb.append(" CM: " + this.cm().toASCII());
            } else {
                sb.append(" CM: too large to print.\n");
            }
        }
        return sb.toString();
    }

    public double logloss() {
        return this._logloss;
    }

    public double mean_per_class_error() {
        return this._mean_per_class_error;
    }

    @Override
    public ConfusionMatrix cm() {
        return this._cm;
    }

    @Override
    public float[] hr() {
        return this._hit_ratios;
    }

    public double auc() {
        if (this._auc != null) {
            return this._auc.auc();
        }
        return Double.NaN;
    }

    public double pr_auc() {
        if (this._auc != null) {
            return this._auc.pr_auc();
        }
        return Double.NaN;
    }

    public double aucpr() {
        return this.pr_auc();
    }

    public static ModelMetricsMultinomial getFromDKV(Model model, Frame frame) {
        ModelMetrics mm = ModelMetrics.getFromDKV(model, frame);
        if (!(mm instanceof ModelMetricsMultinomial)) {
            throw new H2OIllegalArgumentException("Expected to find a Multinomial ModelMetrics for model: " + model._key.toString() + " and frame: " + frame._key.toString(), "Expected to find a ModelMetricsMultinomial for model: " + model._key.toString() + " and frame: " + frame._key.toString() + " but found a: " + mm.getClass());
        }
        return (ModelMetricsMultinomial)mm;
    }

    public static void updateHits(double w, int iact, double[] ds, double[] hits) {
        ModelMetricsMultinomial.updateHits(w, iact, ds, hits, null);
    }

    public static void updateHits(double w, int iact, double[] ds, double[] hits, double[] priorClassDistribution) {
        double after;
        if ((double)iact == ds[0]) {
            hits[0] = hits[0] + 1.0;
            return;
        }
        double before = ArrayUtils.sum(hits);
        double[] ds_copy = Arrays.copyOf(ds, ds.length);
        ds_copy[1 + (int)ds[0]] = 0.0;
        for (int k = 1; k < hits.length; ++k) {
            int pred_labels = GenModel.getPrediction((double[])ds_copy, (double[])priorClassDistribution, (double[])ds, (double)0.5);
            ds_copy[1 + pred_labels] = 0.0;
            if (pred_labels != iact) continue;
            int n = k;
            hits[n] = hits[n] + w;
            break;
        }
        if (hits.length == ds.length - 1 && (after = ArrayUtils.sum(hits)) == before) {
            int n = hits.length - 1;
            hits[n] = hits[n] + w;
        }
    }

    public static TwoDimTable getHitRatioTable(float[] hits) {
        String tableHeader = "Top-" + hits.length + " Hit Ratios";
        String[] rowHeaders = new String[hits.length];
        for (int k = 0; k < hits.length; ++k) {
            rowHeaders[k] = Integer.toString(k + 1);
        }
        String[] colHeaders = new String[]{"Hit Ratio"};
        String[] colTypes = new String[]{"float"};
        String[] colFormats = new String[]{"%f"};
        String colHeaderForRowHeaders = "K";
        TwoDimTable table = new TwoDimTable(tableHeader, null, rowHeaders, colHeaders, colTypes, colFormats, colHeaderForRowHeaders);
        for (int k = 0; k < hits.length; ++k) {
            table.set(k, 0, Float.valueOf(hits[k]));
        }
        return table;
    }

    public static ModelMetricsMultinomial make(Frame perClassProbs, Vec actualLabels, MultinomialAucType aucType) {
        String[] label;
        String[] names = perClassProbs.names();
        String[] union = ArrayUtils.union(names, label = actualLabels.domain(), true);
        if (union.length == names.length + label.length) {
            throw new IllegalArgumentException("Column names of per-class-probabilities and categorical domain of actual labels have no common values!");
        }
        return ModelMetricsMultinomial.make(perClassProbs, actualLabels, perClassProbs.names(), aucType);
    }

    public static ModelMetricsMultinomial make(Frame perClassProbs, Vec actualLabels, String[] domain, MultinomialAucType aucType) {
        return ModelMetricsMultinomial.make(perClassProbs, actualLabels, null, domain, aucType);
    }

    public static ModelMetricsMultinomial make(Frame perClassProbs, Vec actualLabels, Vec weights, String[] domain, MultinomialAucType aucType) {
        int nclasses;
        Scope.enter();
        Vec labels = actualLabels.toCategoricalVec();
        if (labels == null || perClassProbs == null) {
            throw new IllegalArgumentException("Missing actualLabels or predictedProbs for multinomial metrics!");
        }
        if (labels.length() != perClassProbs.numRows()) {
            throw new IllegalArgumentException("Both arguments must have the same length for multinomial metrics (" + labels.length() + "!=" + perClassProbs.numRows() + ")!");
        }
        for (Vec p : perClassProbs.vecs()) {
            if (!p.isNumeric()) {
                throw new IllegalArgumentException("Predicted probabilities must be numeric per-class probabilities for multinomial metrics.");
            }
            if (!(p.min() < 0.0) && !(p.max() > 1.0)) continue;
            throw new IllegalArgumentException("Predicted probabilities must be between 0 and 1 for multinomial metrics.");
        }
        if (aucType.equals((Object)MultinomialAucType.AUTO) || aucType.equals((Object)MultinomialAucType.NONE)) {
            Log.info("Multinomial AUC and AUCPR will not be calculated in metric summary. The model parameter auc_type is set to \"NONE\" or \"AUTO\" or the maximum size of domain (50) was reached.");
        }
        if (domain.length != (nclasses = perClassProbs.numCols())) {
            throw new IllegalArgumentException("Given domain has " + domain.length + " classes, but predictions have " + nclasses + " columns (per-class probabilities) for multinomial metrics.");
        }
        labels = labels.adaptTo(domain);
        Frame fr = new Frame(perClassProbs);
        fr.add("labels", labels);
        if (weights != null) {
            fr.add("weights", weights);
        }
        MetricBuilderMultinomial mb = ((MultinomialMetrics)new MultinomialMetrics(labels.domain(), aucType).doAll(fr))._mb;
        labels.remove();
        ModelMetricsMultinomial mm = (ModelMetricsMultinomial)mb.makeModelMetrics(null, fr, null, null);
        mm._description = "Computed on user-given predictions and labels.";
        Scope.exit(new Key[0]);
        return mm;
    }

    public static class MetricBuilderMultinomial<T extends MetricBuilderMultinomial<T>>
    extends ModelMetricsSupervised.MetricBuilderSupervised<T> {
        double[][] _cm;
        double[] _hits;
        int _K;
        double _logloss;
        boolean _calculateAuc;
        AUC2.AUCBuilder[][] _ovoAucs;
        AUC2.AUCBuilder[] _ovrAucs;
        MultinomialAucType _aucType;
        public transient double[] _priorDistribution;

        public MetricBuilderMultinomial() {
        }

        public MetricBuilderMultinomial(int nclasses, String[] domain, MultinomialAucType aucType) {
            super(nclasses, domain);
            int domainLength = domain.length;
            this._cm = domain.length > ConfusionMatrix.maxClasses() ? (double[][])null : new double[domainLength][domainLength];
            this._K = Math.min(10, this._nclasses);
            this._hits = new double[this._K];
            this._aucType = aucType;
            boolean bl = this._calculateAuc = !this._aucType.equals((Object)MultinomialAucType.NONE) && !this._aucType.equals((Object)MultinomialAucType.AUTO) && domainLength <= 50;
            if (this._calculateAuc) {
                this._ovoAucs = new AUC2.AUCBuilder[domainLength][domainLength];
                this._ovrAucs = new AUC2.AUCBuilder[domainLength];
                for (int i = 0; i < domainLength; ++i) {
                    this._ovrAucs[i] = new AUC2.AUCBuilder(400);
                    for (int j = 0; j < domainLength; ++j) {
                        if (i == j) continue;
                        this._ovoAucs[i][j] = new AUC2.AUCBuilder(400);
                    }
                }
            }
        }

        @Override
        public double[] perRow(double[] ds, float[] yact, Model m) {
            return this.perRow(ds, yact, 1.0, 0.0, m);
        }

        @Override
        public double[] perRow(double[] ds, float[] yact, double w, double o, Model m) {
            if (this._cm == null) {
                return ds;
            }
            if (Float.isNaN(yact[0])) {
                return ds;
            }
            if (ArrayUtils.hasNaNs(ds)) {
                return ds;
            }
            if (w == 0.0 || Double.isNaN(w)) {
                return ds;
            }
            int iact = (int)yact[0];
            ++this._count;
            this._wcount += w;
            this._wY += w * (double)iact;
            this._wYY += w * (double)iact * (double)iact;
            double err = iact + 1 < ds.length ? 1.0 - ds[iact + 1] : 1.0;
            this._sumsqe += w * err * err;
            assert (!Double.isNaN(this._sumsqe));
            assert (iact < this._cm.length) : "iact = " + iact + "; _cm.length = " + this._cm.length;
            assert ((int)ds[0] < this._cm.length) : "ds[0] = " + ds[0] + "; _cm.length = " + this._cm.length;
            double[] dArray = this._cm[iact];
            int n = (int)ds[0];
            dArray[n] = dArray[n] + 1.0;
            if (this._K > 0 && iact < ds.length - 1) {
                ModelMetricsMultinomial.updateHits(w, iact, ds, this._hits, m != null ? ((Model.Output)m._output)._priorClassDist : this._priorDistribution);
            }
            this._logloss += w * MathUtils.logloss(err);
            if (this._calculateAuc) {
                this.calculateAucsPerRow(ds, iact, w);
            }
            return ds;
        }

        private void calculateAucsPerRow(double[] ds, int iact, double w) {
            if (iact >= this._domain.length) {
                iact = this._domain.length - 1;
            }
            for (int i = 0; i < this._domain.length; ++i) {
                double p1 = 0.0;
                double p2 = 0.0;
                if (i < ds.length - 1) {
                    p1 = ds[i + 1];
                }
                if (iact < ds.length - 1) {
                    p2 = ds[iact + 1];
                }
                if (i != iact) {
                    this._ovoAucs[iact][i].perRow(p1, 0, w);
                    this._ovoAucs[i][iact].perRow(p2, 1, w);
                    this._ovrAucs[i].perRow(p1, 0, w);
                    continue;
                }
                this._ovrAucs[iact].perRow(p2, 1, w);
            }
        }

        @Override
        public void reduce(T mb) {
            if (this._cm == null) {
                return;
            }
            super.reduce(mb);
            assert (((MetricBuilderMultinomial)mb)._K == this._K);
            ArrayUtils.add(this._cm, ((MetricBuilderMultinomial)mb)._cm);
            this._hits = ArrayUtils.add(this._hits, ((MetricBuilderMultinomial)mb)._hits);
            this._logloss += ((MetricBuilderMultinomial)mb)._logloss;
            if (this._calculateAuc) {
                for (int i = 0; i < this._ovoAucs.length; ++i) {
                    this._ovrAucs[i].reduce(((MetricBuilderMultinomial)mb)._ovrAucs[i]);
                    for (int j = 0; j < this._ovoAucs[0].length; ++j) {
                        if (i == j) continue;
                        this._ovoAucs[i][j].reduce(((MetricBuilderMultinomial)mb)._ovoAucs[i][j]);
                    }
                }
            }
        }

        @Override
        public ModelMetrics makeModelMetrics(Model m, Frame f, Frame adaptedFrame, Frame preds) {
            double mse = Double.NaN;
            double logloss = Double.NaN;
            float[] hr = new float[this._K];
            ConfusionMatrix cm = new ConfusionMatrix(this._cm, this._domain);
            double sigma = this.weightedSigma();
            if (this._wcount > 0.0) {
                if (this._hits != null) {
                    int i;
                    for (i = 0; i < hr.length; ++i) {
                        hr[i] = (float)(this._hits[i] / this._wcount);
                    }
                    for (i = 1; i < hr.length; ++i) {
                        int n = i;
                        hr[n] = hr[n] + hr[i - 1];
                    }
                }
                mse = this._sumsqe / this._wcount;
                logloss = this._logloss / this._wcount;
            }
            MultinomialAUC auc = new MultinomialAUC(this._ovrAucs, this._ovoAucs, this._domain, this._wcount == 0.0, this._aucType);
            ModelMetricsMultinomial mm = new ModelMetricsMultinomial(m, f, this._count, mse, this._domain, sigma, cm, hr, logloss, auc, this._customMetric);
            if (m != null) {
                m.addModelMetrics(mm);
            }
            return mm;
        }
    }

    private static class MultinomialMetrics
    extends MRTask<MultinomialMetrics> {
        private final String[] _domain;
        private final MultinomialAucType _aucType;
        private MetricBuilderMultinomial _mb;

        MultinomialMetrics(String[] domain, MultinomialAucType aucType) {
            this._domain = domain;
            this._aucType = aucType;
        }

        @Override
        public void map(Chunk[] chks) {
            this._mb = new MetricBuilderMultinomial(this._domain.length, this._domain, this._aucType);
            Chunk actuals = chks[this._domain.length];
            Chunk weights = chks.length == this._domain.length + 2 ? chks[this._domain.length + 1] : null;
            double[] ds = new double[this._domain.length + 1];
            float[] acts = new float[1];
            for (int i = 0; i < chks[0]._len; ++i) {
                for (int c = 0; c < ds.length - 1; ++c) {
                    ds[c + 1] = chks[c].atd(i);
                }
                ds[0] = GenModel.getPrediction((double[])ds, null, (double[])ds, (double)0.5);
                acts[0] = actuals.at8(i);
                double w = weights != null ? weights.atd(i) : 1.0;
                this._mb.perRow(ds, acts, w, 0.0, null);
            }
        }

        @Override
        public void reduce(MultinomialMetrics mrt) {
            this._mb.reduce(mrt._mb);
        }
    }
}

