/*
 * Decompiled with CFR 0.152.
 */
package hex.schemas;

import hex.Model;
import hex.ModelMetrics;
import hex.grid.Grid;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Set;
import water.DKV;
import water.Key;
import water.api.API;
import water.api.KeyV3;
import water.api.ModelMetricsBase;
import water.api.ModelParametersSchema;
import water.api.Schema;
import water.api.TwoDimTableBase;
import water.exceptions.H2OIllegalArgumentException;
import water.util.TwoDimTable;

public class GridSchemaV99
extends Schema<Grid, GridSchemaV99> {
    @API(help="Grid id")
    public KeyV3.GridKeyV3 grid_id;
    @API(help="Model performance metric to sort by. Examples: logloss, residual_deviance, mse, auc, r2, f1, recall, precision, accuracy, mcc, err, err_count, lift_top_group, max_per_class_error", required=false, direction=API.Direction.INOUT)
    public String sort_by;
    @API(help="Specify whether sort order should be decreasing.", required=false, direction=API.Direction.INOUT)
    public boolean decreasing;
    @API(help="Model IDs built by a grid search")
    public KeyV3.ModelKeyV3[] model_ids;
    @API(help="Used hyper parameters.", direction=API.Direction.OUTPUT)
    public String[] hyper_names;
    @API(help="List of failed parameters", direction=API.Direction.OUTPUT)
    public ModelParametersSchema[] failed_params;
    @API(help="List of detailed failure messages", direction=API.Direction.OUTPUT)
    public String[] failure_details;
    @API(help="List of detailed failure stack traces", direction=API.Direction.OUTPUT)
    public String[] failure_stack_traces;
    @API(help="List of raw parameters causing model building failure", direction=API.Direction.OUTPUT)
    public String[][] failed_raw_params;
    @API(help="Training model metrics for the returned models; only returned if sort_by is set", direction=API.Direction.OUTPUT)
    public ModelMetricsBase[] training_metrics;
    @API(help="Validation model metrics for the returned models; only returned if sort_by is set", direction=API.Direction.OUTPUT)
    public ModelMetricsBase[] validation_metrics;
    @API(help="Cross validation model metrics for the returned models; only returned if sort_by is set", direction=API.Direction.OUTPUT)
    public ModelMetricsBase[] cross_validation_metrics;
    @API(help="Cross validation model metrics summary for the returned models; only returned if sort_by is set", direction=API.Direction.OUTPUT)
    public TwoDimTableBase[] cross_validation_metrics_summary;
    @API(help="Summary", direction=API.Direction.OUTPUT)
    TwoDimTableBase summary_table;
    @API(help="Scoring history", direction=API.Direction.OUTPUT, level=API.Level.secondary)
    TwoDimTableBase scoring_history;

    @Override
    public Grid createImpl() {
        return Grid.GRID_PROTO;
    }

    @Override
    public GridSchemaV99 fillFromImpl(Grid grid) {
        TwoDimTable h;
        Set<String> possibleMetrics;
        Model m;
        Key<Model>[] gridModelKeys = grid.getModelKeys();
        List<Key<Model>> modelKeys = new ArrayList<Key<Model>>(gridModelKeys.length);
        for (Key<Model> k : gridModelKeys) {
            if (k == null || DKV.get(k) == null) continue;
            modelKeys.add(k);
        }
        if (this.sort_by == null && modelKeys.size() > 0 && modelKeys.get(0) != null && (m = (Model)DKV.getGet((Key)modelKeys.get(0))) != null && m.isSupervised()) {
            if (((Model.Output)m._output).nclasses() > 1) {
                this.sort_by = "logloss";
                this.decreasing = false;
            } else {
                this.sort_by = "residual_deviance";
                this.decreasing = false;
            }
        }
        if (modelKeys.size() > 0 && this.sort_by != null && !(possibleMetrics = ModelMetrics.getAllowedMetrics((Key)modelKeys.get(0))).contains(this.sort_by.toLowerCase())) {
            throw new H2OIllegalArgumentException("Invalid argument for sort_by specified. Must be one of: " + Arrays.toString(possibleMetrics.toArray(new String[0])));
        }
        if (null != this.sort_by && !this.sort_by.isEmpty()) {
            modelKeys = ModelMetrics.sortModelsByMetric(this.sort_by, this.decreasing, modelKeys);
            this.training_metrics = new ModelMetricsBase[modelKeys.size()];
            this.validation_metrics = new ModelMetricsBase[modelKeys.size()];
            this.cross_validation_metrics = new ModelMetricsBase[modelKeys.size()];
            this.cross_validation_metrics_summary = new TwoDimTableBase[modelKeys.size()];
            for (int i = 0; i < modelKeys.size(); ++i) {
                Model m2 = (Model)DKV.getGet(modelKeys.get(i));
                if (null == m2) continue;
                Object o = m2._output;
                if (null != ((Model.Output)o)._training_metrics) {
                    this.training_metrics[i] = (ModelMetricsBase)Schema.schema(3, ((Model.Output)o)._training_metrics).fillFromImpl(((Model.Output)o)._training_metrics);
                }
                if (null != ((Model.Output)o)._validation_metrics) {
                    this.validation_metrics[i] = (ModelMetricsBase)Schema.schema(3, ((Model.Output)o)._validation_metrics).fillFromImpl(((Model.Output)o)._validation_metrics);
                }
                if (null != ((Model.Output)o)._cross_validation_metrics) {
                    this.cross_validation_metrics[i] = (ModelMetricsBase)Schema.schema(3, ((Model.Output)o)._cross_validation_metrics).fillFromImpl(((Model.Output)o)._cross_validation_metrics);
                }
                if (null == ((Model.Output)o)._cross_validation_metrics_summary) continue;
                this.cross_validation_metrics_summary[i] = (TwoDimTableBase)Schema.schema(3, ((Model.Output)o)._cross_validation_metrics_summary).fillFromImpl(((Model.Output)o)._cross_validation_metrics_summary);
            }
        }
        KeyV3.ModelKeyV3[] modelIds = new KeyV3.ModelKeyV3[modelKeys.size()];
        Key[] keys = new Key[modelKeys.size()];
        for (int i = 0; i < modelIds.length; ++i) {
            modelIds[i] = new KeyV3.ModelKeyV3(modelKeys.get(i));
            keys[i] = modelIds[i].key();
        }
        this.grid_id = new KeyV3.GridKeyV3((Key<Grid>)grid._key);
        this.model_ids = modelIds;
        this.hyper_names = grid.getHyperNames();
        this.failed_params = this.toModelParametersSchema(grid.getFailedParameters());
        this.failure_details = grid.getFailureDetails();
        this.failure_stack_traces = grid.getFailureStackTraces();
        this.failed_raw_params = grid.getFailedRawParameters();
        TwoDimTable t = grid.createSummaryTable(keys, this.sort_by, this.decreasing);
        if (t != null) {
            this.summary_table = new TwoDimTableBase().fillFromImpl(t);
        }
        if ((h = grid.createScoringHistoryTable()) != null) {
            this.scoring_history = new TwoDimTableBase().fillFromImpl(h);
        }
        return this;
    }

    private ModelParametersSchema[] toModelParametersSchema(Model.Parameters[] modelParameters) {
        if (modelParameters == null) {
            return null;
        }
        ModelParametersSchema[] result = new ModelParametersSchema[modelParameters.length];
        for (int i = 0; i < modelParameters.length; ++i) {
            result[i] = modelParameters[i] != null ? (ModelParametersSchema)Schema.schema(Schema.getLatestVersion(), modelParameters[i]).fillFromImpl(modelParameters[i]) : null;
        }
        return result;
    }
}

