/*
 * Decompiled with CFR 0.152.
 */
package hex.tree.xgboost.util;

import ai.h2o.xgboost4j.java.DMatrix;
import ai.h2o.xgboost4j.java.INativeLibLoader;
import ai.h2o.xgboost4j.java.NativeLibLoader;
import ai.h2o.xgboost4j.java.Rabit;
import ai.h2o.xgboost4j.java.XGBoost;
import ai.h2o.xgboost4j.java.XGBoostError;
import hex.tree.xgboost.util.NativeLibrary;
import hex.tree.xgboost.util.NativeLibraryLoaderChain;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Set;
import org.apache.log4j.Logger;
import water.DTask;
import water.H2O;
import water.H2ONode;
import water.RPC;

public class GpuUtils {
    private static final Logger LOG = Logger.getLogger(GpuUtils.class);
    public static final int[] DEFAULT_GPU_ID = new int[]{0};
    private static volatile boolean defaultGpuIdNotValid = false;
    private static volatile boolean gpuSearchPerformed = false;
    private static final Set<Integer> GPUS = new HashSet<Integer>();

    static boolean isGpuSupportEnabled() {
        try {
            INativeLibLoader loader = NativeLibLoader.getLoader();
            if (!(loader instanceof NativeLibraryLoaderChain)) {
                return false;
            }
            NativeLibraryLoaderChain chainLoader = (NativeLibraryLoaderChain)loader;
            NativeLibrary lib = chainLoader.getLoadedLibrary();
            return lib.hasCompilationFlag(NativeLibrary.CompilationFlags.WITH_GPU);
        }
        catch (IOException e) {
            LOG.debug((Object)e);
            return false;
        }
    }

    private static boolean gpuCheckEnabled() {
        return H2O.getSysBoolProperty((String)"xgboost.gpu.check.enabled", (boolean)true);
    }

    public static int numGPUs(H2ONode node) {
        return GpuUtils.allGPUs(node).size();
    }

    public static Set<Integer> allGPUs(H2ONode node) {
        if (H2O.SELF.equals((Object)node)) {
            return GpuUtils.allGPUs();
        }
        AllGPUsTask t = new AllGPUsTask();
        new RPC(node, (DTask)t).call().get();
        return new HashSet<Integer>(Arrays.asList(t.gpuIds));
    }

    public static Set<Integer> allGPUs() {
        if (gpuSearchPerformed) {
            return Collections.unmodifiableSet(GPUS);
        }
        int nextGpuId = 0;
        while (GpuUtils.hasGPU(new int[]{nextGpuId++})) {
        }
        gpuSearchPerformed = true;
        return Collections.unmodifiableSet(GPUS);
    }

    public static boolean hasGPU(H2ONode node, int[] gpu_id) {
        boolean hasGPU;
        if (H2O.SELF.equals((Object)node)) {
            hasGPU = GpuUtils.hasGPU(gpu_id);
        } else {
            HasGPUTask t = new HasGPUTask(gpu_id);
            new RPC(node, (DTask)t).call().get();
            hasGPU = t._hasGPU;
        }
        LOG.debug((Object)("Availability of GPU (id=" + Arrays.toString(gpu_id) + ") on node " + node + ": " + hasGPU));
        return hasGPU;
    }

    public static boolean hasGPU(int[] gpu_id) {
        if (!GpuUtils.gpuCheckEnabled()) {
            return true;
        }
        if (gpu_id == null && defaultGpuIdNotValid) {
            return false;
        }
        boolean hasGPU = true;
        if (gpu_id == null) {
            gpu_id = DEFAULT_GPU_ID;
        }
        for (int i = 0; hasGPU && i < gpu_id.length; ++i) {
            hasGPU = GpuUtils.hasGPU_impl(gpu_id[i]);
        }
        if (Arrays.equals(gpu_id, DEFAULT_GPU_ID) && !hasGPU) {
            defaultGpuIdNotValid = true;
        }
        return hasGPU;
    }

    public static boolean hasGPU() {
        return GpuUtils.hasGPU(null);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private static synchronized boolean hasGPU_impl(int gpu_id) {
        DMatrix trainMat;
        if (!GpuUtils.isGpuSupportEnabled()) {
            return false;
        }
        if (GPUS.contains(gpu_id)) {
            return true;
        }
        try {
            trainMat = new DMatrix(new float[]{1.0f, 2.0f, 1.0f, 2.0f}, 2, 2);
            trainMat.setLabel(new float[]{1.0f, 0.0f});
        }
        catch (XGBoostError xgBoostError) {
            throw new IllegalStateException("Couldn't prepare training matrix for XGBoost.", xgBoostError);
        }
        HashMap<String, Object> params = new HashMap<String, Object>();
        params.put("tree_method", "gpu_hist");
        params.put("silent", 1);
        params.put("fail_on_invalid_gpu_id", true);
        params.put("gpu_id", gpu_id);
        HashMap<String, DMatrix> watches = new HashMap<String, DMatrix>();
        watches.put("train", trainMat);
        try {
            HashMap localRabitEnv = new HashMap();
            Rabit.init(localRabitEnv);
            XGBoost.train((DMatrix)trainMat, params, (int)1, watches, null, null);
            GPUS.add(gpu_id);
            boolean bl = true;
            return bl;
        }
        catch (XGBoostError xgBoostError) {
            boolean bl = false;
            return bl;
        }
        finally {
            try {
                Rabit.shutdown();
            }
            catch (XGBoostError e) {
                LOG.warn((Object)"Cannot shutdown XGBoost Rabit for current thread.");
            }
        }
    }

    private static class HasGPUTask
    extends DTask<HasGPUTask> {
        private final int[] _gpu_id;
        private boolean _hasGPU;

        private HasGPUTask(int[] gpu_id) {
            this._gpu_id = gpu_id;
        }

        public void compute2() {
            this._hasGPU = GpuUtils.hasGPU(this._gpu_id);
            this.tryComplete();
        }
    }

    private static class AllGPUsTask
    extends DTask<HasGPUTask> {
        private Integer[] gpuIds;

        private AllGPUsTask() {
        }

        public void compute2() {
            this.gpuIds = GpuUtils.allGPUs().toArray(new Integer[0]);
            this.tryComplete();
        }
    }
}

