/*
 * Decompiled with CFR 0.152.
 */
package hex.tree.xgboost.exec;

import hex.DataInfo;
import hex.genmodel.utils.IOUtils;
import hex.tree.xgboost.BoosterParms;
import hex.tree.xgboost.XGBoostModel;
import hex.tree.xgboost.XGBoostOutput;
import hex.tree.xgboost.exec.XGBoostExecReq;
import hex.tree.xgboost.exec.XGBoostExecutor;
import hex.tree.xgboost.matrix.FrameMatrixLoader;
import hex.tree.xgboost.matrix.MatrixLoader;
import hex.tree.xgboost.matrix.RemoteMatrixLoader;
import hex.tree.xgboost.rabit.RabitTrackerH2O;
import hex.tree.xgboost.remote.RemoteXGBoostUploadServlet;
import hex.tree.xgboost.task.XGBoostCleanupTask;
import hex.tree.xgboost.task.XGBoostSetupTask;
import hex.tree.xgboost.task.XGBoostUpdateTask;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.HashMap;
import java.util.Map;
import water.H2O;
import water.Key;
import water.fvec.Frame;

public class LocalXGBoostExecutor
implements XGBoostExecutor {
    public final Key modelKey;
    private final BoosterParms boosterParams;
    private final MatrixLoader loader;
    private final CheckpointProvider checkpointProvider;
    private final boolean[] nodes;
    private final String saveMatrixDirectory;
    private final RabitTrackerH2O rt;
    private XGBoostSetupTask setupTask;
    private XGBoostUpdateTask updateTask;

    public LocalXGBoostExecutor(Key key, XGBoostExecReq.Init init) {
        this.modelKey = key;
        this.rt = this.setupRabitTracker(init.num_nodes);
        this.boosterParams = BoosterParms.fromMap((Map<String, Object>)init.parms);
        this.nodes = new boolean[H2O.CLOUD.size()];
        for (int i = 0; i < init.num_nodes; ++i) {
            this.nodes[i] = init.nodes[i] != null;
        }
        this.loader = new RemoteMatrixLoader(this.modelKey);
        this.saveMatrixDirectory = init.save_matrix_path;
        this.checkpointProvider = () -> {
            if (!init.has_checkpoint) {
                return null;
            }
            File checkpointFile = RemoteXGBoostUploadServlet.getCheckpointFile(this.modelKey.toString());
            ByteArrayOutputStream bos = new ByteArrayOutputStream();
            try (FileInputStream fis = new FileInputStream(checkpointFile);){
                IOUtils.copyStream((InputStream)fis, (OutputStream)bos);
            }
            catch (IOException e) {
                throw new RuntimeException("Failed writing data to response.", e);
            }
            finally {
                checkpointFile.delete();
            }
            return bos.toByteArray();
        };
    }

    public LocalXGBoostExecutor(XGBoostModel model, Frame train) {
        this.modelKey = model._key;
        XGBoostSetupTask.FrameNodes trainFrameNodes = XGBoostSetupTask.findFrameNodes(train);
        this.rt = this.setupRabitTracker(trainFrameNodes.getNumNodes());
        DataInfo dataInfo = model.model_info().dataInfo();
        this.boosterParams = XGBoostModel.createParams((XGBoostModel.XGBoostParameters)model._parms, ((XGBoostOutput)model._output).nclasses(), dataInfo.coefNames());
        ((XGBoostOutput)model._output)._native_parameters = this.boosterParams.toTwoDimTable();
        this.loader = new FrameMatrixLoader(model, train);
        this.nodes = trainFrameNodes._nodes;
        this.saveMatrixDirectory = ((XGBoostModel.XGBoostParameters)model._parms)._save_matrix_directory;
        this.checkpointProvider = () -> {
            if (((XGBoostModel.XGBoostParameters)model._parms).hasCheckpoint()) {
                return model.model_info()._boosterBytes;
            }
            return null;
        };
    }

    @Override
    public byte[] setup() {
        this.setupTask = new XGBoostSetupTask(this.modelKey, this.saveMatrixDirectory, this.boosterParams, this.checkpointProvider.get(), this.getRabitEnv(), this.nodes, this.loader);
        this.setupTask.run();
        this.updateTask = (XGBoostUpdateTask)((Object)new XGBoostUpdateTask(this.setupTask, 0).run());
        return this.updateTask.getBoosterBytes();
    }

    private RabitTrackerH2O setupRabitTracker(int numNodes) {
        if (numNodes > 1) {
            RabitTrackerH2O rt = new RabitTrackerH2O(numNodes);
            rt.start(0L);
            return rt;
        }
        return null;
    }

    private void stopRabitTracker() {
        if (this.rt != null) {
            this.rt.waitFor(0L);
            this.rt.stop();
        }
    }

    private Map<String, String> getRabitEnv() {
        if (this.rt != null) {
            return this.rt.getWorkerEnvs();
        }
        return new HashMap<String, String>();
    }

    @Override
    public void update(int treeId) {
        this.updateTask = new XGBoostUpdateTask(this.setupTask, treeId);
        this.updateTask.run();
    }

    @Override
    public byte[] updateBooster() {
        if (this.updateTask != null) {
            byte[] booster = this.updateTask.getBoosterBytes();
            this.updateTask = null;
            return booster;
        }
        return null;
    }

    @Override
    public void close() {
        XGBoostCleanupTask.cleanUp(this.setupTask);
        this.stopRabitTracker();
    }

    static interface CheckpointProvider {
        public byte[] get();
    }
}

