/*
 * Decompiled with CFR 0.152.
 */
package hex.tree.xgboost;

import hex.CVModelBuilder;
import hex.ModelBuilder;
import hex.tree.xgboost.XGBoost;
import hex.tree.xgboost.XGBoostModel;
import hex.tree.xgboost.util.GpuUtils;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;
import org.apache.log4j.Logger;
import water.Job;

public class XGBoostGPUCVModelBuilder
extends CVModelBuilder {
    private static final Logger LOG = Logger.getLogger(XGBoostGPUCVModelBuilder.class);
    private final GPUAllocator _allocator;

    public XGBoostGPUCVModelBuilder(Job<?> job, ModelBuilder<?, ?, ?>[] modelBuilders, int parallelization, int[] gpuIds) {
        super(job, modelBuilders, parallelization);
        LinkedList<Integer> availableGpus;
        if (gpuIds != null && gpuIds.length > 0) {
            availableGpus = new LinkedList();
            for (int id : gpuIds) {
                availableGpus.add(id);
            }
        } else {
            availableGpus = new LinkedList<Integer>(GpuUtils.allGPUs());
        }
        LOG.info((Object)("Available #GPUs for CV model training: " + availableGpus.size()));
        this._allocator = new GPUAllocator(availableGpus);
    }

    protected void prepare(ModelBuilder<?, ?, ?> m) {
        XGBoost xgb = (XGBoost)m;
        ((XGBoostModel.XGBoostParameters)xgb._parms)._gpu_id = new int[]{this._allocator.takeLeastUtilizedGPU()};
        LOG.info((Object)("Building " + xgb.dest() + " on GPU " + ((XGBoostModel.XGBoostParameters)xgb._parms)._gpu_id[0]));
    }

    protected void finished(ModelBuilder<?, ?, ?> m) {
        XGBoost xgb = (XGBoost)m;
        this._allocator.releaseGPU(((XGBoostModel.XGBoostParameters)xgb._parms)._gpu_id[0]);
    }

    static class GPUAllocator {
        final int[] _gpu_utilization;

        GPUAllocator(List<Integer> gpuIds) {
            this(GPUAllocator.initUtilization(gpuIds));
        }

        GPUAllocator(int[] gpuUtilization) {
            this._gpu_utilization = gpuUtilization;
        }

        static int[] initUtilization(List<Integer> gpus) {
            int maxGpuId = (Integer)gpus.stream().max(Integer::compareTo).orElseThrow(() -> new IllegalStateException("There are no GPUs available for XGBoost (" + gpus + ")."));
            int[] utilization = new int[maxGpuId + 1];
            Arrays.fill(utilization, -1);
            gpus.forEach(id -> {
                utilization[id.intValue()] = 0;
            });
            return utilization;
        }

        void releaseGPU(int id) {
            int n = id;
            this._gpu_utilization[n] = this._gpu_utilization[n] - 1;
        }

        int takeLeastUtilizedGPU() {
            int id = -1;
            for (int i = 0; i < this._gpu_utilization.length; ++i) {
                if (this._gpu_utilization[i] == -1 || id != -1 && this._gpu_utilization[i] >= this._gpu_utilization[id]) continue;
                id = i;
            }
            assert (id != -1);
            int n = id;
            this._gpu_utilization[n] = this._gpu_utilization[n] + 1;
            return id;
        }
    }
}

