/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.mllib.optimization;

import org.apache.spark.Logging;
import org.apache.spark.mllib.optimization.Gradient;
import org.apache.spark.mllib.optimization.GradientDescent$$anonfun$runMiniBatchSGD$1$;
import org.apache.spark.mllib.optimization.Updater;
import org.apache.spark.rdd.RDD;
import org.jblas.DoubleMatrix;
import org.slf4j.Logger;
import scala.Function0;
import scala.Function1;
import scala.Function2;
import scala.MatchError;
import scala.Predef$;
import scala.Serializable;
import scala.Tuple2;
import scala.collection.Seq;
import scala.collection.TraversableOnce;
import scala.collection.immutable.StringOps;
import scala.collection.mutable.ArrayBuffer;
import scala.reflect.ClassTag$;
import scala.runtime.BoxesRunTime;
import scala.runtime.DoubleRef;
import scala.runtime.ObjectRef;
import scala.runtime.RichInt$;

public final class GradientDescent$
implements Logging {
    public static final GradientDescent$ MODULE$;
    private transient Logger org$apache$spark$Logging$$log_;

    static {
        new GradientDescent$();
    }

    public Logger org$apache$spark$Logging$$log_() {
        return this.org$apache$spark$Logging$$log_;
    }

    public void org$apache$spark$Logging$$log__$eq(Logger x$1) {
        this.org$apache$spark$Logging$$log_ = x$1;
    }

    public Logger log() {
        return Logging.class.log((Logging)this);
    }

    public void logInfo(Function0<String> msg) {
        Logging.class.logInfo((Logging)this, msg);
    }

    public void logDebug(Function0<String> msg) {
        Logging.class.logDebug((Logging)this, msg);
    }

    public void logTrace(Function0<String> msg) {
        Logging.class.logTrace((Logging)this, msg);
    }

    public void logWarning(Function0<String> msg) {
        Logging.class.logWarning((Logging)this, msg);
    }

    public void logError(Function0<String> msg) {
        Logging.class.logError((Logging)this, msg);
    }

    public void logInfo(Function0<String> msg, Throwable throwable) {
        Logging.class.logInfo((Logging)this, msg, (Throwable)throwable);
    }

    public void logDebug(Function0<String> msg, Throwable throwable) {
        Logging.class.logDebug((Logging)this, msg, (Throwable)throwable);
    }

    public void logTrace(Function0<String> msg, Throwable throwable) {
        Logging.class.logTrace((Logging)this, msg, (Throwable)throwable);
    }

    public void logWarning(Function0<String> msg, Throwable throwable) {
        Logging.class.logWarning((Logging)this, msg, (Throwable)throwable);
    }

    public void logError(Function0<String> msg, Throwable throwable) {
        Logging.class.logError((Logging)this, msg, (Throwable)throwable);
    }

    public boolean isTraceEnabled() {
        return Logging.class.isTraceEnabled((Logging)this);
    }

    public Tuple2<double[], double[]> runMiniBatchSGD(RDD<Tuple2<Object, double[]>> data, Gradient gradient, Updater updater, double stepSize, int numIterations, double regParam, double miniBatchFraction, double[] initialWeights) {
        ArrayBuffer stochasticLossHistory = new ArrayBuffer(numIterations);
        long nexamples = data.count();
        double miniBatchSize = (double)nexamples * miniBatchFraction;
        ObjectRef weights = new ObjectRef((Object)new DoubleMatrix(initialWeights.length, 1, initialWeights));
        DoubleRef regVal = new DoubleRef(0.0);
        RichInt$.MODULE$.to$extension0(Predef$.MODULE$.intWrapper(1), numIterations).foreach$mVc$sp((Function1)new Serializable(data, gradient, updater, stepSize, regParam, miniBatchFraction, stochasticLossHistory, miniBatchSize, weights, regVal){
            public static final long serialVersionUID = 0L;
            private final RDD data$1;
            public final Gradient gradient$1;
            private final Updater updater$1;
            private final double stepSize$1;
            private final double regParam$1;
            private final double miniBatchFraction$1;
            private final ArrayBuffer stochasticLossHistory$1;
            private final double miniBatchSize$1;
            public final ObjectRef weights$1;
            private final DoubleRef regVal$1;

            public final void apply(int i) {
                this.apply$mcVI$sp(i);
            }

            public void apply$mcVI$sp(int i) {
                Tuple2 tuple2 = (Tuple2)this.data$1.sample(false, this.miniBatchFraction$1, 42 + i).map((Function1)new anonfun$runMiniBatchSGD$1$$anonfun$1(this), ClassTag$.MODULE$.apply(Tuple2.class)).reduce((Function2)new anonfun$runMiniBatchSGD$1$$anonfun$2(this));
                if (tuple2 != null) {
                    Tuple2 tuple22;
                    DoubleMatrix gradientSum = (DoubleMatrix)tuple2._1();
                    double lossSum = tuple2._2$mcD$sp();
                    Tuple2 tuple23 = tuple22 = new Tuple2((Object)gradientSum, (Object)BoxesRunTime.boxToDouble((double)lossSum));
                    DoubleMatrix gradientSum2 = (DoubleMatrix)tuple23._1();
                    double lossSum2 = tuple23._2$mcD$sp();
                    this.stochasticLossHistory$1.append((Seq)Predef$.MODULE$.wrapDoubleArray(new double[]{lossSum2 / this.miniBatchSize$1 + this.regVal$1.elem}));
                    Tuple2<DoubleMatrix, Object> update = this.updater$1.compute((DoubleMatrix)this.weights$1.elem, gradientSum2.div(this.miniBatchSize$1), this.stepSize$1, i, this.regParam$1);
                    this.weights$1.elem = (DoubleMatrix)update._1();
                    this.regVal$1.elem = update._2$mcD$sp();
                    return;
                }
                throw new MatchError((Object)tuple2);
            }
            {
                this.data$1 = data$1;
                this.gradient$1 = gradient$1;
                this.updater$1 = updater$1;
                this.stepSize$1 = stepSize$1;
                this.regParam$1 = regParam$1;
                this.miniBatchFraction$1 = miniBatchFraction$1;
                this.stochasticLossHistory$1 = stochasticLossHistory$1;
                this.miniBatchSize$1 = miniBatchSize$1;
                this.weights$1 = weights$1;
                this.regVal$1 = regVal$1;
            }
        });
        this.logInfo((Function0<String>)new Serializable(stochasticLossHistory){
            public static final long serialVersionUID = 0L;
            private final ArrayBuffer stochasticLossHistory$1;

            public final String apply() {
                return new StringOps(Predef$.MODULE$.augmentString("GradientDescent finished. Last 10 stochastic losses %s")).format((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{((TraversableOnce)this.stochasticLossHistory$1.takeRight(10)).mkString(", ")}));
            }
            {
                this.stochasticLossHistory$1 = stochasticLossHistory$1;
            }
        });
        return new Tuple2((Object)((DoubleMatrix)weights.elem).toArray(), stochasticLossHistory.toArray(ClassTag$.MODULE$.Double()));
    }

    private GradientDescent$() {
        MODULE$ = this;
        Logging.class.$init$((Logging)this);
    }
}

