/*
 * Decompiled with CFR 0.152.
 */
package hex.grid;

import hex.Model;
import hex.ModelCategory;
import hex.ModelMetrics;
import hex.ScoringInfo;
import java.io.IOException;
import java.io.OutputStream;
import java.lang.reflect.Array;
import java.net.URI;
import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;
import java.util.Objects;
import water.AutoBuffer;
import water.DKV;
import water.Futures;
import water.H2O;
import water.Iced;
import water.Job;
import water.Key;
import water.Keyed;
import water.Lockable;
import water.api.schemas3.KeyV3;
import water.fvec.Frame;
import water.persist.Persist;
import water.util.ArrayUtils;
import water.util.FileUtils;
import water.util.IcedHashMap;
import water.util.IcedLong;
import water.util.Log;
import water.util.PojoUtils;
import water.util.StringUtils;
import water.util.TwoDimTable;

public class Grid<MP extends Model.Parameters>
extends Lockable<Grid<MP>> {
    public static final Grid GRID_PROTO = new Grid<Object>(null, null, null, null);
    private final IcedHashMap<IcedLong, Key<Model>> _models = new IcedHashMap();
    private final IcedHashMap<Key<Model>, SearchFailure> _failures;
    private final MP _params;
    private final String[] _hyper_names;
    private final PojoUtils.FieldNaming _field_naming_strategy;
    private ScoringInfo[] _scoring_infos = null;
    private static final Key<Model> NO_MODEL_FAILURES_KEY = Key.makeUserHidden("GridSearchFailureEmptyModelKey");

    protected Grid(Key key, MP params, String[] hyperNames, PojoUtils.FieldNaming fieldNaming) {
        super(key);
        this._params = params != null ? (Model.Parameters)((Iced)params).clone() : null;
        this._hyper_names = hyperNames;
        this._field_naming_strategy = fieldNaming;
        this._failures = new IcedHashMap();
    }

    public String getModelName() {
        return ((Model.Parameters)this._params).algoName();
    }

    public ScoringInfo[] getScoringInfos() {
        return this._scoring_infos;
    }

    public void setScoringInfos(ScoringInfo[] scoring_infos) {
        this._scoring_infos = scoring_infos;
    }

    public Frame getTrainingFrame() {
        return ((Model.Parameters)this._params).train();
    }

    public Model getModel(MP params) {
        Key<Model> mKey = this.getModelKey(params);
        return mKey != null ? mKey.get() : null;
    }

    public Key<Model> getModelKey(MP params) {
        long checksum = ((Model.Parameters)params).checksum();
        return this.getModelKey(checksum);
    }

    Key<Model> getModelKey(long paramsChecksum) {
        Key mKey = (Key)this._models.get(IcedLong.valueOf(paramsChecksum));
        return mKey;
    }

    synchronized Key<Model> putModel(long checksum, Key<Model> modelKey) {
        return this._models.put(IcedLong.valueOf(checksum), modelKey);
    }

    private void appendFailedModelParameters(Key<Model> modelKey, MP params, String[] rawParams, Throwable t) {
        String failureDetails = Grid.isJobCanceled(t) ? "Job Canceled" : t.getMessage();
        String stackTrace = StringUtils.toString(t);
        Key<Model> searchedKey = modelKey != null ? modelKey : NO_MODEL_FAILURES_KEY;
        SearchFailure searchFailure = (SearchFailure)this._failures.get(searchedKey);
        if (searchFailure == null) {
            searchFailure = new SearchFailure(this._params.getClass());
            this._failures.put(searchedKey, searchFailure);
        }
        searchFailure.appendFailedModelParameters(params, rawParams, failureDetails, stackTrace);
    }

    private static boolean isJobCanceled(Throwable t) {
        for (Throwable ex = t; ex != null; ex = ex.getCause()) {
            if (!(ex instanceof Job.JobCancelledException)) continue;
            return true;
        }
        return false;
    }

    void appendFailedModelParameters(Key<Model> modelKey, MP params, Throwable t) {
        assert (params != null) : "Model parameters should be always != null !";
        String[] rawParams = ArrayUtils.toString(this.getHyperValues(params));
        this.appendFailedModelParameters(modelKey, params, rawParams, t);
    }

    void appendFailedModelParameters(Key<Model> modelKey, Object[] rawParams, Exception e) {
        assert (rawParams != null) : "Raw parameters should be always != null !";
        this.appendFailedModelParameters(modelKey, null, ArrayUtils.toString(rawParams), e);
    }

    public Key<Model>[] getModelKeys() {
        Object[] keys = this._models.values().toArray(new Key[this._models.size()]);
        Arrays.sort(keys);
        return keys;
    }

    public Model[] getModels() {
        Collection modelKeys = this._models.values();
        Model[] models = new Model[modelKeys.size()];
        int i = 0;
        for (Key mKey : modelKeys) {
            models[i] = mKey != null ? (Model)mKey.get() : null;
            ++i;
        }
        return models;
    }

    public int getModelCount() {
        return this._models.size();
    }

    public SearchFailure getFailures() {
        Collection values = this._failures.values();
        SearchFailure searchFailure = new SearchFailure(this._params != null ? this._params.getClass() : null);
        for (SearchFailure f : values) {
            searchFailure.appendFailedModelParameters(f._failed_params, f._failed_raw_params, f._failure_details, f._failure_stack_traces);
        }
        return searchFailure;
    }

    protected void clearNonRelatedFailures() {
        this._failures.remove(NO_MODEL_FAILURES_KEY);
    }

    public Object[] getHyperValues(MP parms) {
        Object[] result = new Object[this._hyper_names.length];
        for (int i = 0; i < this._hyper_names.length; ++i) {
            result[i] = PojoUtils.getFieldValue(parms, this._hyper_names[i], this._field_naming_strategy);
        }
        return result;
    }

    public String[] getHyperNames() {
        return this._hyper_names;
    }

    @Override
    protected Futures remove_impl(Futures fs, boolean cascade) {
        if (cascade) {
            for (Key k : this._models.values()) {
                Keyed.remove(k, fs, true);
            }
        }
        this._models.clear();
        return super.remove_impl(fs, cascade);
    }

    @Override
    protected AutoBuffer writeAll_impl(AutoBuffer ab) {
        for (Key k : this._models.values()) {
            ab.putKey(k);
        }
        return super.writeAll_impl(ab);
    }

    protected AutoBuffer writeWithoutModels(AutoBuffer autoBuffer) {
        return super.writeAll_impl(autoBuffer.put(this));
    }

    @Override
    protected Keyed readAll_impl(AutoBuffer ab, Futures fs) {
        throw H2O.unimpl();
    }

    @Override
    protected long checksum_impl() {
        throw H2O.unimpl();
    }

    @Override
    public Class<KeyV3.GridKeyV3> makeSchema() {
        return KeyV3.GridKeyV3.class;
    }

    public TwoDimTable createSummaryTable(Key<Model>[] model_ids, String sort_by, boolean decreasing) {
        if (this._hyper_names == null || model_ids == null || model_ids.length == 0) {
            return null;
        }
        int extra_len = sort_by != null ? 2 : 1;
        Object[] colTypes = new String[this._hyper_names.length + extra_len];
        Arrays.fill(colTypes, "string");
        Object[] colFormats = new String[this._hyper_names.length + extra_len];
        Arrays.fill(colFormats, "%s");
        String[] colNames = Arrays.copyOf(this._hyper_names, this._hyper_names.length + extra_len);
        colNames[this._hyper_names.length] = "model_ids";
        if (sort_by != null) {
            colNames[this._hyper_names.length + 1] = sort_by;
        }
        TwoDimTable table = new TwoDimTable("Hyper-Parameter Search Summary", sort_by != null ? "ordered by " + (decreasing ? "decreasing " : "increasing ") + sort_by : null, new String[this._models.size()], colNames, (String[])colTypes, (String[])colFormats, "");
        int i = 0;
        for (Key<Model> km : model_ids) {
            int j;
            Model m = (Model)DKV.getGet(km);
            Object parms = m._parms;
            for (j = 0; j < this._hyper_names.length; ++j) {
                table.set(i, j, PojoUtils.getFieldValue(parms, this._hyper_names[j], this._field_naming_strategy));
            }
            table.set(i, j, km.toString());
            if (sort_by != null) {
                table.set(i, j + 1, ModelMetrics.getMetricFromModel(km, sort_by));
            }
            ++i;
        }
        Log.info(table);
        return table;
    }

    public TwoDimTable createScoringHistoryTable() {
        Model m;
        if (0 == this._models.values().size()) {
            return ScoringInfo.createScoringHistoryTable(this._scoring_infos, false, false, ModelCategory.Binomial, false);
        }
        Key k = null;
        Iterator iterator = this._models.values().iterator();
        if (iterator.hasNext()) {
            Key foo;
            k = foo = (Key)iterator.next();
        }
        if (null == (m = (Model)k.get())) {
            Log.warn("Cannot create grid scoring history table; Model has been removed: " + k);
            return ScoringInfo.createScoringHistoryTable(this._scoring_infos, false, false, ModelCategory.Binomial, false);
        }
        ScoringInfo scoring_info = this._scoring_infos != null && this._scoring_infos.length > 0 ? this._scoring_infos[0] : null;
        return ScoringInfo.createScoringHistoryTable(this._scoring_infos, scoring_info != null ? scoring_info.validation : false, scoring_info != null ? scoring_info.cross_validation : false, ((Model.Output)m._output).getModelCategory(), scoring_info != null ? scoring_info.is_autoencoder : false);
    }

    public void exportBinary(String gridExportDir) throws IOException {
        Objects.requireNonNull(gridExportDir);
        String gridFilePath = gridExportDir + "/" + this._key.toString();
        assert (this._key != null);
        URI gridUri = FileUtils.getURI(gridFilePath);
        Persist persist = H2O.getPM().getPersistForURI(gridUri);
        try (OutputStream outputStream = persist.create(gridUri.toString(), true);){
            AutoBuffer autoBuffer = new AutoBuffer(outputStream, true);
            this.writeWithoutModels(autoBuffer);
            autoBuffer.close();
        }
    }

    public void exportModelsBinary(String exportDir) throws IOException {
        Objects.requireNonNull(exportDir);
        for (Model model : this.getModels()) {
            model.exportBinaryModel(exportDir + "/" + model._key.toString(), true);
        }
    }

    public MP getParams() {
        return this._params;
    }

    public static final class SearchFailure<MP extends Model.Parameters>
    extends Iced<SearchFailure> {
        private MP[] _failed_params;
        private String[] _failure_details;
        private String[] _failure_stack_traces;
        private String[][] _failed_raw_params;

        private SearchFailure(Class<MP> paramsClass) {
            this._failed_params = paramsClass != null ? (Model.Parameters[])Array.newInstance(paramsClass, 0) : null;
            this._failure_details = new String[0];
            this._failed_raw_params = new String[0][];
            this._failure_stack_traces = new String[0];
        }

        private void appendFailedModelParameters(MP params, String[] rawParams, String failureDetails, String stackTrace) {
            assert (rawParams != null) : "API has to always pass rawParams";
            MP[] a = this._failed_params;
            Model.Parameters[] na = (Model.Parameters[])Arrays.copyOf(a, a.length + 1);
            na[a.length] = params;
            this._failed_params = na;
            String[] m = this._failure_details;
            String[] nm = Arrays.copyOf(m, m.length + 1);
            nm[m.length] = failureDetails;
            this._failure_details = nm;
            String[][] rp = this._failed_raw_params;
            String[][] nrp = (String[][])Arrays.copyOf(rp, rp.length + 1);
            nrp[rp.length] = rawParams;
            this._failed_raw_params = nrp;
            String[] st = this._failure_stack_traces;
            String[] nst = Arrays.copyOf(st, st.length + 1);
            nst[st.length] = stackTrace;
            this._failure_stack_traces = nst;
        }

        public void appendFailedModelParameters(MP[] params, String[][] rawParams, String[] failureDetails, String[] stackTraces) {
            assert (rawParams != null) : "API has to always pass rawParams";
            this._failed_params = (Model.Parameters[])ArrayUtils.append(this._failed_params, params);
            this._failed_raw_params = (String[][])ArrayUtils.append(this._failed_raw_params, rawParams);
            this._failure_details = ArrayUtils.append(this._failure_details, failureDetails);
            this._failure_stack_traces = ArrayUtils.append(this._failure_stack_traces, stackTraces);
        }

        void appendFailedModelParameters(Object[] rawParams, Exception e) {
            assert (rawParams != null) : "Raw parameters should be always != null !";
            this.appendFailedModelParameters(null, ArrayUtils.toString(rawParams), e.getMessage(), StringUtils.toString(e));
        }

        public Model.Parameters[] getFailedParameters() {
            return this._failed_params;
        }

        public String[] getFailureDetails() {
            return this._failure_details;
        }

        public String[] getFailureStackTraces() {
            return this._failure_stack_traces;
        }

        public String[][] getFailedRawParameters() {
            return this._failed_raw_params;
        }

        public int getFailureCount() {
            return this._failed_params.length;
        }
    }
}

