/*
 * Decompiled with CFR 0.152.
 */
package water.api;

import hex.Model;
import hex.ModelMetrics;
import water.DKV;
import water.H2O;
import water.Iced;
import water.Job;
import water.Key;
import water.KeySnapshot;
import water.Value;
import water.api.API;
import water.api.Handler;
import water.api.Schema;
import water.api.SchemaServer;
import water.api.schemas3.FrameV3;
import water.api.schemas3.JobV3;
import water.api.schemas3.KeyV3;
import water.api.schemas3.ModelMetricsBaseV3;
import water.api.schemas3.SchemaV3;
import water.exceptions.H2OIllegalArgumentException;
import water.exceptions.H2OKeyNotFoundArgumentException;
import water.fvec.Frame;
import water.util.Log;

class ModelMetricsHandler
extends Handler {
    ModelMetricsHandler() {
    }

    public static ModelMetrics getFromDKV(Key key) {
        if (null == key) {
            throw new IllegalArgumentException("Got null key.");
        }
        Value v = DKV.get(key);
        if (null == v) {
            throw new IllegalArgumentException("Did not find key: " + key.toString());
        }
        Object ice = v.get();
        if (!(ice instanceof ModelMetrics)) {
            throw new IllegalArgumentException("Expected a Model for key: " + key.toString() + "; got a: " + ice.getClass());
        }
        return (ModelMetrics)ice;
    }

    public ModelMetricsListSchemaV3 fetch(int version, ModelMetricsListSchemaV3 s) {
        ModelMetricsList m = (ModelMetricsList)s.createAndFillImpl();
        s.fillFromImpl(m.fetch());
        return s;
    }

    public ModelMetricsListSchemaV3 delete(int version, ModelMetricsListSchemaV3 s) {
        ModelMetricsList m = (ModelMetricsList)s.createAndFillImpl();
        s.fillFromImpl(m.delete());
        return s;
    }

    public ModelMetricsListSchemaV3 score(int version, ModelMetricsListSchemaV3 s) {
        if (null == s.model) {
            throw new H2OIllegalArgumentException("model", "predict", s.model);
        }
        if (null == DKV.get(s.model.name)) {
            throw new H2OKeyNotFoundArgumentException("model", "predict", s.model.name);
        }
        if (null == s.frame) {
            throw new H2OIllegalArgumentException("frame", "predict", s.frame);
        }
        if (null == DKV.get(s.frame.name)) {
            throw new H2OKeyNotFoundArgumentException("frame", "predict", s.frame.name);
        }
        ModelMetricsList parms = (ModelMetricsList)s.createAndFillImpl();
        parms._model.score(parms._frame, parms._predictions_name).remove();
        ModelMetricsListSchemaV3 mm = this.fetch(version, s);
        if (null == mm) {
            mm = new ModelMetricsListSchemaV3();
        }
        if (null == mm.model_metrics || 0 == mm.model_metrics.length) {
            Log.warn("Score() did not return a ModelMetrics for model: " + s.model + " on frame: " + s.frame);
        }
        return mm;
    }

    public JobV3 predictAsync(int version, final ModelMetricsListSchemaV3 s) {
        if (null == s.model) {
            throw new H2OIllegalArgumentException("model", "predict", s.model);
        }
        if (null == DKV.get(s.model.name)) {
            throw new H2OKeyNotFoundArgumentException("model", "predict", s.model.name);
        }
        if (null == s.frame) {
            throw new H2OIllegalArgumentException("frame", "predict", s.frame);
        }
        if (null == DKV.get(s.frame.name)) {
            throw new H2OKeyNotFoundArgumentException("frame", "predict", s.frame.name);
        }
        final ModelMetricsList parms = (ModelMetricsList)s.createAndFillImpl();
        if (s.deep_features_hidden_layer > 0) {
            if (null == parms._predictions_name) {
                parms._predictions_name = "deep_features" + Key.make().toString().substring(0, 5) + "_" + parms._model._key.toString() + "_on_" + parms._frame._key.toString();
            }
        } else if (null == parms._predictions_name) {
            parms._predictions_name = parms._exemplar_index >= 0 ? "members_" + parms._model._key.toString() + "_for_exemplar_" + parms._exemplar_index : "predictions" + Key.make().toString().substring(0, 5) + "_" + parms._model._key.toString() + "_on_" + parms._frame._key.toString();
        }
        final Job j = new Job(Key.make(parms._predictions_name), Frame.class.getName(), "prediction");
        H2O.H2OCountedCompleter work = new H2O.H2OCountedCompleter(){

            @Override
            public void compute2() {
                if (s.deep_features_hidden_layer < 0) {
                    parms._model.score(parms._frame, parms._predictions_name, j);
                } else {
                    Frame predictions = ((Model.DeepFeatures)((Object)parms._model)).scoreDeepFeatures(parms._frame, s.deep_features_hidden_layer, j);
                    predictions = new Frame(Key.make(parms._predictions_name), predictions.names(), predictions.vecs());
                    DKV.put(predictions._key, predictions);
                }
                this.tryComplete();
            }
        };
        j.start(work, parms._frame.anyVec().nChunks());
        return new JobV3().fillFromImpl(j);
    }

    public ModelMetricsListSchemaV3 predict(int version, ModelMetricsListSchemaV3 s) {
        Frame predictions;
        if (null == s.model) {
            throw new H2OIllegalArgumentException("model", "predict", s.model);
        }
        if (null == DKV.get(s.model.name)) {
            throw new H2OKeyNotFoundArgumentException("model", "predict", s.model.name);
        }
        if (s.exemplar_index < 0) {
            if (null == s.frame) {
                throw new H2OIllegalArgumentException("frame", "predict", s.frame);
            }
            if (null == DKV.get(s.frame.name)) {
                throw new H2OKeyNotFoundArgumentException("frame", "predict", s.frame.name);
            }
        }
        ModelMetricsList parms = (ModelMetricsList)s.createAndFillImpl();
        if (!(s.reconstruction_error || s.reconstruction_error_per_feature || s.deep_features_hidden_layer >= 0 || s.project_archetypes || s.reconstruct_train || s.leaf_node_assignment || s.exemplar_index >= 0)) {
            if (null == parms._predictions_name) {
                parms._predictions_name = "predictions" + Key.make().toString().substring(0, 5) + "_" + parms._model._key.toString() + "_on_" + parms._frame._key.toString();
            }
            predictions = parms._model.score(parms._frame, parms._predictions_name);
        } else if (Model.DeepFeatures.class.isAssignableFrom(parms._model.getClass())) {
            if (s.reconstruction_error || s.reconstruction_error_per_feature) {
                if (s.deep_features_hidden_layer >= 0) {
                    throw new H2OIllegalArgumentException("Can only compute either reconstruction error OR deep features.", "");
                }
                if (null == parms._predictions_name) {
                    parms._predictions_name = "reconstruction_error" + Key.make().toString().substring(0, 5) + "_" + parms._model._key.toString() + "_on_" + parms._frame._key.toString();
                }
                predictions = ((Model.DeepFeatures)((Object)parms._model)).scoreAutoEncoder(parms._frame, Key.make(parms._predictions_name), parms._reconstruction_error_per_feature);
            } else {
                if (s.deep_features_hidden_layer < 0) {
                    throw new H2OIllegalArgumentException("Deep features hidden layer index must be >= 0.", "");
                }
                if (null == parms._predictions_name) {
                    parms._predictions_name = "deep_features" + Key.make().toString().substring(0, 5) + "_" + parms._model._key.toString() + "_on_" + parms._frame._key.toString();
                }
                predictions = ((Model.DeepFeatures)((Object)parms._model)).scoreDeepFeatures(parms._frame, s.deep_features_hidden_layer);
            }
            predictions = new Frame(Key.make(parms._predictions_name), predictions.names(), predictions.vecs());
            DKV.put(predictions._key, predictions);
        } else if (Model.GLRMArchetypes.class.isAssignableFrom(parms._model.getClass())) {
            if (s.project_archetypes) {
                if (null == parms._predictions_name) {
                    parms._predictions_name = "reconstructed_archetypes_" + Key.make().toString().substring(0, 5) + "_" + parms._model._key.toString() + "_of_" + parms._frame._key.toString();
                }
                predictions = ((Model.GLRMArchetypes)((Object)parms._model)).scoreArchetypes(parms._frame, Key.make(parms._predictions_name), s.reverse_transform);
            } else {
                assert (s.reconstruct_train);
                if (null == parms._predictions_name) {
                    parms._predictions_name = "reconstruction_" + Key.make().toString().substring(0, 5) + "_" + parms._model._key.toString() + "_of_" + parms._frame._key.toString();
                }
                predictions = ((Model.GLRMArchetypes)((Object)parms._model)).scoreReconstruction(parms._frame, Key.make(parms._predictions_name), s.reverse_transform);
            }
        } else if (s.leaf_node_assignment) {
            assert (Model.LeafNodeAssignment.class.isAssignableFrom(parms._model.getClass()));
            if (null == parms._predictions_name) {
                parms._predictions_name = "leaf_node_assignment" + Key.make().toString().substring(0, 5) + "_" + parms._model._key.toString() + "_on_" + parms._frame._key.toString();
            }
            predictions = ((Model.LeafNodeAssignment)((Object)parms._model)).scoreLeafNodeAssignment(parms._frame, Key.make(parms._predictions_name));
        } else if (s.exemplar_index >= 0) {
            assert (Model.ExemplarMembers.class.isAssignableFrom(parms._model.getClass()));
            if (null == parms._predictions_name) {
                parms._predictions_name = "members_" + parms._model._key.toString() + "_for_exemplar_" + parms._exemplar_index;
            }
            predictions = ((Model.ExemplarMembers)((Object)parms._model)).scoreExemplarMembers(Key.make(parms._predictions_name), parms._exemplar_index);
        } else {
            throw new H2OIllegalArgumentException("Requires a Deep Learning, GLRM, DRF or GBM model.", "Model must implement specific methods.");
        }
        ModelMetricsListSchemaV3 mm = this.fetch(version, s);
        if (null == mm) {
            mm = new ModelMetricsListSchemaV3();
        }
        mm.predictions_frame = new KeyV3.FrameKeyV3((Key<Frame>)predictions._key);
        if (parms._leaf_node_assignment) {
            mm.model_metrics = null;
        }
        if (null != mm.model_metrics && 0 != mm.model_metrics.length) {
            mm.model_metrics[0].predictions = new FrameV3(predictions, 0L, 100);
        }
        return mm;
    }

    public static final class ModelMetricsListSchemaV3
    extends SchemaV3<ModelMetricsList, ModelMetricsListSchemaV3> {
        @API(help="Key of Model of interest (optional)", json=true)
        public KeyV3.ModelKeyV3 model;
        @API(help="Key of Frame of interest (optional)", json=true)
        public KeyV3.FrameKeyV3 frame;
        @API(help="Key of predictions frame, if predictions are requested (optional)", json=true, required=false, direction=API.Direction.INOUT)
        public KeyV3.FrameKeyV3 predictions_frame;
        @API(help="Compute reconstruction error (optional, only for Deep Learning AutoEncoder models)", json=false, required=false)
        public boolean reconstruction_error;
        @API(help="Compute reconstruction error per feature (optional, only for Deep Learning AutoEncoder models)", json=false, required=false)
        public boolean reconstruction_error_per_feature;
        @API(help="Extract Deep Features for given hidden layer (optional, only for Deep Learning models)", json=false, required=false)
        public int deep_features_hidden_layer;
        @API(help="Reconstruct original training frame (optional, only for GLRM models)", json=false, required=false)
        public boolean reconstruct_train;
        @API(help="Project GLRM archetypes back into original feature space (optional, only for GLRM models)", json=false, required=false)
        public boolean project_archetypes;
        @API(help="Reverse transformation applied during training to model output (optional, only for GLRM models)", json=false, required=false)
        public boolean reverse_transform;
        @API(help="Return the leaf node assignment (optional, only for DRF/GBM models)", json=false, required=false)
        public boolean leaf_node_assignment;
        @API(help="Retrieve all members for a given exemplar (optional, only for Aggregator models)", json=false, required=false)
        public int exemplar_index;
        @API(help="ModelMetrics", direction=API.Direction.OUTPUT)
        public ModelMetricsBaseV3[] model_metrics;

        @Override
        public ModelMetricsList fillImpl(ModelMetricsList mml) {
            mml._model = null == this.model || null == this.model.key() ? null : (Model)this.model.key().get();
            mml._frame = null == this.frame || null == this.frame.key() ? null : (Frame)this.frame.key().get();
            mml._predictions_name = null == this.predictions_frame || null == this.predictions_frame.key() ? null : this.predictions_frame.key().toString();
            mml._reconstruction_error = this.reconstruction_error;
            mml._reconstruction_error_per_feature = this.reconstruction_error_per_feature;
            mml._deep_features_hidden_layer = this.deep_features_hidden_layer;
            mml._reconstruct_train = this.reconstruct_train;
            mml._project_archetypes = this.project_archetypes;
            mml._reverse_transform = this.reverse_transform;
            mml._leaf_node_assignment = this.leaf_node_assignment;
            mml._exemplar_index = this.exemplar_index;
            if (null != this.model_metrics) {
                mml._model_metrics = new ModelMetrics[this.model_metrics.length];
                for (int i = 0; i < this.model_metrics.length; ++i) {
                    mml._model_metrics[i++] = (ModelMetrics)this.model_metrics[i].createImpl();
                }
            }
            return mml;
        }

        @Override
        public ModelMetricsListSchemaV3 fillFromImpl(ModelMetricsList mml) {
            this.model = mml._model == null ? null : new KeyV3.ModelKeyV3((Key<? extends Model>)mml._model._key);
            this.frame = mml._frame == null ? null : new KeyV3.FrameKeyV3((Key<Frame>)mml._frame._key);
            this.predictions_frame = mml._predictions_name == null ? null : new KeyV3.FrameKeyV3(Key.make(mml._predictions_name));
            this.reconstruction_error = mml._reconstruction_error;
            this.reconstruction_error_per_feature = mml._reconstruction_error_per_feature;
            this.deep_features_hidden_layer = mml._deep_features_hidden_layer;
            this.reconstruct_train = mml._reconstruct_train;
            this.project_archetypes = mml._project_archetypes;
            this.reverse_transform = mml._reverse_transform;
            this.leaf_node_assignment = mml._leaf_node_assignment;
            this.exemplar_index = mml._exemplar_index;
            if (null != mml._model_metrics) {
                this.model_metrics = new ModelMetricsBaseV3[mml._model_metrics.length];
                for (int i = 0; i < this.model_metrics.length; ++i) {
                    ModelMetrics mm = mml._model_metrics[i];
                    this.model_metrics[i] = (ModelMetricsBaseV3)SchemaServer.schema(3, mm.getClass()).fillFromImpl(mm);
                }
            } else {
                this.model_metrics = new ModelMetricsBaseV3[0];
            }
            return this;
        }
    }

    public static final class ModelMetricsList
    extends Iced {
        public Model _model;
        public Frame _frame;
        public ModelMetrics[] _model_metrics;
        public String _predictions_name;
        public boolean _reconstruction_error;
        public boolean _reconstruction_error_per_feature;
        public int _deep_features_hidden_layer = -1;
        public boolean _reconstruct_train;
        public boolean _project_archetypes;
        public boolean _reverse_transform;
        public boolean _leaf_node_assignment;
        public int _exemplar_index = -1;

        ModelMetricsList fetch() {
            Key[] modelMetricsKeys = KeySnapshot.globalSnapshot().filter(new KeySnapshot.KVFilter(){

                @Override
                public boolean filter(KeySnapshot.KeyInfo k) {
                    try {
                        if (!Value.isSubclassOf(k._type, ModelMetrics.class)) {
                            return false;
                        }
                        ModelMetrics mm = (ModelMetrics)DKV.getGet(k._key);
                        if (ModelMetricsList.this._model != null && !mm.isForModel((Model)DKV.getGet(ModelMetricsList.this._model._key))) {
                            return false;
                        }
                        if (ModelMetricsList.this._frame != null && !mm.isForFrame((Frame)DKV.getGet(ModelMetricsList.this._frame._key))) {
                            return false;
                        }
                    }
                    catch (ClassCastException | NullPointerException ex) {
                        return false;
                    }
                    return true;
                }
            }).keys();
            this._model_metrics = new ModelMetrics[modelMetricsKeys.length];
            for (int i = 0; i < modelMetricsKeys.length; ++i) {
                this._model_metrics[i] = (ModelMetrics)DKV.getGet(modelMetricsKeys[i]);
            }
            return this;
        }

        ModelMetricsList delete() {
            ModelMetricsList matches = this.fetch();
            for (ModelMetrics mm : matches._model_metrics) {
                DKV.remove(mm._key);
            }
            return matches;
        }

        public Schema list(int version, ModelMetricsList m) {
            return this.schema(version).fillFromImpl(m.fetch());
        }

        protected ModelMetricsListSchemaV3 schema(int version) {
            switch (version) {
                case 3: {
                    return new ModelMetricsListSchemaV3();
                }
            }
            throw H2O.fail("Bad version for ModelMetrics schema: " + version);
        }
    }
}

