/*
 * Decompiled with CFR 0.152.
 */
package ml.dmlc.xgboost4j.scala.spark;

import java.io.Serializable;
import java.util.Map;
import ml.dmlc.xgboost4j.java.Communicator;
import ml.dmlc.xgboost4j.java.IRabitTracker;
import ml.dmlc.xgboost4j.java.RabitTracker;
import ml.dmlc.xgboost4j.java.XGBoostError;
import ml.dmlc.xgboost4j.scala.Booster;
import ml.dmlc.xgboost4j.scala.DMatrix;
import ml.dmlc.xgboost4j.scala.EvalTrait;
import ml.dmlc.xgboost4j.scala.ExternalCheckpointManager;
import ml.dmlc.xgboost4j.scala.ExternalCheckpointParams;
import ml.dmlc.xgboost4j.scala.ObjectiveTrait;
import ml.dmlc.xgboost4j.scala.spark.TrackerConf;
import ml.dmlc.xgboost4j.scala.spark.Watches;
import ml.dmlc.xgboost4j.scala.spark.XGBoostExecutionParams;
import ml.dmlc.xgboost4j.scala.spark.XGBoostExecutionParamsFactory;
import ml.dmlc.xgboost4j.scala.spark.package$;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.spark.SparkContext;
import org.apache.spark.TaskContext;
import org.apache.spark.TaskContext$;
import org.apache.spark.rdd.RDD;
import org.apache.spark.rdd.RDDBarrier;
import org.apache.spark.resource.ResourceInformation;
import scala.Array$;
import scala.Function0;
import scala.Function1;
import scala.MatchError;
import scala.None$;
import scala.Option;
import scala.Predef;
import scala.Predef$;
import scala.Some;
import scala.Tuple2;
import scala.collection.GenIterable;
import scala.collection.Iterable$;
import scala.collection.Iterator;
import scala.collection.JavaConverters$;
import scala.collection.Seq;
import scala.collection.TraversableOnce;
import scala.collection.immutable.StringOps;
import scala.collection.mutable.ArrayOps;
import scala.math.Ordering;
import scala.reflect.ClassTag$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.ScalaRunTime$;

public final class XGBoost$
implements scala.Serializable {
    public static XGBoost$ MODULE$;
    private final Log logger;

    static {
        new XGBoost$();
    }

    private Log logger() {
        return this.logger;
    }

    public int getGPUAddrFromResources() {
        TaskContext tc = TaskContext$.MODULE$.get();
        if (tc == null) {
            throw new RuntimeException("Something wrong for task context");
        }
        scala.collection.immutable.Map resources = tc.resources();
        if (resources.contains((Object)"gpu")) {
            String[] addrs = ((ResourceInformation)resources.apply((Object)"gpu")).addresses();
            if (new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])addrs)).size() > 1) {
                this.logger().warn((Object)"XGBoost only supports 1 gpu per worker");
            }
            return new StringOps(Predef$.MODULE$.augmentString((String)new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])addrs)).head())).toInt();
        }
        throw new RuntimeException("gpu is not allocated by spark, please check if gpu scheduling is enabled");
    }

    private Watches buildWatchesAndCheck(Function0<Watches> buildWatchesFun) {
        Watches watches = (Watches)buildWatchesFun.apply();
        if (!watches.toMap().contains((Object)"train")) {
            throw new XGBoostError(new StringBuilder(64).append("detected an empty partition in the training data, partition ID:").append(" ").append(TaskContext$.MODULE$.getPartitionId()).toString());
        }
        return watches;
    }

    private Iterator<Tuple2<Booster, scala.collection.immutable.Map<String, float[]>>> buildDistributedBooster(Function0<Watches> buildWatches, XGBoostExecutionParams xgbExecutionParam, Map<String, String> rabitEnv, ObjectiveTrait obj, EvalTrait eval, Booster prevBooster) {
        Iterator iterator;
        Watches watches = null;
        String taskId = Integer.toString(TaskContext$.MODULE$.getPartitionId());
        String attempt = Integer.toString(TaskContext$.MODULE$.get().attemptNumber());
        rabitEnv.put("DMLC_TASK_ID", taskId);
        rabitEnv.put("DMLC_NUM_ATTEMPT", attempt);
        int numRounds = xgbExecutionParam.numRounds();
        boolean makeCheckpoint = xgbExecutionParam.checkpointParam().isDefined() && new StringOps(Predef$.MODULE$.augmentString(taskId)).toInt() == 0;
        try {
            try {
                Booster booster;
                Communicator.init(rabitEnv);
                watches = this.buildWatchesAndCheck(buildWatches);
                int numEarlyStoppingRounds = xgbExecutionParam.earlyStoppingRounds();
                float[][] metrics = (float[][])Array$.MODULE$.tabulate(watches.size(), (Function1 & Serializable & scala.Serializable)x$7 -> XGBoost$.$anonfun$buildDistributedBooster$1(numRounds, BoxesRunTime.unboxToInt((Object)x$7)), ClassTag$.MODULE$.apply(ScalaRunTime$.MODULE$.arrayClass(Float.TYPE)));
                Option<ExternalCheckpointParams> externalCheckpointParams = xgbExecutionParam.checkpointParam();
                scala.collection.immutable.Map params = xgbExecutionParam.toMap();
                if (xgbExecutionParam.device().exists((Function1 & Serializable & scala.Serializable)m -> BoxesRunTime.boxToBoolean((boolean)XGBoost$.$anonfun$buildDistributedBooster$2(m)))) {
                    int gpuId = xgbExecutionParam.isLocal() ? 0 : this.getGPUAddrFromResources();
                    this.logger().info((Object)new StringBuilder(31).append("Leveraging gpu device ").append(gpuId).append(" to train").toString());
                    params = params.$plus(Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)"device"), (Object)new StringBuilder(5).append("cuda:").append(gpuId).toString()));
                }
                Booster booster2 = booster = makeCheckpoint ? ml.dmlc.xgboost4j.scala.XGBoost$.MODULE$.trainAndSaveCheckpoint((DMatrix)watches.toMap().apply((Object)"train"), params, numRounds, watches.toMap(), metrics, obj, eval, numEarlyStoppingRounds, prevBooster, externalCheckpointParams) : ml.dmlc.xgboost4j.scala.XGBoost$.MODULE$.train((DMatrix)watches.toMap().apply((Object)"train"), params, numRounds, watches.toMap(), metrics, obj, eval, numEarlyStoppingRounds, prevBooster);
                iterator = TaskContext$.MODULE$.get().partitionId() == 0 ? scala.package$.MODULE$.Iterator().apply((Seq)Predef$.MODULE$.wrapRefArray((Object[])new Tuple2[]{Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)booster), (Object)((TraversableOnce)watches.toMap().keys().zip((GenIterable)Predef$.MODULE$.wrapRefArray((Object[])metrics), Iterable$.MODULE$.canBuildFrom())).toMap(Predef$.MODULE$.$conforms()))})) : scala.package$.MODULE$.Iterator().empty();
            }
            catch (XGBoostError xgbException) {
                this.logger().error((Object)new StringBuilder(43).append("XGBooster worker ").append(taskId).append(" has failed ").append(attempt).append(" times due to ").toString(), (Throwable)xgbException);
                throw xgbException;
            }
        }
        finally {
            Communicator.shutdown();
            if (watches != null) {
                watches.delete();
            }
        }
        return iterator;
    }

    public IRabitTracker getTracker(int nWorkers, TrackerConf trackerConf) {
        RabitTracker tracker = new RabitTracker(nWorkers, trackerConf.hostIp(), trackerConf.pythonExec());
        return tracker;
    }

    private IRabitTracker startTracker(int nWorkers, TrackerConf trackerConf) {
        IRabitTracker tracker = this.getTracker(nWorkers, trackerConf);
        Predef$.MODULE$.require(tracker.start(trackerConf.workerConnectionTimeout()), (Function0 & Serializable & scala.Serializable)() -> "FAULT: Failed to start tracker");
        return tracker;
    }

    public Tuple2<Booster, scala.collection.immutable.Map<String, float[]>> trainDistributed(SparkContext sc, Function1<XGBoostExecutionParams, Tuple2<RDD<Function0<Watches>>, Option<RDD<?>>>> buildTrainingData, scala.collection.immutable.Map<String, Object> params) throws XGBoostError {
        Tuple2 tuple2;
        this.logger().info((Object)new StringBuilder(34).append("Running XGBoost ").append(package$.MODULE$.VERSION()).append(" with parameters:\n").append(params.mkString("\n")).toString());
        XGBoostExecutionParamsFactory xgbParamsFactory = new XGBoostExecutionParamsFactory(params, sc);
        XGBoostExecutionParams xgbExecParams = xgbParamsFactory.buildXGBRuntimeParams();
        Map xgbRabitParams = (Map)JavaConverters$.MODULE$.mapAsJavaMapConverter(xgbParamsFactory.buildRabitParams()).asJava();
        Booster prevBooster = (Booster)xgbExecParams.checkpointParam().map((Function1 & Serializable & scala.Serializable)checkpointParam -> {
            ExternalCheckpointManager checkpointManager = new ExternalCheckpointManager(checkpointParam.checkpointPath(), FileSystem.get((Configuration)sc.hadoopConfiguration()));
            checkpointManager.cleanUpHigherVersions(xgbExecParams.numRounds());
            return checkpointManager.loadCheckpointAsScalaBooster();
        }).orNull(Predef$.MODULE$.$conforms());
        Tuple2 tuple22 = (Tuple2)buildTrainingData.apply((Object)xgbExecParams);
        if (tuple22 == null) {
            throw new MatchError((Object)tuple22);
        }
        RDD trainingRDD = (RDD)tuple22._1();
        Option optionalCachedRDD = (Option)tuple22._2();
        Tuple2 tuple23 = new Tuple2((Object)trainingRDD, (Object)optionalCachedRDD);
        RDD trainingRDD2 = (RDD)tuple23._1();
        Option optionalCachedRDD2 = (Option)tuple23._2();
        try {
            try {
                Tuple2 tuple24;
                IRabitTracker tracker = this.startTracker(xgbExecParams.numWorkers(), xgbExecParams.trackerConf());
                try {
                    tracker.getWorkerEnvs().putAll(xgbRabitParams);
                    Map rabitEnv = tracker.getWorkerEnvs();
                    RDDBarrier qual$1 = trainingRDD2.barrier();
                    Function1 & Serializable & scala.Serializable x$1 = (Function1 & Serializable & scala.Serializable)iter -> {
                        None$ optionWatches;
                        block0: {
                            optionWatches = None$.MODULE$;
                            if (!iter.hasNext()) break block0;
                            optionWatches = new Some(iter.next());
                        }
                        return (Iterator)optionWatches.map((Function1 & Serializable & scala.Serializable)buildWatches -> MODULE$.buildDistributedBooster((Function0<Watches>)buildWatches, xgbExecParams, rabitEnv, xgbExecParams.obj(), xgbExecParams.eval(), prevBooster)).getOrElse((Function0 & Serializable & scala.Serializable)() -> {
                            throw new RuntimeException("No Watches to train");
                        });
                    };
                    boolean x$2 = qual$1.mapPartitions$default$2();
                    RDD boostersAndMetrics = qual$1.mapPartitions((Function1)x$1, x$2, ClassTag$.MODULE$.apply(Tuple2.class));
                    int x$3 = 1;
                    Ordering x$4 = boostersAndMetrics.repartition$default$2(x$3);
                    Tuple2 tuple25 = ((Tuple2[])boostersAndMetrics.repartition(x$3, x$4).collect())[0];
                    if (tuple25 == null) {
                        throw new MatchError((Object)tuple25);
                    }
                    Booster booster = (Booster)tuple25._1();
                    scala.collection.immutable.Map metrics = (scala.collection.immutable.Map)tuple25._2();
                    Tuple2 tuple26 = new Tuple2((Object)booster, (Object)metrics);
                    Booster booster2 = (Booster)tuple26._1();
                    scala.collection.immutable.Map metrics2 = (scala.collection.immutable.Map)tuple26._2();
                    int trackerReturnVal = tracker.waitFor(0L);
                    this.logger().info((Object)new StringBuilder(29).append("Rabit returns with exit code ").append(trackerReturnVal).toString());
                    if (trackerReturnVal != 0) {
                        throw new XGBoostError("XGBoostModel training failed.");
                    }
                    tuple24 = new Tuple2((Object)booster2, (Object)metrics2);
                }
                finally {
                    tracker.stop();
                }
                Tuple2 tuple27 = tuple24;
                if (tuple27 == null) {
                    throw new MatchError((Object)tuple27);
                }
                Booster booster = (Booster)tuple27._1();
                scala.collection.immutable.Map metrics = (scala.collection.immutable.Map)tuple27._2();
                Tuple2 tuple28 = new Tuple2((Object)booster, (Object)metrics);
                Booster booster3 = (Booster)tuple28._1();
                scala.collection.immutable.Map metrics3 = (scala.collection.immutable.Map)tuple28._2();
                xgbExecParams.checkpointParam().foreach((Function1 & Serializable & scala.Serializable)cpParam -> {
                    XGBoost$.$anonfun$trainDistributed$5(xgbExecParams, sc, cpParam);
                    return BoxedUnit.UNIT;
                });
                tuple2 = new Tuple2((Object)booster3, (Object)metrics3);
            }
            catch (Throwable t) {
                this.logger().error((Object)"the job was aborted due to ", t);
                throw t;
            }
        }
        finally {
            optionalCachedRDD2.foreach((Function1 & Serializable & scala.Serializable)x$11 -> x$11.unpersist(x$11.unpersist$default$1()));
        }
        return tuple2;
    }

    private Object readResolve() {
        return MODULE$;
    }

    public static final /* synthetic */ float[] $anonfun$buildDistributedBooster$1(int numRounds$1, int x$7) {
        return (float[])Array$.MODULE$.ofDim(numRounds$1, ClassTag$.MODULE$.Float());
    }

    /*
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    public static final /* synthetic */ boolean $anonfun$buildDistributedBooster$2(String m) {
        String string = m;
        String string2 = "cuda";
        if (string == null) {
            if (string2 == null) return true;
        } else if (string.equals(string2)) return true;
        String string3 = m;
        String string4 = "gpu";
        if (string3 != null) {
            if (!string3.equals(string4)) return false;
            return true;
        }
        if (string4 == null) return true;
        return false;
    }

    public static final /* synthetic */ void $anonfun$trainDistributed$5(XGBoostExecutionParams xgbExecParams$1, SparkContext sc$1, ExternalCheckpointParams cpParam) {
        if (!((ExternalCheckpointParams)xgbExecParams$1.checkpointParam().get()).skipCleanCheckpoint()) {
            ExternalCheckpointManager checkpointManager = new ExternalCheckpointManager(cpParam.checkpointPath(), FileSystem.get((Configuration)sc$1.hadoopConfiguration()));
            checkpointManager.cleanPath();
            return;
        }
    }

    private XGBoost$() {
        MODULE$ = this;
        this.logger = LogFactory.getLog((String)"XGBoostSpark");
    }
}

