/*
 * Decompiled with CFR 0.152.
 */
package ai.catboost.spark;

import ai.catboost.CatBoostError;
import ai.catboost.spark.DataHelpers$;
import ai.catboost.spark.Pool;
import ai.catboost.spark.Pool$;
import ai.catboost.spark.params.Helpers$;
import ai.catboost.spark.params.PoolLoadParams;
import java.nio.file.Path;
import java.util.Arrays;
import java.util.UUID;
import org.apache.spark.ml.attribute.Attribute;
import org.apache.spark.ml.attribute.AttributeGroup;
import org.apache.spark.ml.attribute.AttributeGroup$;
import org.apache.spark.ml.linalg.SparseVector;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.linalg.Vectors$;
import org.apache.spark.ml.param.Param;
import org.apache.spark.ml.param.ParamPair;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.expressions.UserDefinedFunction;
import org.apache.spark.sql.functions$;
import org.apache.spark.sql.types.Metadata;
import ru.yandex.catboost.spark.catboost4j_spark.core.src.native_impl.QuantizedFeaturesInfoPtr;
import ru.yandex.catboost.spark.catboost4j_spark.core.src.native_impl.TFeaturesLayout;
import scala.Array$;
import scala.Function0;
import scala.Function1;
import scala.MatchError;
import scala.Option;
import scala.Predef$;
import scala.Serializable;
import scala.StringContext;
import scala.Tuple2;
import scala.collection.Iterator;
import scala.collection.Map;
import scala.collection.Seq;
import scala.collection.immutable.Nil$;
import scala.collection.mutable.Map$;
import scala.collection.mutable.StringBuilder;
import scala.package$;
import scala.reflect.ClassTag$;
import scala.reflect.api.JavaUniverse;
import scala.reflect.api.Mirror;
import scala.reflect.api.TypeCreator;
import scala.reflect.api.TypeTags;
import scala.reflect.api.Types;
import scala.reflect.api.Universe;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.IntRef;

public final class Pool$
implements Serializable {
    public static final Pool$ MODULE$;

    static {
        new Pool$();
    }

    private Dataset<Row> updateSparseFeaturesSize(Dataset<Row> data) {
        SparkSession spark = data.sparkSession();
        Dataset maxFeatureCountDF = data.mapPartitions((Function1)new Serializable(){
            public static final long serialVersionUID = 0L;

            public final Iterator<Object> apply(Iterator<Row> rows) {
                IntRef maxFeatureCount = IntRef.create((int)0);
                rows.foreach((Function1)new Serializable(this, maxFeatureCount){
                    public static final long serialVersionUID = 0L;
                    private final IntRef maxFeatureCount$2;

                    public final void apply(Row row) {
                        int featureCount = ((SparseVector)row.getAs(0)).size();
                        if (featureCount > this.maxFeatureCount$2.elem) {
                            this.maxFeatureCount$2.elem = featureCount;
                        }
                    }
                    {
                        this.maxFeatureCount$2 = maxFeatureCount$2;
                    }
                });
                return package$.MODULE$.Iterator().apply((Seq)Predef$.MODULE$.wrapIntArray(new int[]{maxFeatureCount.elem}));
            }
        }, spark.implicits().newIntEncoder());
        IntRef maxFeatureCount = IntRef.create((int)0);
        Predef$.MODULE$.intArrayOps((int[])maxFeatureCountDF.collect()).foreach((Function1)new Serializable(maxFeatureCount){
            public static final long serialVersionUID = 0L;
            private final IntRef maxFeatureCount$1;

            public final void apply(int featureCount) {
                this.apply$mcVI$sp(featureCount);
            }

            public void apply$mcVI$sp(int featureCount) {
                if (featureCount > this.maxFeatureCount$1.elem) {
                    this.maxFeatureCount$1.elem = featureCount;
                }
            }
            {
                this.maxFeatureCount$1 = maxFeatureCount$1;
            }
        });
        String[] existingFeatureNames = this.getFeatureNames(data, "features");
        String[] extendedFeatureNames = (String[])Arrays.copyOf((Object[])existingFeatureNames, maxFeatureCount.elem);
        Arrays.fill(extendedFeatureNames, existingFeatureNames.length, maxFeatureCount.elem, "");
        Metadata updatedMetadata = DataHelpers$.MODULE$.makeFeaturesMetadata(extendedFeatureNames);
        JavaUniverse $u = scala.reflect.runtime.package$.MODULE$.universe();
        JavaUniverse.JavaMirror $m = scala.reflect.runtime.package$.MODULE$.universe().runtimeMirror(this.getClass().getClassLoader());
        JavaUniverse $u2 = scala.reflect.runtime.package$.MODULE$.universe();
        JavaUniverse.JavaMirror $m2 = scala.reflect.runtime.package$.MODULE$.universe().runtimeMirror(this.getClass().getClassLoader());
        public final class Ai_catboost_spark_Pool$$typecreator1$1
        extends TypeCreator {
            public <U extends Universe> Types.TypeApi apply(Mirror<U> $m$untyped) {
                Universe $u = $m$untyped.universe();
                Mirror<U> $m = $m$untyped;
                return $m.staticClass("org.apache.spark.ml.linalg.Vector").asType().toTypeConstructor();
            }

            public Ai_catboost_spark_Pool$$typecreator1$1() {
            }
        }
        public final class Ai_catboost_spark_Pool$$typecreator2$1
        extends TypeCreator {
            public <U extends Universe> Types.TypeApi apply(Mirror<U> $m$untyped) {
                Universe $u = $m$untyped.universe();
                Mirror<U> $m = $m$untyped;
                return $m.staticClass("org.apache.spark.ml.linalg.Vector").asType().toTypeConstructor();
            }

            public Ai_catboost_spark_Pool$$typecreator2$1() {
            }
        }
        UserDefinedFunction updateFeaturesSize = functions$.MODULE$.udf((Function1)new Serializable(maxFeatureCount){
            public static final long serialVersionUID = 0L;
            private final IntRef maxFeatureCount$1;

            public final Vector apply(Vector features) {
                SparseVector sparseFeatures = (SparseVector)features;
                return Vectors$.MODULE$.sparse(this.maxFeatureCount$1.elem, sparseFeatures.indices(), sparseFeatures.values());
            }
            {
                this.maxFeatureCount$1 = maxFeatureCount$1;
            }
        }, ((TypeTags)$u).TypeTag().apply((Mirror)$m, (TypeCreator)new Ai_catboost_spark_Pool$$typecreator1$1()), ((TypeTags)$u2).TypeTag().apply((Mirror)$m2, (TypeCreator)new Ai_catboost_spark_Pool$$typecreator2$1()));
        return data.withColumn("features", updateFeaturesSize.apply((Seq)Predef$.MODULE$.wrapRefArray((Object[])new Column[]{spark.implicits().StringToColumn(new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"features"}))).$((Seq)Nil$.MODULE$)})).as("_", updatedMetadata));
    }

    /*
     * WARNING - void declaration
     */
    public Pool load(SparkSession spark, String dataPathWithScheme, Path columnDescription, PoolLoadParams params) {
        Tuple2 tuple2;
        String[] pathParts = dataPathWithScheme.split("://", 2);
        Tuple2 tuple22 = tuple2 = Predef$.MODULE$.refArrayOps((Object[])pathParts).size() == 1 ? new Tuple2((Object)"dsv", (Object)pathParts[0]) : new Tuple2((Object)pathParts[0], (Object)pathParts[1]);
        if (tuple2 != null) {
            void var11_11;
            Tuple2 tuple23;
            String dataScheme = (String)tuple2._1();
            String dataPath = (String)tuple2._2();
            Tuple2 tuple24 = tuple23 = new Tuple2((Object)dataScheme, (Object)dataPath);
            String dataScheme2 = (String)tuple24._1();
            String dataPath2 = (String)tuple24._2();
            String string = dataScheme2;
            boolean bl = "dsv".equals(string) ? true : "libsvm".equals(string);
            if (bl) {
                String string2;
                String format = string2 = "ai.catboost.spark.CatBoostTextFileFormat";
                scala.collection.mutable.Map dataSourceOptions = (scala.collection.mutable.Map)Map$.MODULE$.apply((Seq)Nil$.MODULE$);
                dataSourceOptions.update((Object)"dataScheme", (Object)dataScheme2);
                params.extractParamMap().toSeq().foreach((Function1)new Serializable(dataSourceOptions){
                    public static final long serialVersionUID = 0L;
                    private final scala.collection.mutable.Map dataSourceOptions$1;

                    public final void apply(ParamPair<?> x0$1) {
                        ParamPair<?> paramPair = x0$1;
                        if (paramPair != null) {
                            Param param = paramPair.param();
                            Object value = paramPair.value();
                            this.dataSourceOptions$1.update((Object)param.name(), (Object)value.toString());
                            BoxedUnit boxedUnit = BoxedUnit.UNIT;
                            return;
                        }
                        throw new MatchError(paramPair);
                    }
                    {
                        this.dataSourceOptions$1 = dataSourceOptions$1;
                    }
                });
                if (columnDescription != null) {
                    dataSourceOptions.update((Object)"columnDescription", (Object)((Object)columnDescription).toString());
                }
                dataSourceOptions.update((Object)"catboostJsonParams", (Object)Helpers$.MODULE$.sparkMlParamsToCatBoostJsonParamsString(params));
                dataSourceOptions.update((Object)"uuid", (Object)UUID.randomUUID().toString());
                Dataset<Row> data = spark.read().format(format).options((Map)dataSourceOptions).load(dataPath2);
                String string3 = dataScheme2;
                String string4 = "libsvm";
                Pool pool = new Pool(!(string3 != null ? !string3.equals(string4) : string4 != null) ? this.updateSparseFeaturesSize(data) : data);
                this.setColumnParamsFromLoadedData(pool);
                return pool;
            }
            throw new CatBoostError(new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"Loading pool from scheme ", " is not supported"})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{var11_11})));
        }
        throw new MatchError((Object)tuple2);
    }

    public Path load$default$3() {
        return null;
    }

    public PoolLoadParams load$default$4() {
        return new PoolLoadParams();
    }

    public void setColumnParamsFromLoadedData(Pool pool) {
        Predef$.MODULE$.refArrayOps((Object[])pool.data().columns()).foreach((Function1)new Serializable(pool){
            public static final long serialVersionUID = 0L;
            private final Pool pool$1;

            public final Pool apply(String name) {
                return (Pool)this.pool$1.set(new StringBuilder().append((Object)name).append((Object)"Col").toString(), name);
            }
            {
                this.pool$1 = pool$1;
            }
        });
    }

    public int getFeatureCount(Dataset<Row> data, String featuresCol) {
        int n;
        AttributeGroup attributeGroup = AttributeGroup$.MODULE$.fromStructField(data.schema().apply(featuresCol));
        Option optNumAttributes = attributeGroup.numAttributes();
        if (optNumAttributes.isDefined()) {
            n = BoxesRunTime.unboxToInt((Object)optNumAttributes.get());
        } else {
            Option optAttributes = attributeGroup.attributes();
            if (optAttributes.isDefined()) {
                return Predef$.MODULE$.refArrayOps((Object[])optAttributes.get()).size();
            }
            if (data.count() == 0L) {
                throw new CatBoostError("Cannot get feature count from empty DataFrame without attributes");
            }
            n = ((Vector)((Row)data.first()).getAs(featuresCol)).size();
        }
        return n;
    }

    public String[] getFeatureNames(Dataset<Row> data, String featuresCol) {
        String[] stringArray;
        int featureCount = this.getFeatureCount(data, featuresCol);
        Option attributes = AttributeGroup$.MODULE$.fromStructField(data.schema().apply(featuresCol)).attributes();
        if (attributes.isEmpty()) {
            String[] featureNames = new String[featureCount];
            Arrays.fill(featureNames, 0, featureCount, "");
            stringArray = featureNames;
        } else {
            if (Predef$.MODULE$.refArrayOps((Object[])attributes.get()).size() != featureCount) {
                throw new CatBoostError(new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"number of attributes (", ") is not equal to featureCount (", ")"})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToInteger((int)Predef$.MODULE$.refArrayOps((Object[])attributes.get()).size()), BoxesRunTime.boxToInteger((int)featureCount)})));
            }
            stringArray = (String[])Predef$.MODULE$.refArrayOps((Object[])Predef$.MODULE$.refArrayOps((Object[])attributes.get()).map((Function1)new Serializable(){
                public static final long serialVersionUID = 0L;

                public final String apply(Attribute attribute) {
                    return (String)attribute.name().getOrElse((Function0)new Serializable(this){
                        public static final long serialVersionUID = 0L;

                        public final String apply() {
                            return "";
                        }
                    });
                }
            }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(String.class)))).toArray(ClassTag$.MODULE$.apply(String.class));
        }
        return stringArray;
    }

    public Dataset<Row> $lessinit$greater$default$2() {
        return null;
    }

    public TFeaturesLayout $lessinit$greater$default$3() {
        return null;
    }

    public QuantizedFeaturesInfoPtr $lessinit$greater$default$4() {
        return null;
    }

    private Object readResolve() {
        return MODULE$;
    }

    private Pool$() {
        MODULE$ = this;
    }
}

