/*
 * Decompiled with CFR 0.152.
 */
package hex.grid;

import hex.Model;
import hex.ModelBuilder;
import hex.ModelParametersBuilderFactory;
import hex.ScoringInfo;
import hex.grid.Grid;
import hex.grid.HyperSpaceSearchCriteria;
import hex.grid.HyperSpaceWalker;
import java.io.PrintWriter;
import java.io.StringWriter;
import java.util.Map;
import water.DKV;
import water.H2O;
import water.Job;
import water.Key;
import water.KeySnapshot;
import water.Keyed;
import water.Value;
import water.exceptions.H2OIllegalArgumentException;
import water.fvec.Frame;
import water.util.Log;
import water.util.PojoUtils;

public final class GridSearch<MP extends Model.Parameters>
extends Keyed<GridSearch> {
    public final Key<Grid> _result;
    public final Job<Grid> _job;
    private final transient HyperSpaceWalker<MP, ?> _hyperSpaceWalker;

    private GridSearch(Key<Grid> gkey, HyperSpaceWalker<MP, ?> hyperSpaceWalker) {
        assert (hyperSpaceWalker != null) : "Grid search needs to know how to walk around hyper space!";
        this._hyperSpaceWalker = hyperSpaceWalker;
        this._result = gkey;
        String algoName = ((Model.Parameters)hyperSpaceWalker.getParams()).algoName();
        this._job = new Job<Grid>(gkey, Grid.class.getName(), algoName + " Grid Search");
    }

    Job<Grid> start() {
        Grid<MP> grid;
        long gridSize = this._hyperSpaceWalker.getMaxHyperSpaceSize();
        Log.info("Starting gridsearch: estimated size of search space = " + gridSize);
        Keyed keyed = (Keyed)DKV.getGet(this._result);
        if (keyed != null) {
            if (!(keyed instanceof Grid)) {
                throw new H2OIllegalArgumentException("Name conflict: tried to create a Grid using the ID of a non-Grid object that's already in H2O: " + this._job._result + "; it is a: " + keyed.getClass());
            }
            grid = (Grid<MP>)keyed;
            Frame specTrainFrame = ((Model.Parameters)this._hyperSpaceWalker.getParams()).train();
            Frame oldTrainFrame = grid.getTrainingFrame();
            if (oldTrainFrame != null && !specTrainFrame._key.equals(oldTrainFrame._key) || oldTrainFrame != null && specTrainFrame.checksum() != oldTrainFrame.checksum()) {
                throw new H2OIllegalArgumentException("training_frame", "grid", "Cannot append new models to a grid with different training input");
            }
            grid.write_lock(this._job);
        } else {
            grid = new Grid<MP>(this._result, this._hyperSpaceWalker.getParams(), this._hyperSpaceWalker.getHyperParamNames(), this._hyperSpaceWalker.getParametersBuilderFactory().getFieldNamingStrategy());
            grid.delete_and_lock(this._job);
        }
        Model model = null;
        HyperSpaceWalker.HyperSpaceIterator<MP> it = this._hyperSpaceWalker.iterator();
        long gridWork = 0L;
        if (gridSize > 0L) {
            int count = 0;
            while (it.hasNext(model) && it.max_models() > 0 && count++ < it.max_models()) {
                try {
                    MP parms = it.nextModelParameters(model);
                    gridWork += (long)(((Model.Parameters)parms)._nfolds > 0 ? ((Model.Parameters)parms)._nfolds + 1 : 1) * ((Model.Parameters)parms).progressUnits();
                }
                catch (Throwable ex) {}
            }
        } else {
            gridWork = Long.MAX_VALUE;
        }
        it.reset();
        return this._job.start(new H2O.H2OCountedCompleter(){

            @Override
            public void compute2() {
                GridSearch.this.gridSearch(grid);
                this.tryComplete();
            }
        }, gridWork, it.max_runtime_secs());
    }

    public long getModelCount() {
        return this._hyperSpaceWalker.getMaxHyperSpaceSize();
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void gridSearch(Grid<MP> grid) {
        Model model = null;
        String protoModelKey = grid._key + "_model_";
        try {
            HyperSpaceWalker.HyperSpaceIterator<MP> it = this._hyperSpaceWalker.iterator();
            int counter = grid.getModelCount();
            while (it.hasNext(model)) {
                if (this._job.stop_requested()) {
                    return;
                }
                double max_runtime_secs = it.max_runtime_secs();
                double time_remaining_secs = Double.MAX_VALUE;
                if (max_runtime_secs > 0.0 && (time_remaining_secs = it.time_remaining_secs()) < 0.0) {
                    Log.info("Grid max_runtime_secs of " + max_runtime_secs + " secs has expired; stopping early.");
                    return;
                }
                try {
                    MP params = it.nextModelParameters(model);
                    if (max_runtime_secs > 0.0) {
                        double scale;
                        Log.info("Grid time is limited to: " + max_runtime_secs + " for grid: " + grid._key + ". Remaining time is: " + time_remaining_secs);
                        double d = scale = ((Model.Parameters)params)._nfolds > 0 ? (double)(((Model.Parameters)params)._nfolds + 1) : 1.0;
                        if (((Model.Parameters)params)._max_runtime_secs == 0.0) {
                            ((Model.Parameters)params)._max_runtime_secs = time_remaining_secs / scale;
                            Log.info("Due to the grid time limit, changing model max runtime to: " + ((Model.Parameters)params)._max_runtime_secs + " secs.");
                        } else {
                            double was = ((Model.Parameters)params)._max_runtime_secs;
                            ((Model.Parameters)params)._max_runtime_secs = Math.min(((Model.Parameters)params)._max_runtime_secs, time_remaining_secs / scale);
                            Log.info("Due to the grid time limit, changing model max runtime from: " + was + " secs to: " + ((Model.Parameters)params)._max_runtime_secs + " secs.");
                        }
                    }
                    try {
                        ScoringInfo scoringInfo = new ScoringInfo();
                        scoringInfo.time_stamp_ms = System.currentTimeMillis();
                        model = this.buildModel(params, grid, counter++, protoModelKey);
                        if (model != null) {
                            model.fillScoringInfo(scoringInfo);
                            grid.setScoringInfos(ScoringInfo.prependScoringInfo(scoringInfo, grid.getScoringInfos()));
                            ScoringInfo.sort(grid.getScoringInfos(), ((HyperSpaceSearchCriteria)this._hyperSpaceWalker.search_criteria()).stopping_metric());
                        }
                    }
                    catch (RuntimeException e) {
                        if (!Job.isCancelledException(e)) {
                            StringWriter sw = new StringWriter();
                            PrintWriter pw = new PrintWriter(sw);
                            e.printStackTrace(pw);
                            Log.warn("Grid search: model builder for parameters " + params + " failed! Exception: ", e, sw.toString());
                        }
                        grid.appendFailedModelParameters(params, (Exception)e);
                    }
                }
                catch (IllegalArgumentException e) {
                    Log.warn("Grid search: construction of model parameters failed! Exception: ", e);
                    it.modelFailed(model);
                    Object[] rawParams = it.getCurrentRawParameters();
                    grid.appendFailedModelParameters(rawParams, (Exception)e);
                }
                finally {
                    this._job.update(1L);
                    grid.update(this._job);
                }
                if (model == null || grid.getScoringInfos() == null || !this._hyperSpaceWalker.stopEarly(model, grid.getScoringInfos())) continue;
                Log.info("Convergence detected based on simple moving average of the loss function. Grid building completed.");
                break;
            }
            Log.info("For grid: " + grid._key + " built: " + grid.getModelCount() + " models.");
        }
        finally {
            grid.unlock(this._job);
        }
    }

    private Model buildModel(MP params, Grid<MP> grid, int paramsIdx, String protoModelKey) {
        Key[] modelKeys;
        final long checksum = ((Model.Parameters)params).checksum();
        Key<Model> key = grid.getModelKey(checksum);
        if (key != null) {
            if (DKV.get(key) == null) {
                Log.info("GridSearch.buildModel(): model with these parameters was built but removed, rebuilding; checksum: " + checksum);
            } else {
                Log.info("GridSearch.buildModel(): model with these parameters already exists, skipping; checksum: " + checksum);
                return key.get();
            }
        }
        if ((modelKeys = KeySnapshot.globalSnapshot().filter(new KeySnapshot.KVFilter(){

            @Override
            public boolean filter(KeySnapshot.KeyInfo k) {
                return Value.isSubclassOf(k._type, Model.class) && ((Model.Parameters)((Model)k._key.get())._parms).checksum() == checksum;
            }
        }).keys()).length > 0) {
            grid.putModel(checksum, modelKeys[0]);
            return (Model)modelKeys[0].get();
        }
        Key<Model> result = Key.make(protoModelKey + paramsIdx);
        Model m = (Model)this.startBuildModel(result, params, grid).dest().get();
        grid.putModel(checksum, result);
        return m;
    }

    private ModelBuilder startBuildModel(Key result, MP params, Grid<MP> grid) {
        if (grid.getModel(params) != null) {
            return null;
        }
        Object mb = ModelBuilder.make(((Model.Parameters)params).algoName(), this._job, result);
        ((ModelBuilder)mb)._parms = params;
        ((ModelBuilder)mb).trainModelNested(null);
        return mb;
    }

    protected static Key<Grid> gridKeyName(String modelName, Frame fr) {
        if (fr == null || fr._key == null) {
            throw new IllegalArgumentException("The frame being grid-searched over must have a Key");
        }
        return Key.make("Grid_" + modelName + "_" + fr._key.toString() + H2O.calcNextUniqueModelId(""));
    }

    public static <MP extends Model.Parameters> Job<Grid> startGridSearch(Key<Grid> destKey, MP params, Map<String, Object[]> hyperParams, ModelParametersBuilderFactory<MP> paramsBuilderFactory, HyperSpaceSearchCriteria search_criteria) {
        return GridSearch.startGridSearch(destKey, HyperSpaceWalker.BaseWalker.WalkerFactory.create(params, hyperParams, paramsBuilderFactory, search_criteria));
    }

    public static <MP extends Model.Parameters> Job<Grid> startGridSearch(Key<Grid> destKey, MP params, Map<String, Object[]> hyperParams) {
        return GridSearch.startGridSearch(destKey, params, hyperParams, new SimpleParametersBuilderFactory(), new HyperSpaceSearchCriteria.CartesianSearchCriteria());
    }

    public static <MP extends Model.Parameters> Job<Grid> startGridSearch(Key<Grid> destKey, HyperSpaceWalker<MP, ?> hyperSpaceWalker) {
        MP params = hyperSpaceWalker.getParams();
        Key<Grid> gridKey = destKey != null ? destKey : GridSearch.gridKeyName(((Model.Parameters)params).algoName(), ((Model.Parameters)params).train());
        return new GridSearch<MP>(gridKey, hyperSpaceWalker).start();
    }

    public static class SimpleParametersBuilderFactory<MP extends Model.Parameters>
    implements ModelParametersBuilderFactory<MP> {
        @Override
        public ModelParametersBuilderFactory.ModelParametersBuilder<MP> get(MP initialParams) {
            return new SimpleParamsBuilder<MP>(initialParams);
        }

        @Override
        public PojoUtils.FieldNaming getFieldNamingStrategy() {
            return PojoUtils.FieldNaming.CONSISTENT;
        }

        public static class SimpleParamsBuilder<MP extends Model.Parameters>
        implements ModelParametersBuilderFactory.ModelParametersBuilder<MP> {
            private final MP params;

            public SimpleParamsBuilder(MP initialParams) {
                this.params = initialParams;
            }

            @Override
            public ModelParametersBuilderFactory.ModelParametersBuilder<MP> set(String name, Object value) {
                PojoUtils.setField(this.params, name, value, PojoUtils.FieldNaming.CONSISTENT);
                return this;
            }

            @Override
            public MP build() {
                return this.params;
            }
        }
    }
}

