/*
 * Decompiled with CFR 0.152.
 */
package ai.catboost.spark;

import ai.catboost.spark.CatBoostPredictorTrait;
import ai.catboost.spark.CatBoostPredictorTrait$;
import ai.catboost.spark.Pool;
import ai.catboost.spark.SparkHelpers$;
import ai.catboost.spark.TrainingDriver;
import ai.catboost.spark.WorkerInfo;
import ai.catboost.spark.impl.CatBoostMasterWrapper;
import ai.catboost.spark.impl.CatBoostMasterWrapper$;
import ai.catboost.spark.impl.CatBoostWorkers;
import ai.catboost.spark.impl.CtrFeatures$;
import ai.catboost.spark.impl.CtrsContext;
import ai.catboost.spark.impl.Helpers$;
import ai.catboost.spark.params.QuantizationParams;
import ai.catboost.spark.params.TrainingParamsTrait;
import java.time.Duration;
import java.util.concurrent.ExecutorCompletionService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import org.apache.spark.internal.Logging;
import org.apache.spark.ml.PredictionModel;
import org.apache.spark.ml.param.Param;
import org.apache.spark.ml.param.Params;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.storage.StorageLevel$;
import org.json4s.JsonAST;
import org.json4s.jackson.JsonMethods$;
import ru.yandex.catboost.spark.catboost4j_spark.core.src.native_impl.TFullModel;
import ru.yandex.catboost.spark.catboost4j_spark.core.src.native_impl.native_impl;
import scala.Array$;
import scala.Function0;
import scala.Function1;
import scala.MatchError;
import scala.Predef$;
import scala.Serializable;
import scala.StringContext;
import scala.Tuple3;
import scala.Tuple4;
import scala.collection.Seq;
import scala.collection.immutable.Nil$;
import scala.reflect.ClassTag$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.IntRef;
import scala.runtime.RichInt$;

public abstract class CatBoostPredictorTrait$class {
    public static Tuple3 addEstimatedCtrFeatures(CatBoostPredictorTrait $this, Pool quantizedTrainPool, Pool[] quantizedEvalPools, JsonAST.JObject catBoostJsonParams) {
        int oneHotMaxSize;
        int catFeaturesMaxUniqValueCount = native_impl.CalcMaxCategoricalFeaturesUniqueValuesCountOnLearn(quantizedTrainPool.quantizedFeaturesInfo().__deref__());
        return catFeaturesMaxUniqValueCount > (oneHotMaxSize = native_impl.GetOneHotMaxSize(catFeaturesMaxUniqValueCount, quantizedTrainPool.isDefined(quantizedTrainPool.labelCol()), JsonMethods$.MODULE$.compact((JsonAST.JValue)catBoostJsonParams))) ? CtrFeatures$.MODULE$.addCtrsAsEstimated(quantizedTrainPool, quantizedEvalPools, (TrainingParamsTrait)((Object)$this), oneHotMaxSize) : new Tuple3<Pool, Pool[], CtrsContext>((Object)quantizedTrainPool, (Object)quantizedEvalPools, null);
    }

    public static Tuple4 preprocessBeforeTraining(CatBoostPredictorTrait $this, Pool quantizedTrainPool, Pool[] quantizedEvalPools) {
        JsonAST.JObject catBoostJsonParams = ai.catboost.spark.params.Helpers$.MODULE$.sparkMlParamsToCatBoostJsonParams((Params)$this, ai.catboost.spark.params.Helpers$.MODULE$.sparkMlParamsToCatBoostJsonParams$default$2());
        Tuple3<Pool, Pool[], CtrsContext> tuple3 = $this.addEstimatedCtrFeatures(quantizedTrainPool, quantizedEvalPools, catBoostJsonParams);
        if (tuple3 != null) {
            Tuple3 tuple32;
            Pool preprocessedTrainPool = (Pool)tuple3._1();
            Pool[] preprocessedEvalPools = (Pool[])tuple3._2();
            CtrsContext ctrsContext = (CtrsContext)tuple3._3();
            Tuple3 tuple33 = tuple32 = new Tuple3((Object)preprocessedTrainPool, (Object)preprocessedEvalPools, (Object)ctrsContext);
            Pool preprocessedTrainPool2 = (Pool)tuple33._1();
            Pool[] preprocessedEvalPools2 = (Pool[])tuple33._2();
            CtrsContext ctrsContext2 = (CtrsContext)tuple33._3();
            return new Tuple4((Object)preprocessedTrainPool2, (Object)preprocessedEvalPools2, (Object)catBoostJsonParams, (Object)ctrsContext2);
        }
        throw new MatchError(tuple3);
    }

    public static PredictionModel train(CatBoostPredictorTrait $this, Dataset dataset) {
        Pool pool = new Pool((Dataset<Row>)dataset);
        $this.copyValues((Params)pool, $this.copyValues$default$2());
        return $this.fit(pool, $this.fit$default$2());
    }

    public static PredictionModel fit(CatBoostPredictorTrait $this, Pool trainPool, Pool[] evalPools) {
        Pool pool;
        ai.catboost.spark.params.Helpers$.MODULE$.checkParamsCompatibility($this.getClass().getName(), (Params)$this, "trainPool", (Params)trainPool);
        RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), evalPools.length).foreach$mVc$sp((Function1)new Serializable($this, evalPools){
            public static final long serialVersionUID = 0L;
            private final /* synthetic */ CatBoostPredictorTrait $outer;
            private final Pool[] evalPools$1;

            public final void apply(int i) {
                this.apply$mcVI$sp(i);
            }

            public void apply$mcVI$sp(int i) {
                ai.catboost.spark.params.Helpers$.MODULE$.checkParamsCompatibility(this.$outer.getClass().getName(), (Params)this.$outer, new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"evalPool #", ""})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToInteger((int)i)})), (Params)this.evalPools$1[i]);
            }
            {
                if ($outer == null) {
                    throw null;
                }
                this.$outer = $outer;
                this.evalPools$1 = evalPools$1;
            }
        });
        SparkSession spark = trainPool.data().sparkSession();
        if (trainPool.isQuantized()) {
            pool = trainPool;
        } else {
            QuantizationParams quantizationParams = new QuantizationParams();
            $this.copyValues(quantizationParams, $this.copyValues$default$2());
            ((Logging)$this).logInfo((Function0)new Serializable($this){
                public static final long serialVersionUID = 0L;

                public final String apply() {
                    return new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"fit. schedule quantization for train dataset"})).s((Seq)Nil$.MODULE$);
                }
            });
            pool = trainPool.quantize(quantizationParams);
        }
        Pool quantizedTrainPool = pool;
        IntRef evalIdx = IntRef.create((int)0);
        Pool[] quantizedEvalPools = (Pool[])Predef$.MODULE$.refArrayOps((Object[])evalPools).map((Function1)new Serializable($this, quantizedTrainPool, evalIdx){
            public static final long serialVersionUID = 0L;
            private final /* synthetic */ CatBoostPredictorTrait $outer;
            private final Pool quantizedTrainPool$1;
            public final IntRef evalIdx$1;

            public final Pool apply(Pool evalPool) {
                Pool pool;
                ++this.evalIdx$1.elem;
                if (evalPool.isQuantized()) {
                    pool = evalPool;
                } else {
                    ((Logging)this.$outer).logInfo((Function0)new Serializable(this){
                        public static final long serialVersionUID = 0L;
                        private final /* synthetic */ CatBoostPredictorTrait$.anonfun.3 $outer;

                        public final String apply() {
                            return new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"fit. schedule quantization for eval dataset #", ""})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToInteger((int)(this.$outer.evalIdx$1.elem - 1))}));
                        }
                        {
                            if ($outer == null) {
                                throw null;
                            }
                            this.$outer = $outer;
                        }
                    });
                    pool = evalPool.quantize(this.quantizedTrainPool$1.quantizedFeaturesInfo());
                }
                return pool;
            }
            {
                void var3_3;
                if ($outer == null) {
                    throw null;
                }
                this.$outer = $outer;
                this.quantizedTrainPool$1 = quantizedTrainPool$1;
                this.evalIdx$1 = var3_3;
            }
        }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Pool.class)));
        Tuple4<Pool, Pool[], JsonAST.JObject, CtrsContext> tuple4 = $this.preprocessBeforeTraining(quantizedTrainPool, quantizedEvalPools);
        if (tuple4 != null) {
            TFullModel tFullModel;
            Future firstCompletedFuture;
            Tuple4 tuple42;
            Pool preprocessedTrainPool = (Pool)tuple4._1();
            Pool[] preprocessedEvalPools = (Pool[])tuple4._2();
            JsonAST.JObject catBoostJsonParams = (JsonAST.JObject)tuple4._3();
            CtrsContext ctrsContext = (CtrsContext)tuple4._4();
            Tuple4 tuple43 = tuple42 = new Tuple4((Object)preprocessedTrainPool, (Object)preprocessedEvalPools, (Object)catBoostJsonParams, (Object)ctrsContext);
            Pool preprocessedTrainPool2 = (Pool)tuple43._1();
            Pool[] preprocessedEvalPools2 = (Pool[])tuple43._2();
            JsonAST.JObject catBoostJsonParams2 = (JsonAST.JObject)tuple43._3();
            CtrsContext ctrsContext2 = (CtrsContext)tuple43._4();
            ((Logging)$this).logInfo((Function0)new Serializable($this){
                public static final long serialVersionUID = 0L;

                public final String apply() {
                    return "fit. persist preprocessedTrainPool: start";
                }
            });
            preprocessedTrainPool2.persist(StorageLevel$.MODULE$.MEMORY_ONLY());
            ((Logging)$this).logInfo((Function0)new Serializable($this){
                public static final long serialVersionUID = 0L;

                public final String apply() {
                    return "fit. persist preprocessedTrainPool: finish";
                }
            });
            int partitionCount = BoxesRunTime.unboxToInt((Object)$this.get((Param)((TrainingParamsTrait)((Object)$this)).sparkPartitionCount()).getOrElse((Function0)new Serializable($this, spark){
                public static final long serialVersionUID = 0L;
                private final SparkSession spark$1;

                public final int apply() {
                    return this.apply$mcI$sp();
                }

                public int apply$mcI$sp() {
                    return SparkHelpers$.MODULE$.getWorkerCount(this.spark$1);
                }
                {
                    this.spark$1 = spark$1;
                }
            }));
            ((Logging)$this).logInfo((Function0)new Serializable($this, partitionCount){
                public static final long serialVersionUID = 0L;
                private final int partitionCount$1;

                public final String apply() {
                    return new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"fit. partitionCount=", ""})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToInteger((int)this.partitionCount$1)}));
                }
                {
                    this.partitionCount$1 = partitionCount$1;
                }
            });
            String precomputedOnlineCtrMetaDataAsJsonString = ctrsContext2 == null ? null : ctrsContext2.precomputedOnlineCtrMetaDataAsJsonString();
            CatBoostMasterWrapper master = CatBoostMasterWrapper$.MODULE$.apply(preprocessedTrainPool2, preprocessedEvalPools2, JsonMethods$.MODULE$.compact((JsonAST.JValue)catBoostJsonParams2), precomputedOnlineCtrMetaDataAsJsonString);
            TrainingDriver trainingDriver = new TrainingDriver(0, partitionCount, (Function1<WorkerInfo[], BoxedUnit>)new Serializable($this, master){
                public static final long serialVersionUID = 0L;
                private final CatBoostMasterWrapper master$1;

                public final void apply(WorkerInfo[] workersInfo) {
                    this.master$1.trainCallback(workersInfo);
                }
                {
                    this.master$1 = master$1;
                }
            }, (Duration)$this.getOrDefault(((TrainingParamsTrait)((Object)$this)).workerInitializationTimeout()));
            int listeningPort = trainingDriver.getListeningPort();
            ((Logging)$this).logInfo((Function0)new Serializable($this, listeningPort){
                public static final long serialVersionUID = 0L;
                private final int listeningPort$1;

                public final String apply() {
                    return new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"fit. TrainingDriver listening port = ", ""})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToInteger((int)this.listeningPort$1)}));
                }
                {
                    this.listeningPort$1 = listeningPort$1;
                }
            });
            ((Logging)$this).logInfo((Function0)new Serializable($this){
                public static final long serialVersionUID = 0L;

                public final String apply() {
                    return new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"fit. Training started"})).s((Seq)Nil$.MODULE$);
                }
            });
            ExecutorCompletionService<BoxedUnit> ecs = new ExecutorCompletionService<BoxedUnit>(Executors.newFixedThreadPool(2));
            Future<BoxedUnit> trainingDriverFuture = ecs.submit(trainingDriver, BoxedUnit.UNIT);
            CatBoostWorkers workers = new CatBoostWorkers(spark, partitionCount, listeningPort, preprocessedTrainPool2, catBoostJsonParams2, precomputedOnlineCtrMetaDataAsJsonString, master.savedPoolsFuture());
            Future<BoxedUnit> workersFuture = ecs.submit(workers, BoxedUnit.UNIT);
            Future future = firstCompletedFuture = ecs.take();
            Future<BoxedUnit> future2 = workersFuture;
            if (!(future != null ? !future.equals(future2) : future2 != null)) {
                Helpers$.MODULE$.checkOneFutureAndWaitForOther(workersFuture, trainingDriverFuture, "workers");
            } else {
                Helpers$.MODULE$.checkOneFutureAndWaitForOther(trainingDriverFuture, workersFuture, "master");
            }
            ((Logging)$this).logInfo((Function0)new Serializable($this){
                public static final long serialVersionUID = 0L;

                public final String apply() {
                    return new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"fit. Training finished"})).s((Seq)Nil$.MODULE$);
                }
            });
            if (ctrsContext2 == null) {
                tFullModel = master.nativeModelResult();
            } else {
                ((Logging)$this).logInfo((Function0)new Serializable($this){
                    public static final long serialVersionUID = 0L;

                    public final String apply() {
                        return new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"fit. Add CtrProvider to model"})).s((Seq)Nil$.MODULE$);
                    }
                });
                tFullModel = CtrFeatures$.MODULE$.addCtrProviderToModel(master.nativeModelResult(), ctrsContext2, preprocessedTrainPool2, preprocessedEvalPools2);
            }
            Object resultModel = $this.createModel(tFullModel);
            preprocessedTrainPool2.unpersist();
            return resultModel;
        }
        throw new MatchError(tuple4);
    }

    public static Pool[] fit$default$2(CatBoostPredictorTrait $this) {
        return (Pool[])Array$.MODULE$.apply((Seq)Nil$.MODULE$, ClassTag$.MODULE$.apply(Pool.class));
    }

    public static void $init$(CatBoostPredictorTrait $this) {
    }
}

