/*
 * Decompiled with CFR 0.152.
 */
package ml.dmlc.xgboost4j.scala.spark;

import java.io.Serializable;
import java.util.Map;
import ml.dmlc.xgboost4j.java.Communicator;
import ml.dmlc.xgboost4j.java.ConfigContext;
import ml.dmlc.xgboost4j.java.RabitTracker;
import ml.dmlc.xgboost4j.scala.Booster;
import ml.dmlc.xgboost4j.scala.DMatrix;
import ml.dmlc.xgboost4j.scala.EvalTrait;
import ml.dmlc.xgboost4j.scala.ObjectiveTrait;
import ml.dmlc.xgboost4j.scala.spark.RuntimeParams;
import ml.dmlc.xgboost4j.scala.spark.StageLevelScheduling;
import ml.dmlc.xgboost4j.scala.spark.TrackerConf;
import ml.dmlc.xgboost4j.scala.spark.Utils$;
import ml.dmlc.xgboost4j.scala.spark.Watches;
import ml.dmlc.xgboost4j.scala.spark.package$;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.spark.SparkConf;
import org.apache.spark.SparkContext;
import org.apache.spark.TaskContext;
import org.apache.spark.TaskContext$;
import org.apache.spark.rdd.RDD;
import org.apache.spark.rdd.RDDBarrier;
import org.apache.spark.resource.ResourceInformation;
import scala.Array$;
import scala.Function0;
import scala.Function1;
import scala.MatchError;
import scala.Predef;
import scala.Predef$;
import scala.Tuple2;
import scala.collection.GenIterable;
import scala.collection.Iterable$;
import scala.collection.Iterator;
import scala.collection.Seq;
import scala.collection.TraversableOnce;
import scala.collection.immutable.StringOps;
import scala.collection.mutable.ArrayOps;
import scala.jdk.CollectionConverters$;
import scala.math.Ordering;
import scala.reflect.ClassTag$;
import scala.runtime.BoxesRunTime;
import scala.runtime.ScalaRunTime$;

public final class XGBoost$
implements StageLevelScheduling {
    public static XGBoost$ MODULE$;
    private final Log logger;
    private final Log ml$dmlc$xgboost4j$scala$spark$StageLevelScheduling$$logger;

    static {
        new XGBoost$();
    }

    @Override
    public boolean isStandaloneOrLocalCluster(SparkConf conf) {
        return StageLevelScheduling.isStandaloneOrLocalCluster$(this, conf);
    }

    @Override
    public boolean skipStageLevelScheduling(String sparkVersion, boolean runOnGpu, SparkConf conf) {
        return StageLevelScheduling.skipStageLevelScheduling$(this, sparkVersion, runOnGpu, conf);
    }

    @Override
    public <T> RDD<T> tryStageLevelScheduling(SparkContext sc, RuntimeParams xgbExecParams, RDD<T> rdd) {
        return StageLevelScheduling.tryStageLevelScheduling$(this, sc, xgbExecParams, rdd);
    }

    @Override
    public Log ml$dmlc$xgboost4j$scala$spark$StageLevelScheduling$$logger() {
        return this.ml$dmlc$xgboost4j$scala$spark$StageLevelScheduling$$logger;
    }

    @Override
    public final void ml$dmlc$xgboost4j$scala$spark$StageLevelScheduling$_setter_$ml$dmlc$xgboost4j$scala$spark$StageLevelScheduling$$logger_$eq(Log x$1) {
        this.ml$dmlc$xgboost4j$scala$spark$StageLevelScheduling$$logger = x$1;
    }

    private Log logger() {
        return this.logger;
    }

    public int getGPUAddrFromResources() {
        TaskContext tc = TaskContext$.MODULE$.get();
        if (tc == null) {
            throw new RuntimeException("Something wrong for task context");
        }
        scala.collection.immutable.Map resources = tc.resources();
        if (resources.contains((Object)"gpu")) {
            String[] addrs = ((ResourceInformation)resources.apply((Object)"gpu")).addresses();
            if (new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])addrs)).size() > 1) {
                this.logger().warn((Object)"XGBoost only supports 1 gpu per worker");
            }
            return new StringOps(Predef$.MODULE$.augmentString((String)new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])addrs)).head())).toInt();
        }
        throw new RuntimeException("gpu is not allocated by spark, please check if gpu scheduling is enabled");
    }

    private Tuple2<Booster, float[][]> trainBooster(Watches watches, RuntimeParams runtimeParams, scala.collection.immutable.Map<String, Object> xgboostParams) {
        int numEarlyStoppingRounds = runtimeParams.earlyStoppingRounds();
        float[][] metrics = (float[][])Array$.MODULE$.tabulate(watches.size(), (Function1 & Serializable & scala.Serializable)x$1 -> XGBoost$.$anonfun$trainBooster$1(runtimeParams, BoxesRunTime.unboxToInt((Object)x$1)), ClassTag$.MODULE$.apply(ScalaRunTime$.MODULE$.arrayClass(Float.TYPE)));
        scala.collection.immutable.Map params = xgboostParams;
        if (runtimeParams.runOnGpu()) {
            int gpuId = runtimeParams.isLocal() ? TaskContext$.MODULE$.get().partitionId() % runtimeParams.numWorkers() : this.getGPUAddrFromResources();
            this.logger().info((Object)new StringBuilder(31).append("Leveraging gpu device ").append(gpuId).append(" to train").toString());
            params = params.$plus(Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)"device"), (Object)new StringBuilder(5).append("cuda:").append(gpuId).toString()));
        }
        Booster booster = ml.dmlc.xgboost4j.scala.XGBoost$.MODULE$.train((DMatrix)watches.toMap().apply((Object)"train"), params, runtimeParams.numRounds(), watches.toMap(), metrics, (ObjectiveTrait)runtimeParams.obj().getOrElse((Function0 & Serializable & scala.Serializable)() -> null), (EvalTrait)runtimeParams.eval().getOrElse((Function0 & Serializable & scala.Serializable)() -> null), numEarlyStoppingRounds, ml.dmlc.xgboost4j.scala.XGBoost$.MODULE$.train$default$9());
        return new Tuple2((Object)booster, (Object)metrics);
    }

    public Tuple2<Booster, scala.collection.immutable.Map<String, float[]>> train(RDD<Watches> input, RuntimeParams runtimeParams, scala.collection.immutable.Map<String, Object> xgboostParams) {
        Tuple2 tuple2;
        SparkContext sc = input.sparkContext();
        this.logger().info((Object)new StringBuilder(34).append("Running XGBoost ").append(package$.MODULE$.VERSION()).append(" with parameters: ").append(xgboostParams).toString());
        TrackerConf trackerConf = runtimeParams.trackerConf();
        RabitTracker tracker = new RabitTracker(runtimeParams.numWorkers(), trackerConf.hostIp(), trackerConf.port(), trackerConf.timeout());
        Predef$.MODULE$.require(tracker.start(), (Function0 & Serializable & scala.Serializable)() -> "FAULT: Failed to start tracker");
        try {
            try {
                Map<String, Object> rabitEnv = tracker.getWorkerArgs();
                RDDBarrier qual$1 = input.barrier();
                Function1 & Serializable & scala.Serializable x$1 = (Function1 & Serializable & scala.Serializable)iter -> {
                    Iterator iterator;
                    int partitionId = TaskContext$.MODULE$.getPartitionId();
                    rabitEnv.put("DMLC_TASK_ID", Integer.toString(partitionId));
                    try {
                        Communicator.init(rabitEnv);
                        Predef$.MODULE$.require(iter.hasNext(), (Function0 & Serializable & scala.Serializable)() -> "Failed to create DMatrix");
                        iterator = (Iterator)Utils$.MODULE$.withResource(new ConfigContext((Map)CollectionConverters$.MODULE$.mapAsJavaMapConverter(runtimeParams.configs()).asJava()), (Function1 & Serializable & scala.Serializable)x$2 -> {
                            Iterator iterator;
                            Watches watches = (Watches)iter.next();
                            try {
                                Tuple2<Booster, float[][]> tuple2 = MODULE$.trainBooster(watches, runtimeParams, xgboostParams);
                                if (tuple2 == null) {
                                    throw new MatchError(tuple2);
                                }
                                Booster booster = (Booster)tuple2._1();
                                float[][] metrics = (float[][])tuple2._2();
                                Tuple2 tuple22 = new Tuple2((Object)booster, (Object)metrics);
                                Booster booster2 = (Booster)tuple22._1();
                                float[][] metrics2 = (float[][])tuple22._2();
                                iterator = partitionId == 0 ? scala.package$.MODULE$.Iterator().apply((Seq)Predef$.MODULE$.wrapRefArray((Object[])new Tuple2[]{Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)booster2), (Object)((TraversableOnce)watches.toMap().keys().zip((GenIterable)Predef$.MODULE$.wrapRefArray((Object[])metrics2), Iterable$.MODULE$.canBuildFrom())).toMap(Predef$.MODULE$.$conforms()))})) : scala.package$.MODULE$.Iterator().empty();
                            }
                            finally {
                                if (watches != null) {
                                    watches.delete();
                                }
                            }
                            return iterator;
                        });
                    }
                    catch (Throwable throwable) {
                        try {
                            Communicator.shutdown();
                        }
                        catch (Throwable e) {
                            MODULE$.logger().error((Object)"Communicator.shutdown error: ", e);
                        }
                        throw throwable;
                    }
                    Iterator iterator2 = iterator;
                    try {
                        Communicator.shutdown();
                    }
                    catch (Throwable e) {
                        MODULE$.logger().error((Object)"Communicator.shutdown error: ", e);
                    }
                    return iterator2;
                };
                boolean x$2 = qual$1.mapPartitions$default$2();
                RDD boostersAndMetrics = qual$1.mapPartitions((Function1)x$1, x$2, ClassTag$.MODULE$.apply(Tuple2.class));
                RDD rdd = this.tryStageLevelScheduling(sc, runtimeParams, boostersAndMetrics);
                int x$3 = 1;
                Ordering x$4 = rdd.repartition$default$2(x$3);
                Tuple2 tuple22 = ((Tuple2[])rdd.repartition(x$3, x$4).collect())[0];
                if (tuple22 == null) {
                    throw new MatchError((Object)tuple22);
                }
                Booster booster = (Booster)tuple22._1();
                scala.collection.immutable.Map metrics = (scala.collection.immutable.Map)tuple22._2();
                Tuple2 tuple23 = new Tuple2((Object)booster, (Object)metrics);
                Booster booster2 = (Booster)tuple23._1();
                scala.collection.immutable.Map metrics2 = (scala.collection.immutable.Map)tuple23._2();
                tuple2 = new Tuple2((Object)booster2, (Object)metrics2);
            }
            catch (Throwable t) {
                this.logger().error((Object)"XGBoost job was aborted due to ", t);
                throw t;
            }
        }
        catch (Throwable throwable) {
            try {
                tracker.stop();
            }
            catch (Throwable t) {
                this.logger().error((Object)t);
            }
            throw throwable;
        }
        Tuple2 tuple24 = tuple2;
        try {
            tracker.stop();
        }
        catch (Throwable t) {
            this.logger().error((Object)t);
        }
        return tuple24;
    }

    private Object readResolve() {
        return MODULE$;
    }

    public static final /* synthetic */ float[] $anonfun$trainBooster$1(RuntimeParams runtimeParams$1, int x$1) {
        return (float[])Array$.MODULE$.ofDim(runtimeParams$1.numRounds(), ClassTag$.MODULE$.Float());
    }

    private XGBoost$() {
        MODULE$ = this;
        StageLevelScheduling.$init$(this);
        this.logger = LogFactory.getLog((String)"XGBoostSpark");
    }
}

