/*
 * Decompiled with CFR 0.152.
 */
package ai.catboost.spark;

import ai.catboost.spark.CatBoostPredictorTrait;
import ai.catboost.spark.Pool;
import ai.catboost.spark.SparkHelpers$;
import ai.catboost.spark.TrainingDriver;
import ai.catboost.spark.WorkerInfo;
import ai.catboost.spark.impl.Master;
import ai.catboost.spark.impl.Master$;
import ai.catboost.spark.impl.Workers;
import ai.catboost.spark.params.Helpers$;
import ai.catboost.spark.params.QuantizationParams;
import ai.catboost.spark.params.TrainingParamsTrait;
import java.time.Duration;
import java.util.concurrent.ExecutorCompletionService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import org.apache.spark.ml.PredictionModel;
import org.apache.spark.ml.param.Param;
import org.apache.spark.ml.param.Params;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.json4s.JsonAST;
import org.json4s.jackson.JsonMethods$;
import scala.Array$;
import scala.Function0;
import scala.Function1;
import scala.MatchError;
import scala.Predef$;
import scala.Serializable;
import scala.StringContext;
import scala.Tuple3;
import scala.collection.Seq;
import scala.collection.immutable.Nil$;
import scala.reflect.ClassTag$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.RichInt$;

public abstract class CatBoostPredictorTrait$class {
    public static Tuple3 preprocessBeforeTraining(CatBoostPredictorTrait $this, Pool quantizedTrainPool, Pool[] quantizedEvalPools) {
        return new Tuple3((Object)quantizedTrainPool, (Object)quantizedEvalPools, (Object)Helpers$.MODULE$.sparkMlParamsToCatBoostJsonParams((Params)$this, Helpers$.MODULE$.sparkMlParamsToCatBoostJsonParams$default$2()));
    }

    public static PredictionModel train(CatBoostPredictorTrait $this, Dataset dataset) {
        Pool pool = new Pool((Dataset<Row>)dataset);
        $this.copyValues((Params)pool, $this.copyValues$default$2());
        return $this.fit(pool, $this.fit$default$2());
    }

    public static PredictionModel fit(CatBoostPredictorTrait $this, Pool trainPool, Pool[] evalPools) {
        Pool pool;
        Helpers$.MODULE$.checkParamsCompatibility($this.getClass().getName(), (Params)$this, "trainPool", (Params)trainPool);
        RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), evalPools.length).foreach$mVc$sp((Function1)new Serializable($this, evalPools){
            public static final long serialVersionUID = 0L;
            private final /* synthetic */ CatBoostPredictorTrait $outer;
            private final Pool[] evalPools$1;

            public final void apply(int i) {
                this.apply$mcVI$sp(i);
            }

            public void apply$mcVI$sp(int i) {
                Helpers$.MODULE$.checkParamsCompatibility(this.$outer.getClass().getName(), (Params)this.$outer, new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"evalPool #", ""})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToInteger((int)i)})), (Params)this.evalPools$1[i]);
            }
            {
                if ($outer == null) {
                    throw null;
                }
                this.$outer = $outer;
                this.evalPools$1 = evalPools$1;
            }
        });
        SparkSession spark = trainPool.data().sparkSession();
        int partitionCount = BoxesRunTime.unboxToInt((Object)$this.get((Param)((TrainingParamsTrait)((Object)$this)).sparkPartitionCount()).getOrElse((Function0)new Serializable($this, spark){
            public static final long serialVersionUID = 0L;
            private final SparkSession spark$1;

            public final int apply() {
                return this.apply$mcI$sp();
            }

            public int apply$mcI$sp() {
                return SparkHelpers$.MODULE$.getWorkerCount(this.spark$1);
            }
            {
                this.spark$1 = spark$1;
            }
        }));
        if (trainPool.isQuantized()) {
            pool = trainPool;
        } else {
            QuantizationParams quantizationParams = new QuantizationParams();
            $this.copyValues(quantizationParams, $this.copyValues$default$2());
            Pool qual$1 = trainPool.quantize(quantizationParams);
            int x$2 = partitionCount;
            boolean x$3 = qual$1.repartition$default$2();
            pool = qual$1.repartition(x$2, x$3);
        }
        Pool quantizedTrainPool = pool;
        Pool[] quantizedEvalPools = (Pool[])Predef$.MODULE$.refArrayOps((Object[])evalPools).map((Function1)new Serializable($this, quantizedTrainPool){
            public static final long serialVersionUID = 0L;
            private final Pool quantizedTrainPool$1;

            public final Pool apply(Pool evalPool) {
                return evalPool.isQuantized() ? evalPool : evalPool.quantize(this.quantizedTrainPool$1.quantizedFeaturesInfo());
            }
            {
                this.quantizedTrainPool$1 = quantizedTrainPool$1;
            }
        }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Pool.class)));
        Tuple3<Pool, Pool[], JsonAST.JObject> tuple3 = $this.preprocessBeforeTraining(quantizedTrainPool, quantizedEvalPools);
        if (tuple3 != null) {
            Future firstCompletedFuture;
            Tuple3 tuple32;
            Pool preprocessedTrainPool = (Pool)tuple3._1();
            Pool[] preprocessedEvalPools = (Pool[])tuple3._2();
            JsonAST.JObject catBoostJsonParams = (JsonAST.JObject)tuple3._3();
            Tuple3 tuple33 = tuple32 = new Tuple3((Object)preprocessedTrainPool, (Object)preprocessedEvalPools, (Object)catBoostJsonParams);
            Pool preprocessedTrainPool2 = (Pool)tuple33._1();
            Pool[] preprocessedEvalPools2 = (Pool[])tuple33._2();
            JsonAST.JObject catBoostJsonParams2 = (JsonAST.JObject)tuple33._3();
            Master master = Master$.MODULE$.apply(preprocessedTrainPool2, preprocessedEvalPools2, JsonMethods$.MODULE$.compact((JsonAST.JValue)catBoostJsonParams2));
            TrainingDriver trainingDriver = new TrainingDriver(0, partitionCount, (Function1<WorkerInfo[], BoxedUnit>)new Serializable($this, master){
                public static final long serialVersionUID = 0L;
                private final Master master$1;

                public final void apply(WorkerInfo[] workersInfo) {
                    this.master$1.trainCallback(workersInfo);
                }
                {
                    this.master$1 = master$1;
                }
            }, (Duration)$this.getOrDefault(((TrainingParamsTrait)((Object)$this)).workerInitializationTimeout()));
            int listeningPort = trainingDriver.getListeningPort();
            ExecutorCompletionService<BoxedUnit> ecs = new ExecutorCompletionService<BoxedUnit>(Executors.newFixedThreadPool(2));
            Future<BoxedUnit> trainingDriverFuture = ecs.submit(trainingDriver, BoxedUnit.UNIT);
            Workers workers = new Workers(spark, listeningPort, preprocessedTrainPool2, catBoostJsonParams2);
            Future<BoxedUnit> workersFuture = ecs.submit(workers, BoxedUnit.UNIT);
            Future future = firstCompletedFuture = ecs.take();
            Future<BoxedUnit> future2 = workersFuture;
            if (!(future != null ? !future.equals(future2) : future2 != null)) {
                ai.catboost.spark.impl.Helpers$.MODULE$.checkOneFutureAndWaitForOther(workersFuture, trainingDriverFuture, "workers");
            } else {
                ai.catboost.spark.impl.Helpers$.MODULE$.checkOneFutureAndWaitForOther(trainingDriverFuture, workersFuture, "master");
            }
            return $this.createModel(master.nativeModelResult());
        }
        throw new MatchError(tuple3);
    }

    public static Pool[] fit$default$2(CatBoostPredictorTrait $this) {
        return (Pool[])Array$.MODULE$.apply((Seq)Nil$.MODULE$, ClassTag$.MODULE$.apply(Pool.class));
    }

    public static void $init$(CatBoostPredictorTrait $this) {
    }
}

