/*
 * Decompiled with CFR 0.152.
 */
package ai.h2o.automl;

import ai.h2o.automl.UserFeedback;
import ai.h2o.automl.UserFeedbackEvent;
import hex.Model;
import hex.ModelMetrics;
import hex.ModelMetricsBinomial;
import hex.ModelMetricsMultinomial;
import hex.ModelMetricsRegression;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import water.DKV;
import water.Iced;
import water.Key;
import water.Keyed;
import water.TAtomic;
import water.api.schemas3.KeyV3;
import water.exceptions.H2OIllegalArgumentException;
import water.fvec.Frame;
import water.util.ArrayUtils;
import water.util.IcedHashMap;
import water.util.Log;
import water.util.TwoDimTable;

public class Leaderboard
extends Keyed<Leaderboard> {
    private final String project_name;
    private Key<Model>[] models = new Key[0];
    private IcedHashMap<Key<ModelMetrics>, ModelMetrics> leaderboard_set_metrics = new IcedHashMap();
    public double[] sort_metrics = new double[0];
    public double[] rmse = new double[0];
    public double[] mae = new double[0];
    public double[] rmsle = new double[0];
    public double[] logloss = new double[0];
    private String sort_metric;
    private String[] other_metrics;
    private boolean sort_decreasing;
    private boolean have_set_sort_metric = false;
    private UserFeedback userFeedback;
    private Frame leaderboardFrame;
    private long leaderboardFrameChecksum;
    protected static final String[] colTypesMultinomial = new String[]{"string", "double"};
    protected static final String[] colFormatsMultinomial = new String[]{"%s", "%.6f"};
    protected static final String[] colTypesBinomial = new String[]{"string", "double", "double"};
    protected static final String[] colFormatsBinomial = new String[]{"%s", "%.6f", "%.6f"};
    protected static final String[] colTypesRegression = new String[]{"string", "double", "double", "double", "double"};
    protected static final String[] colFormatsRegression = new String[]{"%s", "%.6f", "%.6f", "%.6f", "%.6f"};

    private Leaderboard() {
        throw new UnsupportedOperationException("Do not call the default constructor Leaderboard().");
    }

    public Leaderboard(String project_name, UserFeedback userFeedback, Frame leaderboardFrame) {
        this._key = Key.make((String)Leaderboard.idForProject(project_name));
        this.project_name = project_name;
        this.userFeedback = userFeedback;
        this.leaderboardFrame = leaderboardFrame;
        this.leaderboardFrameChecksum = null != this.leaderboardFrame ? leaderboardFrame.checksum() : 0L;
        this.sort_metric = "auc";
        this.other_metrics = new String[]{"logloss"};
        this.sort_decreasing = true;
    }

    public static Leaderboard getOrMakeLeaderboard(String project_name, UserFeedback userFeedback, Frame leaderboardFrame) {
        Leaderboard exists = (Leaderboard)DKV.getGet((Key)Key.make((String)Leaderboard.idForProject(project_name)));
        if (null != exists) {
            exists.userFeedback = userFeedback;
            exists.leaderboardFrame = leaderboardFrame;
            exists.leaderboardFrameChecksum = null != leaderboardFrame ? leaderboardFrame.checksum() : 0L;
            DKV.put((Keyed)exists);
            return exists;
        }
        Leaderboard newLeaderboard = new Leaderboard(project_name, userFeedback, leaderboardFrame);
        DKV.put((Keyed)newLeaderboard);
        return newLeaderboard;
    }

    public static String idForProject(String project_name) {
        return "AutoML_Leaderboard_" + project_name;
    }

    public String getProject() {
        return this.project_name;
    }

    public void setMetricAndDirection(String metric, String[] otherMetrics, boolean sortDecreasing) {
        this.sort_metric = metric;
        this.other_metrics = otherMetrics;
        this.sort_decreasing = sortDecreasing;
        this.have_set_sort_metric = true;
        DKV.put((Keyed)this);
    }

    public void setMetricAndDirection(String metric, boolean sortDecreasing) {
        this.sort_metric = metric;
        this.sort_decreasing = sortDecreasing;
        this.have_set_sort_metric = true;
        DKV.put((Keyed)this);
    }

    public void setDefaultMetricAndDirection(Model m) {
        if (m._output.isBinomialClassifier()) {
            this.setMetricAndDirection("auc", new String[]{"logloss"}, true);
        } else if (m._output.isClassifier()) {
            this.setMetricAndDirection("mean_per_class_error", false);
        } else if (m._output.isSupervised()) {
            this.setMetricAndDirection("mean_residual_deviance", new String[]{"rmse", "mae", "rmsle"}, false);
        }
    }

    public final void addModels(final Key<Model>[] newModels) {
        if (null == this._key) {
            throw new H2OIllegalArgumentException("Can't add models to a Leaderboard which isn't in the DKV.");
        }
        if (null == newModels || newModels.length < 1) {
            return;
        }
        if (!this.have_set_sort_metric) {
            this.setDefaultMetricAndDirection((Model)newModels[0].get());
        }
        final Key[] newLeader = new Key[1];
        final double[] newLeaderSortMetric = new double[1];
        new TAtomic<Leaderboard>(){

            public final Leaderboard atomic(Leaderboard updating) {
                if (updating == null) {
                    Log.err((Object[])new Object[]{"trying to update null leaderboard!"});
                    throw new H2OIllegalArgumentException("Trying to update a null leaderboard.");
                }
                Key[] oldModels = updating.models;
                Key oldLeader = oldModels == null || 0 == oldModels.length ? null : oldModels[0];
                HashSet<Key> uniques = new HashSet<Key>(oldModels.length + newModels.length);
                uniques.addAll(Arrays.asList(oldModels));
                uniques.addAll(Arrays.asList(newModels));
                Leaderboard.access$002(updating, uniques.toArray(new Key[0]));
                updating.leaderboard_set_metrics = new IcedHashMap();
                for (Key aKey : updating.models) {
                    Model aModel = (Model)aKey.get();
                    if (null == aModel) {
                        Leaderboard.this.userFeedback.warn(UserFeedbackEvent.Stage.ModelTraining, "Model in the leaderboard has unexpectedly been deleted from H2O: " + aKey);
                        continue;
                    }
                    ModelMetrics mm = null;
                    if (Leaderboard.this.leaderboardFrame == null) {
                        mm = aModel._output._cross_validation_metrics;
                    } else {
                        mm = ModelMetrics.getFromDKV((Model)aModel, (Frame)Leaderboard.this.leaderboardFrame);
                        if (mm == null) {
                            Frame preds = aModel.score(Leaderboard.this.leaderboardFrame);
                            mm = ModelMetrics.getFromDKV((Model)aModel, (Frame)Leaderboard.this.leaderboardFrame);
                        }
                    }
                    updating.leaderboard_set_metrics.put((Object)mm._key, (Object)mm);
                }
                try {
                    List modelsSorted = null;
                    modelsSorted = Leaderboard.this.leaderboardFrame == null ? ModelMetrics.sortModelsByMetric((String)Leaderboard.this.sort_metric, (boolean)Leaderboard.this.sort_decreasing, Arrays.asList(updating.models)) : ModelMetrics.sortModelsByMetric((Frame)Leaderboard.this.leaderboardFrame, (String)Leaderboard.this.sort_metric, (boolean)Leaderboard.this.sort_decreasing, Arrays.asList(updating.models));
                    Leaderboard.access$002(updating, modelsSorted.toArray(new Key[0]));
                }
                catch (H2OIllegalArgumentException e) {
                    Log.warn((Object[])new Object[]{"ModelMetrics.sortModelsByMetric failed: " + (Object)((Object)e)});
                    throw e;
                }
                Model[] updating_models = new Model[updating.models.length];
                Leaderboard.modelsForModelKeys(updating.models, updating_models);
                updating.sort_metrics = Leaderboard.getSortMetrics(updating.sort_metric, (IcedHashMap<Key<ModelMetrics>, ModelMetrics>)updating.leaderboard_set_metrics, Leaderboard.this.leaderboardFrame, updating_models);
                if (Leaderboard.this.sort_metric.equals("auc")) {
                    updating.logloss = Leaderboard.getOtherMetrics("logloss", (IcedHashMap<Key<ModelMetrics>, ModelMetrics>)updating.leaderboard_set_metrics, Leaderboard.this.leaderboardFrame, updating_models);
                } else if (Leaderboard.this.sort_metric.equals("mean_residual_deviance")) {
                    updating.rmse = Leaderboard.getOtherMetrics("rmse", (IcedHashMap<Key<ModelMetrics>, ModelMetrics>)updating.leaderboard_set_metrics, Leaderboard.this.leaderboardFrame, updating_models);
                    updating.mae = Leaderboard.getOtherMetrics("mae", (IcedHashMap<Key<ModelMetrics>, ModelMetrics>)updating.leaderboard_set_metrics, Leaderboard.this.leaderboardFrame, updating_models);
                    updating.rmsle = Leaderboard.getOtherMetrics("rmsle", (IcedHashMap<Key<ModelMetrics>, ModelMetrics>)updating.leaderboard_set_metrics, Leaderboard.this.leaderboardFrame, updating_models);
                }
                if (oldLeader == null || !oldLeader.equals((Object)updating.models[0])) {
                    newLeader[0] = updating.models[0];
                    newLeaderSortMetric[0] = updating.sort_metrics[0];
                }
                return updating;
            }
        }.invoke(this._key);
        Leaderboard updated = (Leaderboard)DKV.getGet((Key)this._key);
        this.models = updated.models;
        this.leaderboard_set_metrics = updated.leaderboard_set_metrics;
        this.sort_metrics = updated.sort_metrics;
        if (this.sort_metric.equals("auc")) {
            this.logloss = updated.logloss;
        } else if (this.sort_metric.equals("mean_residual_deviance")) {
            this.rmse = updated.rmse;
            this.mae = updated.mae;
            this.rmsle = updated.rmsle;
        }
        if (null != newLeader[0]) {
            this.userFeedback.info(UserFeedbackEvent.Stage.ModelTraining, "New leader: " + newLeader[0] + ", " + this.sort_metric + ": " + newLeaderSortMetric[0]);
        }
    }

    public void addModel(Key<Model> key) {
        if (null == key) {
            return;
        }
        Key[] keys = new Key[]{key};
        this.addModels(keys);
    }

    public void addModel(Model model) {
        if (null == model) {
            return;
        }
        Key[] keys = new Key[]{model._key};
        this.addModels(keys);
    }

    private static Model[] modelsForModelKeys(Key<Model>[] modelKeys, Model[] models) {
        assert (models.length >= modelKeys.length);
        int i = 0;
        for (Key<Model> modelKey : modelKeys) {
            models[i++] = (Model)DKV.getGet(modelKey);
        }
        return models;
    }

    public Key<Model>[] getModelKeys() {
        return ((Leaderboard)DKV.getGet((Key)this._key)).models;
    }

    public Key<Model>[] modelKeys(String metric, boolean sortDecreasing) {
        Key<Model>[] models = this.getModelKeys();
        List newModelsSorted = ModelMetrics.sortModelsByMetric((String)metric, (boolean)sortDecreasing, Arrays.asList(models));
        return newModelsSorted.toArray(new Key[0]);
    }

    public Model[] getModels() {
        Key<Model>[] modelKeys = this.getModelKeys();
        if (modelKeys == null || 0 == modelKeys.length) {
            return new Model[0];
        }
        Model[] models = new Model[modelKeys.length];
        return Leaderboard.modelsForModelKeys(modelKeys, models);
    }

    public Model[] getModels(String metric, boolean sortDecreasing) {
        Key<Model>[] modelKeys = this.modelKeys(metric, sortDecreasing);
        if (modelKeys == null || 0 == modelKeys.length) {
            return new Model[0];
        }
        Model[] models = new Model[modelKeys.length];
        return Leaderboard.modelsForModelKeys(modelKeys, models);
    }

    public Model getLeader() {
        Key<Model>[] modelKeys = this.getModelKeys();
        if (modelKeys == null || 0 == modelKeys.length) {
            return null;
        }
        return (Model)modelKeys[0].get();
    }

    public int getModelCount() {
        return this.getModelKeys().length;
    }

    public double[] getSortMetrics() {
        return Leaderboard.getSortMetrics(this.sort_metric, this.leaderboard_set_metrics, this.leaderboardFrame, this.getModels());
    }

    public static double[] getOtherMetrics(String other_metric, IcedHashMap<Key<ModelMetrics>, ModelMetrics> leaderboard_set_metrics, Frame leaderboardFrame, Model[] models) {
        double[] other_metrics = new double[models.length];
        int i = 0;
        for (Model m : models) {
            if (leaderboardFrame != null) {
                other_metrics[i++] = ModelMetrics.getMetricFromModelMetric((ModelMetrics)((ModelMetrics)leaderboard_set_metrics.get((Object)ModelMetrics.buildKey((Model)m, (Frame)leaderboardFrame))), (String)other_metric);
                continue;
            }
            Key model_key = m._key;
            long model_checksum = m.checksum();
            Key frame_key = m._output._cross_validation_metrics.frame()._key;
            long frame_checksum = m._output._cross_validation_metrics.frame().checksum();
            other_metrics[i++] = ModelMetrics.getMetricFromModelMetric((ModelMetrics)((ModelMetrics)leaderboard_set_metrics.get((Object)ModelMetrics.buildKey((Key)model_key, (long)model_checksum, (Key)frame_key, (long)frame_checksum))), (String)other_metric);
        }
        return other_metrics;
    }

    public static double[] getSortMetrics(String sort_metric, IcedHashMap<Key<ModelMetrics>, ModelMetrics> leaderboard_set_metrics, Frame leaderboardFrame, Model[] models) {
        double[] sort_metrics = new double[models.length];
        int i = 0;
        for (Model m : models) {
            if (leaderboardFrame != null) {
                sort_metrics[i++] = ModelMetrics.getMetricFromModelMetric((ModelMetrics)((ModelMetrics)leaderboard_set_metrics.get((Object)ModelMetrics.buildKey((Model)m, (Frame)leaderboardFrame))), (String)sort_metric);
                continue;
            }
            Key model_key = m._key;
            long model_checksum = m.checksum();
            Key frame_key = m._output._cross_validation_metrics.frame()._key;
            long frame_checksum = m._output._cross_validation_metrics.frame().checksum();
            sort_metrics[i++] = ModelMetrics.getMetricFromModelMetric((ModelMetrics)((ModelMetrics)leaderboard_set_metrics.get((Object)ModelMetrics.buildKey((Key)model_key, (long)model_checksum, (Key)frame_key, (long)frame_checksum))), (String)sort_metric);
        }
        return sort_metrics;
    }

    public void delete() {
        this.remove();
    }

    public void deleteWithChildren() {
        for (Model m : this.getModels()) {
            m.delete();
        }
        this.delete();
    }

    public static double[] defaultMetricForModel(Model m) {
        ModelMetrics mm = m._output._cross_validation_metrics != null ? m._output._cross_validation_metrics : (m._output._validation_metrics != null ? m._output._validation_metrics : m._output._training_metrics);
        return Leaderboard.defaultMetricForModel(m, mm);
    }

    public static double[] defaultMetricForModel(Model m, ModelMetrics mm) {
        if (m._output.isBinomialClassifier()) {
            return new double[]{((ModelMetricsBinomial)mm).auc(), ((ModelMetricsBinomial)mm).logloss()};
        }
        if (m._output.isClassifier()) {
            return new double[]{((ModelMetricsMultinomial)mm).mean_per_class_error()};
        }
        if (m._output.isSupervised()) {
            return new double[]{((ModelMetricsRegression)mm).mean_residual_deviance(), mm.rmse(), ((ModelMetricsRegression)mm).mae(), ((ModelMetricsRegression)mm).rmsle()};
        }
        Log.warn((Object[])new Object[]{"Failed to find metric for model: " + m});
        return new double[]{Double.NaN};
    }

    public static String[] defaultMetricNameForModel(Model m) {
        if (m._output.isBinomialClassifier()) {
            return new String[]{"auc", "logloss"};
        }
        if (m._output.isClassifier()) {
            return new String[]{"mean per-class error"};
        }
        if (m._output.isSupervised()) {
            return new String[]{"mean_residual_deviance", "rmse", "mae", "rmsle"};
        }
        return new String[]{"unknown"};
    }

    public String rankTsv() {
        String fieldSeparator = "\\t";
        String lineSeparator = "\\n";
        StringBuffer sb = new StringBuffer();
        sb.append("Error").append(lineSeparator);
        Model[] models = this.getModels();
        for (int i = models.length - 1; i >= 0; --i) {
            Model m = models[i];
            sb.append(Leaderboard.defaultMetricForModel(m));
            sb.append(lineSeparator);
        }
        return sb.toString();
    }

    public String timeTsv() {
        String fieldSeparator = "\\t";
        String lineSeparator = "\\n";
        StringBuffer sb = new StringBuffer();
        sb.append("Error").append(lineSeparator);
        Model[] models = this.getModels();
        for (int i = models.length - 1; i >= 0; --i) {
            Model m = models[i];
            sb.append(Leaderboard.defaultMetricForModel(m));
            sb.append(lineSeparator);
        }
        return sb.toString();
    }

    protected static final String[] colHeaders(String metric, String[] other_metric) {
        String[] headers = ArrayUtils.append((String[])new String[]{"model_id", metric.toString()}, (String[])other_metric);
        return headers;
    }

    protected static final String[] colHeadersMult(String metric) {
        return new String[]{"model_id", metric.toString()};
    }

    public static final TwoDimTable makeTwoDimTable(String tableHeader, String sort_metric, String[] other_metric, int length) {
        assert (sort_metric != null || sort_metric == null && length == 0) : "sort_metrics needs to be always not-null for non-empty array!";
        String[] rowHeaders = new String[length];
        for (int i = 0; i < length; ++i) {
            rowHeaders[i] = "" + i;
        }
        if (sort_metric == null && length == 0) {
            return new TwoDimTable(tableHeader, "no models in this leaderboard", rowHeaders, Leaderboard.colHeaders(sort_metric, other_metric), colTypesBinomial, colFormatsBinomial, "-");
        }
        if ("mean_per_class_error".equals(sort_metric)) {
            return new TwoDimTable(tableHeader, "models sorted in order of " + sort_metric + ", best first", rowHeaders, Leaderboard.colHeadersMult(sort_metric), colTypesMultinomial, colFormatsMultinomial, "#");
        }
        if ("auc".equals(sort_metric)) {
            return new TwoDimTable(tableHeader, "models sorted in order of " + sort_metric + ", best first", rowHeaders, Leaderboard.colHeaders(sort_metric, other_metric), colTypesBinomial, colFormatsBinomial, "#");
        }
        return new TwoDimTable(tableHeader, "models sorted in order of " + sort_metric + ", best first", rowHeaders, Leaderboard.colHeaders(sort_metric, other_metric), colTypesRegression, colFormatsRegression, "#");
    }

    public void addTwoDimTableRowMultinomial(TwoDimTable table, int row, String[] modelIDs, double[] errors) {
        int col = 0;
        table.set(row, col++, (Object)modelIDs[row]);
        table.set(row, col++, (Object)errors[row]);
    }

    public void addTwoDimTableRowBinomial(TwoDimTable table, int row, String[] modelIDs, double[] errors, double[] otherErrors) {
        int col = 0;
        table.set(row, col++, (Object)modelIDs[row]);
        table.set(row, col++, (Object)errors[row]);
        table.set(row, col++, (Object)otherErrors[row]);
    }

    public void addTwoDimTableRowRegression(TwoDimTable table, int row, String[] modelIDs, double[] errors, double[] rmse, double[] mae, double[] rmsle) {
        int col = 0;
        table.set(row, col++, (Object)modelIDs[row]);
        table.set(row, col++, (Object)errors[row]);
        table.set(row, col++, (Object)rmse[row]);
        table.set(row, col++, (Object)mae[row]);
        table.set(row, col++, (Object)rmsle[row]);
    }

    public TwoDimTable toTwoDimTable() {
        return this.toTwoDimTable("Leaderboard for project_name: " + this.project_name, false);
    }

    public TwoDimTable toTwoDimTable(String tableHeader, boolean leftJustifyModelIds) {
        int i;
        Model[] models = this.getModels();
        String[] modelIDsFormatted = new String[models.length];
        TwoDimTable table = Leaderboard.makeTwoDimTable(tableHeader, this.sort_metric, this.other_metrics, models.length);
        int maxModelIdLen = -1;
        for (Model m : models) {
            maxModelIdLen = Math.max(maxModelIdLen, m._key.toString().length());
        }
        for (i = 0; i < models.length; ++i) {
            modelIDsFormatted[i] = leftJustifyModelIds ? (models[i]._key.toString() + "                                                                                         ").substring(0, maxModelIdLen) : models[i]._key.toString();
        }
        for (i = 0; i < models.length; ++i) {
            if (this.sort_metric.equals("mean_per_class_error")) {
                this.addTwoDimTableRowMultinomial(table, i, modelIDsFormatted, this.sort_metrics);
                continue;
            }
            if (this.sort_metric.equals("auc")) {
                this.addTwoDimTableRowBinomial(table, i, modelIDsFormatted, this.sort_metrics, this.logloss);
                continue;
            }
            this.addTwoDimTableRowRegression(table, i, modelIDsFormatted, this.sort_metrics, this.rmse, this.mae, this.rmsle);
        }
        return table;
    }

    public static String toString(String project_name, Model[] models, String fieldSeparator, String lineSeparator, boolean includeTitle, boolean includeHeader) {
        StringBuilder sb = new StringBuilder();
        if (includeTitle) {
            sb.append("Leaderboard for project_name \"").append(project_name).append("\": ");
            if (models.length == 0) {
                sb.append("<empty>");
                return sb.toString();
            }
            sb.append(lineSeparator);
        }
        boolean printedHeader = false;
        for (Model m : models) {
            if (includeHeader && !printedHeader) {
                sb.append("model_id");
                sb.append(fieldSeparator);
                sb.append(Leaderboard.defaultMetricNameForModel(m));
                sb.append(lineSeparator);
                printedHeader = true;
            }
            sb.append(m._key.toString());
            sb.append(fieldSeparator);
            sb.append(Leaderboard.defaultMetricForModel(m));
            sb.append(lineSeparator);
        }
        return sb.toString();
    }

    public String toString(String fieldSeparator, String lineSeparator) {
        return Leaderboard.toString(this.project_name, this.getModels(), fieldSeparator, lineSeparator, true, true);
    }

    public String toString() {
        return this.toString(" ; ", " | ");
    }

    static /* synthetic */ Key[] access$002(Leaderboard x0, Key[] x1) {
        x0.models = x1;
        return x1;
    }

    public static class LeaderboardKeyV3
    extends KeyV3<Iced, LeaderboardKeyV3, Leaderboard> {
        public LeaderboardKeyV3() {
        }

        public LeaderboardKeyV3(Key<Leaderboard> key) {
            super(key);
        }
    }
}

