/*
 * Decompiled with CFR 0.152.
 */
package ml.dmlc.xgboost4j.java;

import hex.tree.xgboost.BoosterParms;
import hex.tree.xgboost.XGBoostModel;
import hex.tree.xgboost.XGBoostOutput;
import hex.tree.xgboost.XGBoostUtils;
import java.io.File;
import java.util.Map;
import ml.dmlc.xgboost4j.java.AbstractXGBoostTask;
import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.XGBoostError;
import ml.dmlc.xgboost4j.java.XGBoostModelInfo;
import ml.dmlc.xgboost4j.java.XGBoostUpdater;
import water.H2O;
import water.Key;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.IcedHashMapGeneric;
import water.util.Log;

public class XGBoostSetupTask
extends AbstractXGBoostTask<XGBoostSetupTask> {
    private final XGBoostModelInfo _sharedModel;
    private final XGBoostModel.XGBoostParameters _parms;
    private final boolean _sparse;
    private final BoosterParms _boosterParms;
    private final IcedHashMapGeneric.IcedHashMapStringString _rabitEnv;
    private final Frame _trainFrame;

    public XGBoostSetupTask(XGBoostModel model, XGBoostModel.XGBoostParameters parms, BoosterParms boosterParms, Map<String, String> rabitEnv, FrameNodes trainFrame) {
        super((Key<XGBoostModel>)model._key, trainFrame._nodes);
        this._sharedModel = model.model_info();
        this._parms = parms;
        this._sparse = ((XGBoostOutput)model._output)._sparse;
        this._boosterParms = boosterParms;
        this._rabitEnv = new IcedHashMapGeneric.IcedHashMapStringString();
        this._rabitEnv.putAll(rabitEnv);
        this._trainFrame = trainFrame._fr;
    }

    @Override
    protected void execute() {
        DMatrix matrix;
        try {
            matrix = this.makeLocalMatrix();
            if (this._parms._save_matrix_directory != null) {
                File directory = new File(this._parms._save_matrix_directory);
                if (directory.mkdirs()) {
                    Log.debug((Object[])new Object[]{"Created directory for matrix export: " + directory.getAbsolutePath()});
                }
                File path = new File(directory, "matrix.part" + H2O.SELF.index());
                Log.info((Object[])new Object[]{"Saving node-local portion of XGBoost training dataset to " + path.getAbsolutePath() + "."});
                matrix.saveBinary(path.getAbsolutePath());
            }
        }
        catch (XGBoostError xgBoostError) {
            throw new IllegalStateException("Failed XGBoost training.", xgBoostError);
        }
        if (matrix == null) {
            throw new IllegalStateException("Node " + H2O.SELF + " is supposed to participate in XGB training but it doesn't have a DMatrix!");
        }
        this._rabitEnv.put((Object)"DMLC_TASK_ID", (Object)String.valueOf(H2O.SELF.index()));
        XGBoostUpdater thread = XGBoostUpdater.make((Key<XGBoostModel>)this._modelKey, matrix, this._boosterParms, (Map<String, String>)this._rabitEnv);
        thread.start();
    }

    private DMatrix makeLocalMatrix() throws XGBoostError {
        return XGBoostUtils.convertFrameToDMatrix(this._sharedModel.dataInfo(), this._trainFrame, this._parms._response_column, this._parms._weights_column, this._sparse);
    }

    public static FrameNodes findFrameNodes(Frame fr) {
        boolean[] nodesHoldingFrame = new boolean[H2O.CLOUD.size()];
        Vec vec = fr.anyVec();
        for (int chunkNr = 0; chunkNr < vec.nChunks(); ++chunkNr) {
            int home = vec.chunkKey(chunkNr).home_node().index();
            if (nodesHoldingFrame[home]) continue;
            nodesHoldingFrame[home] = true;
        }
        return new FrameNodes(fr, nodesHoldingFrame);
    }

    public static class FrameNodes {
        final Frame _fr;
        final boolean[] _nodes;
        final int _numNodes;

        private FrameNodes(Frame fr, boolean[] nodes) {
            this._fr = fr;
            this._nodes = nodes;
            int n = 0;
            for (boolean f : this._nodes) {
                if (!f) continue;
                ++n;
            }
            this._numNodes = n;
        }

        public int getNumNodes() {
            return this._numNodes;
        }
    }
}

