/*
 * Decompiled with CFR 0.152.
 */
package hex.api;

import hex.Model;
import hex.ModelParametersBuilderFactory;
import hex.grid.Grid;
import hex.grid.GridSearch;
import hex.grid.ModelFactory;
import hex.schemas.GridSearchSchema;
import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import water.Iced;
import water.Job;
import water.Key;
import water.api.API;
import water.api.Handler;
import water.api.JobV3;
import water.api.ModelParametersSchema;
import water.api.Schema;
import water.api.SchemaMetadata;
import water.exceptions.H2OIllegalArgumentException;
import water.util.IcedHashMap;
import water.util.PojoUtils;

public abstract class GridSearchHandler<G extends Grid<MP>, S extends GridSearchSchema<G, S, MP, P>, MP extends Model.Parameters, P extends ModelParametersSchema>
extends Handler {
    public S do_train(int version, S gridSearchSchema) {
        Object parametersSchema = ((GridSearchSchema)((Object)gridSearchSchema)).parameters;
        IcedHashMap<String, Object[]> hyperParams = ((GridSearchSchema)((Object)gridSearchSchema)).hyper_parameters;
        this.validateHyperParams(parametersSchema, (Map<String, Object[]>)hyperParams);
        Key destKey = ((GridSearchSchema)((Object)gridSearchSchema)).grid_id != null ? ((GridSearchSchema)((Object)gridSearchSchema)).grid_id.key() : null;
        Model.Parameters params = (Model.Parameters)parametersSchema.createAndFillImpl();
        ModelFactory<MP> modelFactory = this.getModelFactory();
        GridSearch gsJob = GridSearch.startGridSearch((Key)destKey, (Model.Parameters)params, hyperParams, modelFactory, new DefaultModelParametersBuilderFactory());
        ((GridSearchSchema)((Object)gridSearchSchema)).hyper_parameters = null;
        ((GridSearchSchema)((Object)gridSearchSchema)).total_models = gsJob.getModelCount();
        ((GridSearchSchema)((Object)gridSearchSchema)).job = (JobV3)Schema.schema((int)version, Job.class).fillFromImpl((Iced)gsJob);
        return gridSearchSchema;
    }

    protected abstract ModelFactory<MP> getModelFactory();

    protected void validateHyperParams(P params, Map<String, Object[]> hyperParams) {
        List fsMeta = SchemaMetadata.getFieldMetadata(params);
        for (Map.Entry<String, Object[]> hparam : hyperParams.entrySet()) {
            SchemaMetadata.FieldMetadata fieldMetadata = null;
            for (SchemaMetadata.FieldMetadata fm : fsMeta) {
                if (!fm.name.equals(hparam.getKey())) continue;
                fieldMetadata = fm;
                break;
            }
            if (fieldMetadata == null) {
                throw new H2OIllegalArgumentException(hparam.getKey(), "grid", (Object)"Unknown hyper parameter for grid search!");
            }
            if (fieldMetadata.is_gridable) continue;
            throw new H2OIllegalArgumentException(hparam.getKey(), "grid", (Object)("Illegal hyper parameter for grid search! The parameter '" + fieldMetadata.name + " is not gridable!"));
        }
    }

    public static class ModelParametersFromSchemaBuilder<MP extends Model.Parameters, PS extends ModelParametersSchema>
    implements ModelParametersBuilderFactory.ModelParametersBuilder<MP> {
        private final MP params;
        private final PS paramsSchema;
        private final ArrayList<String> fields;

        public ModelParametersFromSchemaBuilder(MP initialParams) {
            this.params = initialParams;
            this.paramsSchema = (ModelParametersSchema)Schema.schema((int)Schema.getHighestSupportedVersion(), this.params.getClass());
            this.fields = new ArrayList(7);
        }

        public ModelParametersFromSchemaBuilder<MP, PS> set(String name, Object value) {
            try {
                Field f = this.paramsSchema.getClass().getField(name);
                API api = (API)f.getAnnotations()[0];
                Schema.setField(this.paramsSchema, (Field)f, (String)name, (String)value.toString(), (boolean)api.required(), this.paramsSchema.getClass());
                this.fields.add(name);
            }
            catch (NoSuchFieldException e) {
                throw new IllegalArgumentException("Cannot find field '" + name + "'", e);
            }
            catch (IllegalAccessException e) {
                throw new IllegalArgumentException("Cannot set field '" + name + "'", e);
            }
            catch (RuntimeException e) {
                throw new IllegalArgumentException("Cannot set field '" + name + "'", e);
            }
            return this;
        }

        public MP build() {
            PojoUtils.copyProperties(this.params, this.paramsSchema, (PojoUtils.FieldNaming)PojoUtils.FieldNaming.DEST_HAS_UNDERSCORES, null, (String[])this.fields.toArray(new String[this.fields.size()]));
            if (((Model.Parameters)this.params)._valid == null && ((ModelParametersSchema)this.paramsSchema).validation_frame != null) {
                ((Model.Parameters)this.params)._valid = Key.make((String)((ModelParametersSchema)this.paramsSchema).validation_frame.name);
            }
            if (((Model.Parameters)this.params)._train == null && ((ModelParametersSchema)this.paramsSchema).training_frame != null) {
                ((Model.Parameters)this.params)._train = Key.make((String)((ModelParametersSchema)this.paramsSchema).training_frame.name);
            }
            return this.params;
        }
    }

    static class DefaultModelParametersBuilderFactory<MP extends Model.Parameters, PS extends ModelParametersSchema>
    implements ModelParametersBuilderFactory<MP> {
        DefaultModelParametersBuilderFactory() {
        }

        public ModelParametersBuilderFactory.ModelParametersBuilder<MP> get(MP initialParams) {
            return new ModelParametersFromSchemaBuilder(initialParams);
        }

        public PojoUtils.FieldNaming getFieldNamingStrategy() {
            return PojoUtils.FieldNaming.DEST_HAS_UNDERSCORES;
        }
    }
}

