/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.mllib.optimization;

import breeze.linalg.DenseVector;
import breeze.linalg.DenseVector$;
import breeze.linalg.NumericOps;
import breeze.storage.Zero;
import org.apache.spark.Logging;
import org.apache.spark.annotation.DeveloperApi;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors$;
import org.apache.spark.mllib.optimization.Gradient;
import org.apache.spark.mllib.optimization.GradientDescent$;
import org.apache.spark.mllib.optimization.Updater;
import org.apache.spark.mllib.rdd.RDDFunctions;
import org.apache.spark.mllib.rdd.RDDFunctions$;
import org.apache.spark.rdd.RDD;
import org.slf4j.Logger;
import scala.Function0;
import scala.Function1;
import scala.MatchError;
import scala.Predef$;
import scala.Serializable;
import scala.StringContext;
import scala.Tuple2;
import scala.Tuple3;
import scala.collection.Seq;
import scala.collection.TraversableOnce;
import scala.collection.immutable.StringOps;
import scala.collection.mutable.ArrayBuffer;
import scala.reflect.ClassTag$;
import scala.runtime.BoxesRunTime;
import scala.runtime.DoubleRef;
import scala.runtime.ObjectRef;
import scala.runtime.RichInt$;

@DeveloperApi
public final class GradientDescent$
implements Logging,
Serializable {
    public static final GradientDescent$ MODULE$;
    private transient Logger org$apache$spark$Logging$$log_;

    static {
        new GradientDescent$();
    }

    public Logger org$apache$spark$Logging$$log_() {
        return this.org$apache$spark$Logging$$log_;
    }

    public void org$apache$spark$Logging$$log__$eq(Logger x$1) {
        this.org$apache$spark$Logging$$log_ = x$1;
    }

    public String logName() {
        return Logging.class.logName((Logging)this);
    }

    public Logger log() {
        return Logging.class.log((Logging)this);
    }

    public void logInfo(Function0<String> msg) {
        Logging.class.logInfo((Logging)this, msg);
    }

    public void logDebug(Function0<String> msg) {
        Logging.class.logDebug((Logging)this, msg);
    }

    public void logTrace(Function0<String> msg) {
        Logging.class.logTrace((Logging)this, msg);
    }

    public void logWarning(Function0<String> msg) {
        Logging.class.logWarning((Logging)this, msg);
    }

    public void logError(Function0<String> msg) {
        Logging.class.logError((Logging)this, msg);
    }

    public void logInfo(Function0<String> msg, Throwable throwable) {
        Logging.class.logInfo((Logging)this, msg, (Throwable)throwable);
    }

    public void logDebug(Function0<String> msg, Throwable throwable) {
        Logging.class.logDebug((Logging)this, msg, (Throwable)throwable);
    }

    public void logTrace(Function0<String> msg, Throwable throwable) {
        Logging.class.logTrace((Logging)this, msg, (Throwable)throwable);
    }

    public void logWarning(Function0<String> msg, Throwable throwable) {
        Logging.class.logWarning((Logging)this, msg, (Throwable)throwable);
    }

    public void logError(Function0<String> msg, Throwable throwable) {
        Logging.class.logError((Logging)this, msg, (Throwable)throwable);
    }

    public boolean isTraceEnabled() {
        return Logging.class.isTraceEnabled((Logging)this);
    }

    public Tuple2<Vector, double[]> runMiniBatchSGD(RDD<Tuple2<Object, Vector>> data, Gradient gradient, Updater updater, double stepSize, int numIterations, double regParam, double miniBatchFraction, Vector initialWeights) {
        ArrayBuffer stochasticLossHistory = new ArrayBuffer(numIterations);
        long numExamples = data.count();
        if (numExamples == 0L) {
            this.logWarning((Function0<String>)new Serializable(){
                public static final long serialVersionUID = 0L;

                public final String apply() {
                    return "GradientDescent.runMiniBatchSGD returning initial weights, no data found";
                }
            });
            return new Tuple2((Object)initialWeights, stochasticLossHistory.toArray(ClassTag$.MODULE$.Double()));
        }
        if ((double)numExamples * miniBatchFraction < 1.0) {
            this.logWarning((Function0<String>)new Serializable(){
                public static final long serialVersionUID = 0L;

                public final String apply() {
                    return "The miniBatchFraction is too small";
                }
            });
        }
        ObjectRef weights = new ObjectRef((Object)Vectors$.MODULE$.dense(initialWeights.toArray()));
        int n = ((Vector)weights.elem).size();
        DoubleRef regVal = new DoubleRef(updater.compute((Vector)weights.elem, Vectors$.MODULE$.dense(new double[((Vector)weights.elem).size()]), 0.0, 1, regParam)._2$mcD$sp());
        RichInt$.MODULE$.to$extension0(Predef$.MODULE$.intWrapper(1), numIterations).foreach$mVc$sp((Function1)new Serializable(data, gradient, updater, stepSize, numIterations, regParam, miniBatchFraction, stochasticLossHistory, weights, n, regVal){
            public static final long serialVersionUID = 0L;
            private final RDD data$1;
            public final Gradient gradient$1;
            private final Updater updater$1;
            private final double stepSize$1;
            public final int numIterations$1;
            private final double regParam$1;
            private final double miniBatchFraction$1;
            private final ArrayBuffer stochasticLossHistory$1;
            private final ObjectRef weights$1;
            private final int n$1;
            private final DoubleRef regVal$1;

            public final void apply(int i) {
                this.apply$mcVI$sp(i);
            }

            public void apply$mcVI$sp(int i) {
                int x$5;
                Serializable x$4;
                Serializable x$3;
                Tuple3 x$2;
                Broadcast bcWeights = this.data$1.context().broadcast((Object)((Vector)this.weights$1.elem), ClassTag$.MODULE$.apply(Vector.class));
                RDDFunctions<T> qual$1 = RDDFunctions$.MODULE$.fromRDD(this.data$1.sample(false, this.miniBatchFraction$1, (long)(42 + i)), ClassTag$.MODULE$.apply(Tuple2.class));
                Tuple3 tuple3 = qual$1.treeAggregate(x$2 = new Tuple3((Object)DenseVector$.MODULE$.zeros$mDc$sp(this.n$1, ClassTag$.MODULE$.Double(), (Zero)Zero.DoubleZero$.MODULE$), (Object)BoxesRunTime.boxToDouble((double)0.0), (Object)BoxesRunTime.boxToLong((long)0L)), x$3 = new Serializable(this, bcWeights){
                    public static final long serialVersionUID = 0L;
                    private final /* synthetic */ anonfun.runMiniBatchSGD.1 $outer;
                    private final Broadcast bcWeights$1;

                    public final Tuple3<DenseVector<Object>, Object, Object> apply(Tuple3<DenseVector<Object>, Object, Object> c, Tuple2<Object, Vector> v) {
                        double l = this.$outer.gradient$1.compute((Vector)v._2(), v._1$mcD$sp(), (Vector)this.bcWeights$1.value(), Vectors$.MODULE$.fromBreeze((breeze.linalg.Vector<Object>)((breeze.linalg.Vector)c._1())));
                        return new Tuple3(c._1(), (Object)BoxesRunTime.boxToDouble((double)(BoxesRunTime.unboxToDouble((Object)c._2()) + l)), (Object)BoxesRunTime.boxToLong((long)(BoxesRunTime.unboxToLong((Object)c._3()) + 1L)));
                    }
                    {
                        if ($outer == null) {
                            throw new NullPointerException();
                        }
                        this.$outer = $outer;
                        this.bcWeights$1 = bcWeights$1;
                    }
                }, x$4 = new Serializable(this){
                    public static final long serialVersionUID = 0L;

                    public final Tuple3<DenseVector<Object>, Object, Object> apply(Tuple3<DenseVector<Object>, Object, Object> c1, Tuple3<DenseVector<Object>, Object, Object> c2) {
                        return new Tuple3(((NumericOps)c1._1()).$plus$eq(c2._1(), DenseVector$.MODULE$.canAddIntoD()), (Object)BoxesRunTime.boxToDouble((double)(BoxesRunTime.unboxToDouble((Object)c1._2()) + BoxesRunTime.unboxToDouble((Object)c2._2()))), (Object)BoxesRunTime.boxToLong((long)(BoxesRunTime.unboxToLong((Object)c1._3()) + BoxesRunTime.unboxToLong((Object)c2._3()))));
                    }
                }, x$5 = qual$1.treeAggregate$default$4(x$2), ClassTag$.MODULE$.apply(Tuple3.class));
                if (tuple3 != null) {
                    Tuple3 tuple32;
                    DenseVector gradientSum = (DenseVector)tuple3._1();
                    double lossSum = BoxesRunTime.unboxToDouble((Object)tuple3._2());
                    long miniBatchSize = BoxesRunTime.unboxToLong((Object)tuple3._3());
                    Tuple3 tuple33 = tuple32 = new Tuple3((Object)gradientSum, (Object)BoxesRunTime.boxToDouble((double)lossSum), (Object)BoxesRunTime.boxToLong((long)miniBatchSize));
                    DenseVector gradientSum2 = (DenseVector)tuple33._1();
                    double lossSum2 = BoxesRunTime.unboxToDouble((Object)tuple33._2());
                    long miniBatchSize2 = BoxesRunTime.unboxToLong((Object)tuple33._3());
                    if (miniBatchSize2 > 0L) {
                        this.stochasticLossHistory$1.append((Seq)Predef$.MODULE$.wrapDoubleArray(new double[]{lossSum2 / (double)miniBatchSize2 + this.regVal$1.elem}));
                        Tuple2<Vector, Object> update2 = this.updater$1.compute((Vector)this.weights$1.elem, Vectors$.MODULE$.fromBreeze((breeze.linalg.Vector<Object>)((breeze.linalg.Vector)gradientSum2.$div((Object)BoxesRunTime.boxToDouble((double)miniBatchSize2), DenseVector$.MODULE$.dv_s_Op_Double_OpDiv()))), this.stepSize$1, i, this.regParam$1);
                        this.weights$1.elem = (Vector)update2._1();
                        this.regVal$1.elem = update2._2$mcD$sp();
                    } else {
                        GradientDescent$.MODULE$.logWarning((Function0<String>)new Serializable(this, i){
                            public static final long serialVersionUID = 0L;
                            private final /* synthetic */ anonfun.runMiniBatchSGD.1 $outer;
                            private final int i$1;

                            public final String apply() {
                                return new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"Iteration (", "/", "). The size of sampled batch is zero"})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToInteger((int)this.i$1), BoxesRunTime.boxToInteger((int)this.$outer.numIterations$1)}));
                            }
                            {
                                if ($outer == null) {
                                    throw new NullPointerException();
                                }
                                this.$outer = $outer;
                                this.i$1 = i$1;
                            }
                        });
                    }
                    return;
                }
                throw new MatchError((Object)tuple3);
            }
            {
                this.data$1 = data$1;
                this.gradient$1 = gradient$1;
                this.updater$1 = updater$1;
                this.stepSize$1 = stepSize$1;
                this.numIterations$1 = numIterations$1;
                this.regParam$1 = regParam$1;
                this.miniBatchFraction$1 = miniBatchFraction$1;
                this.stochasticLossHistory$1 = stochasticLossHistory$1;
                this.weights$1 = weights$1;
                this.n$1 = n$1;
                this.regVal$1 = regVal$1;
            }
        });
        this.logInfo((Function0<String>)new Serializable(stochasticLossHistory){
            public static final long serialVersionUID = 0L;
            private final ArrayBuffer stochasticLossHistory$1;

            public final String apply() {
                return new StringOps(Predef$.MODULE$.augmentString("GradientDescent.runMiniBatchSGD finished. Last 10 stochastic losses %s")).format((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{((TraversableOnce)this.stochasticLossHistory$1.takeRight(10)).mkString(", ")}));
            }
            {
                this.stochasticLossHistory$1 = stochasticLossHistory$1;
            }
        });
        return new Tuple2((Object)((Vector)weights.elem), stochasticLossHistory.toArray(ClassTag$.MODULE$.Double()));
    }

    private Object readResolve() {
        return MODULE$;
    }

    private GradientDescent$() {
        MODULE$ = this;
        Logging.class.$init$((Logging)this);
    }
}

