/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.mllib.classification;

import java.io.Serializable;
import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.ml.linalg.DenseMatrix;
import org.apache.spark.ml.linalg.Matrix;
import org.apache.spark.ml.util.Identifiable$;
import org.apache.spark.mllib.classification.LogisticRegressionModel;
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS$;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors$;
import org.apache.spark.mllib.optimization.L1Updater;
import org.apache.spark.mllib.optimization.LBFGS;
import org.apache.spark.mllib.optimization.LogisticGradient;
import org.apache.spark.mllib.optimization.SquaredL2Updater;
import org.apache.spark.mllib.optimization.Updater;
import org.apache.spark.mllib.regression.GeneralizedLinearAlgorithm;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.util.DataValidators$;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.SparkSession$;
import org.apache.spark.storage.StorageLevel;
import org.apache.spark.storage.StorageLevel$;
import scala.Function1;
import scala.Predef$;
import scala.collection.Seq;
import scala.collection.immutable.;
import scala.collection.immutable.List;
import scala.collection.immutable.Nil$;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.reflect.api.JavaUniverse;
import scala.reflect.api.Mirror;
import scala.reflect.api.TypeCreator;
import scala.reflect.api.TypeTags;
import scala.reflect.api.Types;
import scala.reflect.api.Universe;
import scala.reflect.runtime.package$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;

@ScalaSignature(bytes="\u0006\u0001\u0005\u001da\u0001B\u0001\u0003\u00015\u00111\u0004T8hSN$\u0018n\u0019*fOJ,7o]5p]^KG\u000f\u001b'C\r\u001e\u001b&BA\u0002\u0005\u00039\u0019G.Y:tS\u001aL7-\u0019;j_:T!!\u0002\u0004\u0002\u000b5dG.\u001b2\u000b\u0005\u001dA\u0011!B:qCJ\\'BA\u0005\u000b\u0003\u0019\t\u0007/Y2iK*\t1\"A\u0002pe\u001e\u001c\u0001aE\u0002\u0001\u001da\u00012a\u0004\n\u0015\u001b\u0005\u0001\"BA\t\u0005\u0003)\u0011Xm\u001a:fgNLwN\\\u0005\u0003'A\u0011!dR3oKJ\fG.\u001b>fI2Kg.Z1s\u00032<wN]5uQ6\u0004\"!\u0006\f\u000e\u0003\tI!a\u0006\u0002\u0003/1{w-[:uS\u000e\u0014Vm\u001a:fgNLwN\\'pI\u0016d\u0007CA\r\u001d\u001b\u0005Q\"\"A\u000e\u0002\u000bM\u001c\u0017\r\\1\n\u0005uQ\"\u0001D*fe&\fG.\u001b>bE2,\u0007\"B\u0010\u0001\t\u0003\u0001\u0013A\u0002\u001fj]&$h\bF\u0001\"!\t)\u0002\u0001C\u0004$\u0001\t\u0007I\u0011\t\u0013\u0002\u0013=\u0004H/[7ju\u0016\u0014X#A\u0013\u0011\u0005\u0019JS\"A\u0014\u000b\u0005!\"\u0011\u0001D8qi&l\u0017N_1uS>t\u0017B\u0001\u0016(\u0005\u0015a%IR$TQ\r\u0011CF\r\t\u0003[Aj\u0011A\f\u0006\u0003_\u0019\t!\"\u00198o_R\fG/[8o\u0013\t\tdFA\u0003TS:\u001cW-I\u00014\u0003\u0015\td&\r\u00181\u0011\u0019)\u0004\u0001)A\u0005K\u0005Qq\u000e\u001d;j[&TXM\u001d\u0011)\u0007Qb#\u0007C\u00049\u0001\t\u0007I\u0011K\u001d\u0002\u0015Y\fG.\u001b3bi>\u00148/F\u0001;!\rY\u0004IQ\u0007\u0002y)\u0011QHP\u0001\nS6lW\u000f^1cY\u0016T!a\u0010\u000e\u0002\u0015\r|G\u000e\\3di&|g.\u0003\u0002By\t!A*[:u!\u0011I2)\u0012(\n\u0005\u0011S\"!\u0003$v]\u000e$\u0018n\u001c82!\r1\u0015jS\u0007\u0002\u000f*\u0011\u0001JB\u0001\u0004e\u0012$\u0017B\u0001&H\u0005\r\u0011F\t\u0012\t\u0003\u001f1K!!\u0014\t\u0003\u00191\u000b'-\u001a7fIB{\u0017N\u001c;\u0011\u0005ey\u0015B\u0001)\u001b\u0005\u001d\u0011un\u001c7fC:DaA\u0015\u0001!\u0002\u0013Q\u0014a\u0003<bY&$\u0017\r^8sg\u0002BQ\u0001\u0016\u0001\u0005\nU\u000b1#\\;mi&d\u0015MY3m-\u0006d\u0017\u000eZ1u_J,\u0012A\u0011\u0005\u0006/\u0002!\t\u0001W\u0001\u000eg\u0016$h*^7DY\u0006\u001c8/Z:\u0015\u0005eSV\"\u0001\u0001\t\u000bm3\u0006\u0019\u0001/\u0002\u00159,Xn\u00117bgN,7\u000f\u0005\u0002\u001a;&\u0011aL\u0007\u0002\u0004\u0013:$\bf\u0001,-A\u0006\n\u0011-A\u00032]Mr\u0003\u0007C\u0003d\u0001\u0011EC-A\u0006de\u0016\fG/Z'pI\u0016dGc\u0001\u000bf[\")aM\u0019a\u0001O\u00069q/Z5hQR\u001c\bC\u00015l\u001b\u0005I'B\u00016\u0005\u0003\u0019a\u0017N\\1mO&\u0011A.\u001b\u0002\u0007-\u0016\u001cGo\u001c:\t\u000b9\u0014\u0007\u0019A8\u0002\u0013%tG/\u001a:dKB$\bCA\rq\u0013\t\t(D\u0001\u0004E_V\u0014G.\u001a\u0005\u0006g\u0002!\t\u0005^\u0001\u0004eVtGC\u0001\u000bv\u0011\u00151(\u000f1\u0001F\u0003\u0015Ig\u000e];u\u0011\u0015\u0019\b\u0001\"\u0011y)\r!\u0012P\u001f\u0005\u0006m^\u0004\r!\u0012\u0005\u0006w^\u0004\raZ\u0001\u000fS:LG/[1m/\u0016Lw\r\u001b;t\u0011\u0015\u0019\b\u0001\"\u0003~)\u0015!bp`A\u0001\u0011\u00151H\u00101\u0001F\u0011\u0015YH\u00101\u0001h\u0011\u0019\t\u0019\u0001 a\u0001\u001d\u0006\u0019Ro]3s'V\u0004\b\u000f\\5fI^+\u0017n\u001a5ug\"\u001a\u0001\u0001\f\u001a")
public class LogisticRegressionWithLBFGS
extends GeneralizedLinearAlgorithm<LogisticRegressionModel> {
    private final LBFGS optimizer;
    private final List<Function1<RDD<LabeledPoint>, Object>> validators;

    @Override
    public LBFGS optimizer() {
        return this.optimizer;
    }

    public List<Function1<RDD<LabeledPoint>, Object>> validators() {
        return this.validators;
    }

    private Function1<RDD<LabeledPoint>, Object> multiLabelValidator() {
        return (Function1 & Serializable & scala.Serializable)data -> BoxesRunTime.boxToBoolean((boolean)LogisticRegressionWithLBFGS.$anonfun$multiLabelValidator$1(this, data));
    }

    public LogisticRegressionWithLBFGS setNumClasses(int numClasses) {
        Predef$.MODULE$.require(numClasses > 1);
        this.numOfLinearPredictor_$eq(numClasses - 1);
        Object object = numClasses > 2 ? this.optimizer().setGradient(new LogisticGradient(numClasses)) : BoxedUnit.UNIT;
        return this;
    }

    @Override
    public LogisticRegressionModel createModel(Vector weights, double intercept) {
        return this.numOfLinearPredictor() == 1 ? new LogisticRegressionModel(weights, intercept) : new LogisticRegressionModel(weights, intercept, this.numFeatures(), this.numOfLinearPredictor() + 1);
    }

    @Override
    public LogisticRegressionModel run(RDD<LabeledPoint> input) {
        return this.run(input, this.generateInitialWeights(input), false);
    }

    @Override
    public LogisticRegressionModel run(RDD<LabeledPoint> input, Vector initialWeights) {
        return this.run(input, initialWeights, true);
    }

    private LogisticRegressionModel run(RDD<LabeledPoint> input, Vector initialWeights, boolean userSuppliedWeights) {
        LogisticRegressionModel logisticRegressionModel;
        if (this.numOfLinearPredictor() == 1) {
            Updater updater = this.optimizer().getUpdater();
            LogisticRegressionModel logisticRegressionModel2 = updater instanceof SquaredL2Updater ? this.runWithMlLogisticRegression$1(0.0, input, initialWeights, userSuppliedWeights) : (updater instanceof L1Updater ? this.runWithMlLogisticRegression$1(1.0, input, initialWeights, userSuppliedWeights) : (LogisticRegressionModel)super.run(input, initialWeights));
            logisticRegressionModel = logisticRegressionModel2;
        } else {
            logisticRegressionModel = (LogisticRegressionModel)super.run(input, initialWeights);
        }
        return logisticRegressionModel;
    }

    public static final /* synthetic */ boolean $anonfun$multiLabelValidator$1(LogisticRegressionWithLBFGS $this, RDD data) {
        return $this.numOfLinearPredictor() > 1 ? BoxesRunTime.unboxToBoolean((Object)DataValidators$.MODULE$.multiLabelValidator($this.numOfLinearPredictor() + 1).apply((Object)data)) : BoxesRunTime.unboxToBoolean((Object)DataValidators$.MODULE$.binaryLabelValidator().apply((Object)data));
    }

    private final LogisticRegressionModel runWithMlLogisticRegression$1(double elasticNetParam, RDD input$1, Vector initialWeights$1, boolean userSuppliedWeights$1) {
        Object object;
        LogisticRegression lr = new LogisticRegression();
        lr.setRegParam(this.optimizer().getRegParam());
        lr.setElasticNetParam(elasticNetParam);
        lr.setStandardization(this.useFeatureScaling());
        if (userSuppliedWeights$1) {
            String uid = Identifiable$.MODULE$.randomUID("logreg-static");
            object = lr.setInitialModel(new org.apache.spark.ml.classification.LogisticRegressionModel(uid, (Matrix)new DenseMatrix(1, initialWeights$1.size(), initialWeights$1.toArray()), Vectors$.MODULE$.dense(1.0, (Seq<Object>)Predef$.MODULE$.wrapDoubleArray(new double[0])).asML(), 2, false));
        } else {
            object = BoxedUnit.UNIT;
        }
        lr.setFitIntercept(this.addIntercept());
        lr.setMaxIter(this.optimizer().getNumIterations());
        lr.setTol(this.optimizer().getConvergenceTol());
        SparkSession spark = SparkSession$.MODULE$.builder().sparkContext(input$1.context()).getOrCreate();
        JavaUniverse $u = package$.MODULE$.universe();
        JavaUniverse.JavaMirror $m = package$.MODULE$.universe().runtimeMirror(LogisticRegressionWithLBFGS.class.getClassLoader());
        public final class Org_apache_spark_mllib_classification_LogisticRegressionWithLBFGS$$typecreator1$1
        extends TypeCreator {
            public <U extends Universe> Types.TypeApi apply(Mirror<U> $m$untyped) {
                Universe $u = $m$untyped.universe();
                Mirror<U> $m = $m$untyped;
                return $m.staticClass("org.apache.spark.ml.feature.LabeledPoint").asType().toTypeConstructor();
            }

            public Org_apache_spark_mllib_classification_LogisticRegressionWithLBFGS$$typecreator1$1(LogisticRegressionWithLBFGS $outer) {
            }
        }
        Dataset df = spark.createDataFrame(input$1.map((Function1 & Serializable & scala.Serializable)x$3 -> x$3.asML(), ClassTag$.MODULE$.apply(org.apache.spark.ml.feature.LabeledPoint.class)), ((TypeTags)$u).TypeTag().apply((Mirror)$m, (TypeCreator)new Org_apache_spark_mllib_classification_LogisticRegressionWithLBFGS$$typecreator1$1(null)));
        StorageLevel storageLevel = input$1.getStorageLevel();
        StorageLevel storageLevel2 = StorageLevel$.MODULE$.NONE();
        boolean handlePersistence = !(storageLevel != null ? !storageLevel.equals(storageLevel2) : storageLevel2 != null);
        org.apache.spark.ml.classification.LogisticRegressionModel mlLogisticRegressionModel = lr.train(df, handlePersistence);
        Vector weights = Vectors$.MODULE$.dense(mlLogisticRegressionModel.coefficients().toArray());
        return this.createModel(weights, mlLogisticRegressionModel.intercept());
    }

    public LogisticRegressionWithLBFGS() {
        this.setFeatureScaling(true);
        this.optimizer = new LBFGS(new LogisticGradient(), new SquaredL2Updater());
        this.validators = new .colon.colon(this.multiLabelValidator(), (List)Nil$.MODULE$);
    }
}

