/*
 * Decompiled with CFR 0.152.
 */
package hex.tree.xgboost.task;

import ai.h2o.xgboost4j.java.DMatrix;
import ai.h2o.xgboost4j.java.XGBoostError;
import hex.tree.xgboost.BoosterParms;
import hex.tree.xgboost.matrix.MatrixLoader;
import hex.tree.xgboost.task.AbstractXGBoostTask;
import hex.tree.xgboost.task.XGBoostUpdater;
import java.io.File;
import java.util.Map;
import org.apache.log4j.Logger;
import water.H2O;
import water.Key;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.IcedHashMapGeneric;

public class XGBoostSetupTask
extends AbstractXGBoostTask<XGBoostSetupTask> {
    private static final Logger LOG = Logger.getLogger(XGBoostSetupTask.class);
    private final BoosterParms _boosterParms;
    private final byte[] _checkpoint;
    private final IcedHashMapGeneric.IcedHashMapStringString _rabitEnv;
    private final MatrixLoader _matrixLoader;
    private final String _saveMatrixDirectory;

    public XGBoostSetupTask(Key modelKey, String saveMatrixDirectory, BoosterParms boosterParms, byte[] checkpointToResume, Map<String, String> rabitEnv, boolean[] nodes, MatrixLoader matrixLoader) {
        super(modelKey, nodes);
        this._boosterParms = boosterParms;
        this._checkpoint = checkpointToResume;
        this._matrixLoader = matrixLoader;
        this._saveMatrixDirectory = saveMatrixDirectory;
        this._rabitEnv = new IcedHashMapGeneric.IcedHashMapStringString();
        this._rabitEnv.putAll(rabitEnv);
    }

    @Override
    protected void execute() {
        DMatrix matrix;
        try {
            matrix = this._matrixLoader.makeLocalMatrix().get();
        }
        catch (XGBoostError e2) {
            throw new IllegalStateException("Failed to create XGBoost DMatrix", e2);
        }
        if (this._saveMatrixDirectory != null) {
            File directory = new File(this._saveMatrixDirectory);
            if (directory.mkdirs()) {
                LOG.debug((Object)("Created directory for matrix export: " + directory.getAbsolutePath()));
            }
            File path = new File(directory, "matrix.part" + H2O.SELF.index());
            LOG.info((Object)("Saving node-local portion of XGBoost training dataset to " + path.getAbsolutePath() + "."));
            matrix.saveBinary(path.getAbsolutePath());
        }
        this._rabitEnv.put("DMLC_TASK_ID", String.valueOf(H2O.SELF.index()));
        XGBoostUpdater thread2 = XGBoostUpdater.make(this._modelKey, matrix, this._boosterParms, this._checkpoint, this._rabitEnv);
        thread2.start();
    }

    public static FrameNodes findFrameNodes(Frame fr) {
        boolean[] nodesHoldingFrame = new boolean[H2O.CLOUD.size()];
        Vec vec = fr.anyVec();
        for (int chunkNr = 0; chunkNr < vec.nChunks(); ++chunkNr) {
            int home = vec.chunkKey(chunkNr).home_node().index();
            if (nodesHoldingFrame[home]) continue;
            nodesHoldingFrame[home] = true;
        }
        return new FrameNodes(fr, nodesHoldingFrame);
    }

    public static class FrameNodes {
        public final Frame _fr;
        public final boolean[] _nodes;
        public final int _numNodes;

        private FrameNodes(Frame fr, boolean[] nodes) {
            this._fr = fr;
            this._nodes = nodes;
            int n2 = 0;
            for (boolean f2 : this._nodes) {
                if (!f2) continue;
                ++n2;
            }
            this._numNodes = n2;
        }

        public int getNumNodes() {
            return this._numNodes;
        }
    }
}

