/*
 * Decompiled with CFR 0.152.
 */
package hex.grid;

import hex.Model;
import hex.ModelBuilder;
import hex.ModelParametersBuilderFactory;
import hex.grid.Grid;
import hex.grid.HyperSpaceWalker;
import hex.grid.ModelFactory;
import java.util.Map;
import water.DKV;
import water.H2O;
import water.Job;
import water.Key;
import water.exceptions.H2OIllegalArgumentException;
import water.fvec.Frame;
import water.util.Log;
import water.util.PojoUtils;

public final class GridSearch<MP extends Model.Parameters>
extends Job<Grid> {
    private final transient ModelFactory<MP> _modelFactory;
    private final transient HyperSpaceWalker<MP> _hyperSpaceWalker;

    private GridSearch(Key gkey, ModelFactory<MP> modelFactory, HyperSpaceWalker<MP> hyperSpaceWalker) {
        super(gkey, modelFactory.getModelName() + " Grid Search");
        assert (modelFactory != null) : "Grid search needs to know how to build a new model!";
        assert (hyperSpaceWalker != null) : "Grid search needs to know to how walk around hyper space!";
        this._modelFactory = modelFactory;
        this._hyperSpaceWalker = hyperSpaceWalker;
    }

    GridSearch start() {
        int gridSize = this._hyperSpaceWalker.getHyperSpaceSize();
        Log.info("Starting gridsearch: estimated size of search space = " + gridSize);
        Grid<MP> grid = (Grid<MP>)DKV.getGet(this.dest());
        if (grid != null) {
            Frame specTrainFrame = ((Model.Parameters)this._hyperSpaceWalker.getParams()).train();
            Frame oldTrainFrame = grid.getTrainingFrame();
            if (!specTrainFrame._key.equals(oldTrainFrame._key) || specTrainFrame.checksum() != oldTrainFrame.checksum()) {
                throw new H2OIllegalArgumentException("training_frame", "grid", "Cannot append new models to a grid with different training input");
            }
            grid.write_lock(this.jobKey());
        } else {
            grid = new Grid<MP>(this.dest(), this._hyperSpaceWalker.getParams(), this._hyperSpaceWalker.getHyperParamNames(), this._modelFactory.getModelName(), this._hyperSpaceWalker.getParametersBuilderFactory().getFieldNamingStrategy());
            grid.delete_and_lock(this.jobKey());
        }
        final Grid<MP> gridToExpose = grid;
        this.start(new H2O.H2OCountedCompleter(){

            @Override
            public void compute2() {
                GridSearch.this.gridSearch(gridToExpose);
                this.tryComplete();
            }
        }, gridSize, true);
        return this;
    }

    public int getModelCount() {
        return this._hyperSpaceWalker.getHyperSpaceSize();
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void gridSearch(Grid<MP> grid) {
        block16: {
            Model model = null;
            try {
                HyperSpaceWalker.HyperSpaceIterator<MP> it = this._hyperSpaceWalker.iterator();
                while (it.hasNext(model)) {
                    if (!this.isRunning()) {
                        this.cancel();
                        return;
                    }
                    Object params = null;
                    try {
                        params = it.nextModelParameters(model);
                        try {
                            model = this.buildModel(params, grid);
                        }
                        catch (RuntimeException e) {
                            Log.warn("Grid search: model builder for parameters " + params + " failed! Exception: ", e);
                            grid.appendFailedModelParameters(params, (Exception)e);
                        }
                    }
                    catch (IllegalArgumentException e) {
                        Log.warn("Grid search: construction of model parameters failed! Exception: ", e);
                        Object[] rawParams = it.getCurrentRawParameters();
                        grid.appendFailedModelParameters(rawParams, (Exception)e);
                    }
                    finally {
                        this.update(1L);
                        grid.update(this.jobKey());
                    }
                }
                this.done();
            }
            catch (Throwable e) {
                Job thisJob = (Job)DKV.getGet(this.jobKey());
                if (thisJob._state == Job.JobState.CANCELLED) {
                    Log.info("Job " + this.jobKey() + " cancelled by user.");
                    break block16;
                }
                this.failed(e);
                throw e;
            }
            finally {
                grid.unlock(this.jobKey());
            }
        }
    }

    private Model buildModel(MP params, Grid<MP> grid) {
        long checksum = ((Model.Parameters)params).checksum();
        Key<Model> key = grid.getModelKey(checksum);
        if (key != null) {
            return key.get();
        }
        Model m = (Model)this.startBuildModel(params, grid).get();
        grid.putModel(checksum, m._key);
        return m;
    }

    private ModelBuilder startBuildModel(MP params, Grid<MP> grid) {
        if (grid.getModel(params) != null) {
            return null;
        }
        ModelBuilder mb = this._modelFactory.buildModel(params);
        mb.trainModel();
        return mb;
    }

    protected static Key<Grid> gridKeyName(String modelName, Frame fr) {
        if (fr._key == null) {
            throw new IllegalArgumentException("The frame being grid-searched over must have a Key");
        }
        return Key.make("Grid_" + modelName + "_" + fr._key.toString() + H2O.calcNextUniqueModelId(""));
    }

    public static <MP extends Model.Parameters> GridSearch startGridSearch(Key<Grid> destKey, MP params, Map<String, Object[]> hyperParams, ModelFactory<MP> modelFactory, ModelParametersBuilderFactory<MP> paramsBuilderFactory) {
        HyperSpaceWalker.CartesianWalker<MP> hyperSpaceWalker = new HyperSpaceWalker.CartesianWalker<MP>(params, hyperParams, paramsBuilderFactory);
        return GridSearch.startGridSearch(destKey, modelFactory, hyperSpaceWalker);
    }

    public static <MP extends Model.Parameters> GridSearch startGridSearch(Key<Grid> destKey, MP params, Map<String, Object[]> hyperParams, ModelFactory<MP> modelFactory) {
        return GridSearch.startGridSearch(destKey, params, hyperParams, modelFactory, new SimpleParametersBuilderFactory());
    }

    public static <MP extends Model.Parameters> GridSearch startGridSearch(MP params, Map<String, Object[]> hyperParams, ModelFactory<MP> modelFactory) {
        return GridSearch.startGridSearch(null, params, hyperParams, modelFactory);
    }

    public static <MP extends Model.Parameters> GridSearch startGridSearch(Key<Grid> destKey, ModelFactory<MP> modelFactory, HyperSpaceWalker<MP> hyperSpaceWalker) {
        Key<Grid> gridKey = destKey != null ? destKey : GridSearch.gridKeyName(modelFactory.getModelName(), ((Model.Parameters)hyperSpaceWalker.getParams()).train());
        return new GridSearch<MP>(gridKey, modelFactory, hyperSpaceWalker).start();
    }

    static class SimpleParametersBuilderFactory<MP extends Model.Parameters>
    implements ModelParametersBuilderFactory<MP> {
        SimpleParametersBuilderFactory() {
        }

        @Override
        public ModelParametersBuilderFactory.ModelParametersBuilder<MP> get(MP initialParams) {
            return new SimpleParamsBuilder<MP>(initialParams);
        }

        @Override
        public PojoUtils.FieldNaming getFieldNamingStrategy() {
            return PojoUtils.FieldNaming.CONSISTENT;
        }

        public static class SimpleParamsBuilder<MP extends Model.Parameters>
        implements ModelParametersBuilderFactory.ModelParametersBuilder<MP> {
            private final MP params;

            public SimpleParamsBuilder(MP initialParams) {
                this.params = initialParams;
            }

            @Override
            public ModelParametersBuilderFactory.ModelParametersBuilder<MP> set(String name, Object value) {
                PojoUtils.setField(this.params, name, value, PojoUtils.FieldNaming.CONSISTENT);
                return this;
            }

            @Override
            public MP build() {
                return this.params;
            }
        }
    }
}

