/*
 * Decompiled with CFR 0.152.
 */
package hex.tree.xgboost;

import hex.CVModelBuilder;
import hex.ModelBuilder;
import hex.tree.xgboost.XGBoost;
import hex.tree.xgboost.XGBoostModel;
import hex.tree.xgboost.util.GpuUtils;
import java.util.Deque;
import java.util.LinkedList;
import org.apache.log4j.Logger;
import water.Job;

public class XGBoostGPUCVModelBuilder
extends CVModelBuilder {
    private static final Logger LOG = Logger.getLogger(XGBoostGPUCVModelBuilder.class);
    private final Deque<Integer> availableGpus;

    public XGBoostGPUCVModelBuilder(Job job, ModelBuilder<?, ?, ?>[] modelBuilders, int parallelization, int[] gpuIds) {
        super(job, modelBuilders, parallelization);
        if (gpuIds != null && gpuIds.length > 0) {
            this.availableGpus = new LinkedList<Integer>();
            for (int id : gpuIds) {
                this.availableGpus.add(id);
            }
        } else {
            this.availableGpus = new LinkedList<Integer>(GpuUtils.allGPUs());
        }
        LOG.info((Object)("Using parallel GPU building on " + this.availableGpus.size() + " GPUs."));
    }

    @Override
    protected void prepare(ModelBuilder m4) {
        XGBoost xgb = (XGBoost)m4;
        ((XGBoostModel.XGBoostParameters)xgb._parms)._gpu_id = new int[]{this.availableGpus.pop()};
        LOG.info((Object)("Building " + xgb.dest() + " on GPU " + ((XGBoostModel.XGBoostParameters)xgb._parms)._gpu_id[0]));
    }

    @Override
    protected void finished(ModelBuilder m4) {
        XGBoost xgb = (XGBoost)m4;
        this.availableGpus.push(((XGBoostModel.XGBoostParameters)xgb._parms)._gpu_id[0]);
    }
}

