/*
 * Decompiled with CFR 0.152.
 */
package water.api;

import hex.Model;
import hex.ModelBuilder;
import hex.ModelParametersBuilderFactory;
import hex.grid.Grid;
import hex.grid.GridSearch;
import hex.grid.HyperSpaceSearchCriteria;
import hex.schemas.GridSearchSchema;
import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import water.H2O;
import water.Job;
import water.Key;
import water.TypeMap;
import water.api.API;
import water.api.Handler;
import water.api.JobV3;
import water.api.ModelParametersSchema;
import water.api.Route;
import water.api.Schema;
import water.api.SchemaMetadata;
import water.exceptions.H2OIllegalArgumentException;
import water.util.PojoUtils;

public class GridSearchHandler<G extends Grid<MP>, S extends GridSearchSchema<G, S, MP, P>, MP extends Model.Parameters, P extends ModelParametersSchema>
extends Handler {
    S handle(int version, Route route, Properties parms) throws Exception {
        if (!route._handler_method.getName().equals("train")) {
            throw H2O.unimpl();
        }
        String[] ss = route._url_pattern_raw.split("/");
        String algoURLName = ss[3];
        String algoName = ModelBuilder.algoName(algoURLName);
        String schemaDir = ModelBuilder.schemaDirectory(algoURLName);
        String algoSchemaName = Schema.schemaClass(version, algoName).getSimpleName();
        int algoVersion = Integer.valueOf(algoSchemaName.substring(algoSchemaName.lastIndexOf("V") + 1));
        String paramSchemaName = schemaDir + algoName + "V" + algoVersion + "$" + ModelBuilder.paramName(algoURLName) + "V" + algoVersion;
        GridSearchSchema gss = new GridSearchSchema();
        gss.init_meta();
        gss.parameters = (ModelParametersSchema)TypeMap.newFreezable(paramSchemaName);
        ((Schema)gss.parameters).init_meta();
        Object builder = ModelBuilder.make(algoURLName, null, null);
        ((ModelParametersSchema)gss.parameters).fillFromImpl(((ModelBuilder)builder)._parms);
        gss.fillFromParms(parms);
        this.validateHyperParams(gss.parameters, gss.hyper_parameters);
        Model.Parameters params = (Model.Parameters)((Schema)gss.parameters).createAndFillImpl();
        Key<Grid> destKey = gss.grid_id != null ? gss.grid_id.key() : null;
        Job<Grid> gsJob = GridSearch.startGridSearch(destKey, params, gss.hyper_parameters, new DefaultModelParametersBuilderFactory(), (HyperSpaceSearchCriteria)gss.search_criteria.createAndFillImpl());
        gss.hyper_parameters = null;
        gss.total_models = ((Grid)gsJob._result.get()).getModelCount();
        gss.job = (JobV3)Schema.schema(version, Job.class).fillFromImpl(gsJob);
        return (S)gss;
    }

    public S train(int version, S gridSearchSchema) {
        throw H2O.fail();
    }

    protected void validateHyperParams(P params, Map<String, Object[]> hyperParams) {
        List<SchemaMetadata.FieldMetadata> fsMeta = SchemaMetadata.getFieldMetadata(params);
        for (Map.Entry<String, Object[]> hparam : hyperParams.entrySet()) {
            SchemaMetadata.FieldMetadata fieldMetadata = null;
            for (SchemaMetadata.FieldMetadata fm : fsMeta) {
                if (!fm.name.equals(hparam.getKey())) continue;
                fieldMetadata = fm;
                break;
            }
            if (fieldMetadata == null) {
                throw new H2OIllegalArgumentException(hparam.getKey(), "grid", "Unknown hyper parameter for grid search!");
            }
            if (fieldMetadata.is_gridable) continue;
            throw new H2OIllegalArgumentException(hparam.getKey(), "grid", "Illegal hyper parameter for grid search! The parameter '" + fieldMetadata.name + " is not gridable!");
        }
    }

    public static class ModelParametersFromSchemaBuilder<MP extends Model.Parameters, PS extends ModelParametersSchema>
    implements ModelParametersBuilderFactory.ModelParametersBuilder<MP> {
        private final MP params;
        private final PS paramsSchema;
        private final ArrayList<String> fields;

        public ModelParametersFromSchemaBuilder(MP initialParams) {
            this.params = initialParams;
            this.paramsSchema = (ModelParametersSchema)Schema.schema(Schema.getHighestSupportedVersion(), this.params.getClass());
            this.fields = new ArrayList(7);
        }

        public ModelParametersFromSchemaBuilder<MP, PS> set(String name, Object value) {
            try {
                Field f = this.paramsSchema.getClass().getField(name);
                API api = (API)f.getAnnotations()[0];
                Schema.setField(this.paramsSchema, f, name, value.toString(), api.required(), this.paramsSchema.getClass());
                this.fields.add(name);
            }
            catch (NoSuchFieldException e) {
                throw new IllegalArgumentException("Cannot find field '" + name + "'", e);
            }
            catch (IllegalAccessException e) {
                throw new IllegalArgumentException("Cannot set field '" + name + "'", e);
            }
            catch (RuntimeException e) {
                throw new IllegalArgumentException("Cannot set field '" + name + "'", e);
            }
            return this;
        }

        @Override
        public MP build() {
            PojoUtils.copyProperties(this.params, this.paramsSchema, PojoUtils.FieldNaming.DEST_HAS_UNDERSCORES, null, this.fields.toArray(new String[this.fields.size()]));
            if (((Model.Parameters)this.params)._valid == null && ((ModelParametersSchema)this.paramsSchema).validation_frame != null) {
                ((Model.Parameters)this.params)._valid = Key.make(((ModelParametersSchema)this.paramsSchema).validation_frame.name);
            }
            if (((Model.Parameters)this.params)._train == null && ((ModelParametersSchema)this.paramsSchema).training_frame != null) {
                ((Model.Parameters)this.params)._train = Key.make(((ModelParametersSchema)this.paramsSchema).training_frame.name);
            }
            return this.params;
        }
    }

    static class DefaultModelParametersBuilderFactory<MP extends Model.Parameters, PS extends ModelParametersSchema>
    implements ModelParametersBuilderFactory<MP> {
        DefaultModelParametersBuilderFactory() {
        }

        @Override
        public ModelParametersBuilderFactory.ModelParametersBuilder<MP> get(MP initialParams) {
            return new ModelParametersFromSchemaBuilder(initialParams);
        }

        @Override
        public PojoUtils.FieldNaming getFieldNamingStrategy() {
            return PojoUtils.FieldNaming.DEST_HAS_UNDERSCORES;
        }
    }
}

