/*
 * Decompiled with CFR 0.152.
 */
package water.api;

import hex.Model;
import hex.ModelBuilder;
import hex.ModelParametersBuilderFactory;
import hex.faulttolerance.Recovery;
import hex.grid.Grid;
import hex.grid.GridSearch;
import hex.grid.HyperSpaceSearchCriteria;
import hex.schemas.GridSearchSchema;
import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.TreeMap;
import water.H2O;
import water.Job;
import water.Key;
import water.TypeMap;
import water.api.API;
import water.api.Handler;
import water.api.Route;
import water.api.Schema;
import water.api.SchemaMetadata;
import water.api.SchemaServer;
import water.api.schemas3.JobV3;
import water.api.schemas3.ModelParametersSchemaV3;
import water.exceptions.H2OIllegalArgumentException;
import water.util.IcedHashMap;
import water.util.PojoUtils;

public class GridSearchHandler<G extends Grid<MP>, S extends GridSearchSchema<G, S, MP, P>, MP extends Model.Parameters, P extends ModelParametersSchemaV3>
extends Handler {
    public S handle(int version, Route route, Properties parms, String postBody) throws Exception {
        String methodName = route._handler_method.getName();
        String[] ss = route._url.split("/");
        String algoURLName = ss[3];
        if ("train".equals(methodName)) {
            return this.trainGrid(algoURLName, parms);
        }
        if ("resume".equals(methodName)) {
            return this.resumeGrid(algoURLName, parms);
        }
        throw H2O.unimpl();
    }

    private S resumeGrid(String algoURLName, Properties parms) {
        if (!parms.containsKey("grid_id")) {
            throw new IllegalArgumentException("grid_id is missing");
        }
        S gss = this.buildGridSearchSchema(algoURLName, parms);
        Grid grid = (Grid)((GridSearchSchema)gss).grid_id.key().get();
        Key<Job> jobKey = ((GridSearchSchema)gss).job_id != null ? ((GridSearchSchema)gss).job_id.key() : null;
        Recovery<Grid> recovery = this.getRecovery((GridSearchSchema)gss);
        Job<Grid> gsJob = GridSearch.resumeGridSearch(jobKey, grid, new DefaultModelParametersBuilderFactory(), recovery);
        ((GridSearchSchema)gss).hyper_parameters = null;
        ((GridSearchSchema)gss).job = new JobV3(gsJob);
        return gss;
    }

    private S buildGridSearchSchema(String algoURLName, Properties parms) {
        String algoName = ModelBuilder.algoName(algoURLName);
        String schemaDir = ModelBuilder.schemaDirectory(algoURLName);
        int algoVersion = 3;
        if (algoName.equals("SVD") || algoName.equals("Aggregator") || algoName.equals("StackedEnsemble")) {
            algoVersion = 99;
        }
        String paramSchemaName = schemaDir + algoName + "V" + algoVersion + "$" + ModelBuilder.paramName(algoURLName) + "V" + algoVersion;
        GridSearchSchema gss = new GridSearchSchema();
        gss.init_meta();
        gss.parameters = (ModelParametersSchemaV3)TypeMap.newFreezable(paramSchemaName);
        ((Schema)gss.parameters).init_meta();
        gss.hyper_parameters = new IcedHashMap();
        Object builder = ModelBuilder.make(algoURLName, null, null);
        ((ModelParametersSchemaV3)gss.parameters).fillFromImpl(((ModelBuilder)builder)._parms);
        gss.fillFromParms(parms);
        return (S)gss;
    }

    private S trainGrid(String algoURLName, Properties parms) {
        S gss = this.buildGridSearchSchema(algoURLName, parms);
        this.validateHyperParams(((GridSearchSchema)gss).parameters, ((GridSearchSchema)gss).hyper_parameters);
        Model.Parameters params = (Model.Parameters)((Schema)((GridSearchSchema)gss).parameters).createAndFillImpl();
        TreeMap<String, Object[]> sortedMap = new TreeMap<String, Object[]>(((GridSearchSchema)gss).hyper_parameters);
        if (sortedMap.containsKey("validation_frame")) {
            sortedMap.put("valid", (Object[])sortedMap.get("validation_frame"));
            sortedMap.remove("validation_frame");
        }
        Key<Grid> destKey = ((GridSearchSchema)gss).grid_id != null ? ((GridSearchSchema)gss).grid_id.key() : null;
        Recovery<Grid> recovery = this.getRecovery((GridSearchSchema)gss);
        Key<Job> jobKey = ((GridSearchSchema)gss).job_id != null ? ((GridSearchSchema)gss).job_id.key() : null;
        Job<Grid> gsJob = GridSearch.startGridSearch(jobKey, destKey, params, sortedMap, new DefaultModelParametersBuilderFactory(), (HyperSpaceSearchCriteria)((GridSearchSchema)gss).search_criteria.createAndFillImpl(), recovery, GridSearch.getParallelismLevel(((GridSearchSchema)gss).parallelism));
        ((GridSearchSchema)gss).hyper_parameters = null;
        ((GridSearchSchema)gss).total_models = ((Grid)gsJob._result.get()).getModelCount();
        ((GridSearchSchema)gss).job = new JobV3(gsJob);
        return gss;
    }

    public S train(int version, S gridSearchSchema) {
        throw H2O.fail();
    }

    public S resume(int version, S gridSearchSchema) {
        throw H2O.fail();
    }

    protected void validateHyperParams(P params, Map<String, Object[]> hyperParams) {
        List<SchemaMetadata.FieldMetadata> fsMeta = SchemaMetadata.getFieldMetadata(params);
        HashSet<String> allKeys = new HashSet<String>(hyperParams.keySet());
        allKeys.remove("subspaces");
        for (String hparam : allKeys) {
            SchemaMetadata.FieldMetadata fieldMetadata = null;
            for (SchemaMetadata.FieldMetadata fm : fsMeta) {
                if (!fm.name.equals(hparam)) continue;
                fieldMetadata = fm;
                break;
            }
            if (fieldMetadata == null) {
                throw new H2OIllegalArgumentException(hparam, "grid", "Unknown hyper parameter for grid search!");
            }
            if (fieldMetadata.is_gridable) continue;
            throw new H2OIllegalArgumentException(hparam, "grid", "Illegal hyper parameter for grid search! The parameter '" + fieldMetadata.name + " is not gridable!");
        }
        if (hyperParams.get("subspaces") != null) {
            Arrays.stream(hyperParams.get("subspaces")).forEach(subspace -> this.validateHyperParams(params, (Map)subspace));
        }
    }

    private Recovery<Grid> getRecovery(GridSearchSchema gss) {
        if (gss.recovery_dir != null) {
            return new Recovery<Grid>(gss.recovery_dir);
        }
        if (H2O.ARGS.auto_recovery_dir != null) {
            return new Recovery<Grid>(H2O.ARGS.auto_recovery_dir);
        }
        return null;
    }

    public static class ModelParametersFromSchemaBuilder<MP extends Model.Parameters, PS extends ModelParametersSchemaV3>
    implements ModelParametersBuilderFactory.ModelParametersBuilder<MP> {
        private final MP params;
        private final PS paramsSchema;
        private final ArrayList<String> fields;

        public ModelParametersFromSchemaBuilder(MP initialParams) {
            this.params = initialParams;
            this.paramsSchema = (ModelParametersSchemaV3)SchemaServer.schema(-1, this.params.getClass());
            this.fields = new ArrayList(7);
        }

        public ModelParametersFromSchemaBuilder<MP, PS> set(String name, Object value) {
            try {
                Field f = this.paramsSchema.getClass().getField(name);
                API api = (API)f.getAnnotations()[0];
                Schema.setField(this.paramsSchema, f, name, value.toString(), api.required(), this.paramsSchema.getClass());
                this.fields.add(name);
            }
            catch (NoSuchFieldException e) {
                throw new IllegalArgumentException("Cannot find field '" + name + "' to value " + value, e);
            }
            catch (IllegalAccessException | RuntimeException e) {
                throw new IllegalArgumentException("Cannot set field '" + name + "' to value " + value, e);
            }
            return this;
        }

        @Override
        public MP build() {
            PojoUtils.copyProperties(this.params, this.paramsSchema, PojoUtils.FieldNaming.DEST_HAS_UNDERSCORES, null, this.fields.toArray(new String[this.fields.size()]));
            if (((Model.Parameters)this.params)._valid == null && ((ModelParametersSchemaV3)this.paramsSchema).validation_frame != null) {
                ((Model.Parameters)this.params)._valid = Key.make(((ModelParametersSchemaV3)this.paramsSchema).validation_frame.name);
            }
            if (((Model.Parameters)this.params)._train == null && ((ModelParametersSchemaV3)this.paramsSchema).training_frame != null) {
                ((Model.Parameters)this.params)._train = Key.make(((ModelParametersSchemaV3)this.paramsSchema).training_frame.name);
            }
            return this.params;
        }
    }

    public static class DefaultModelParametersBuilderFactory<MP extends Model.Parameters, PS extends ModelParametersSchemaV3>
    implements ModelParametersBuilderFactory<MP> {
        @Override
        public ModelParametersBuilderFactory.ModelParametersBuilder<MP> get(MP initialParams) {
            return new ModelParametersFromSchemaBuilder(initialParams);
        }

        @Override
        public PojoUtils.FieldNaming getFieldNamingStrategy() {
            return PojoUtils.FieldNaming.DEST_HAS_UNDERSCORES;
        }
    }
}

