/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.mllib.optimization;

import org.apache.spark.annotation.DeveloperApi;
import org.apache.spark.mllib.linalg.BLAS$;
import org.apache.spark.mllib.linalg.DenseVector;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors$;
import org.apache.spark.mllib.optimization.Gradient;
import org.apache.spark.mllib.optimization.LogisticGradient$;
import org.apache.spark.mllib.util.MLUtils$;
import scala.Array$;
import scala.Function1;
import scala.Function2;
import scala.Predef$;
import scala.Serializable;
import scala.StringContext;
import scala.Tuple2;
import scala.collection.Seq;
import scala.math.package$;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.DoubleRef;
import scala.runtime.IntRef;
import scala.runtime.RichInt$;

@DeveloperApi
@ScalaSignature(bytes="\u0006\u0001\u00113A!\u0001\u0002\u0001\u001b\t\u0001Bj\\4jgRL7m\u0012:bI&,g\u000e\u001e\u0006\u0003\u0007\u0011\tAb\u001c9uS6L'0\u0019;j_:T!!\u0002\u0004\u0002\u000b5dG.\u001b2\u000b\u0005\u001dA\u0011!B:qCJ\\'BA\u0005\u000b\u0003\u0019\t\u0007/Y2iK*\t1\"A\u0002pe\u001e\u001c\u0001a\u0005\u0002\u0001\u001dA\u0011q\u0002E\u0007\u0002\u0005%\u0011\u0011C\u0001\u0002\t\u000fJ\fG-[3oi\"A1\u0003\u0001B\u0001B\u0003%A#\u0001\u0006ok6\u001cE.Y:tKN\u0004\"!\u0006\r\u000e\u0003YQ\u0011aF\u0001\u0006g\u000e\fG.Y\u0005\u00033Y\u00111!\u00138u\u0011\u0015Y\u0002\u0001\"\u0001\u001d\u0003\u0019a\u0014N\\5u}Q\u0011QD\b\t\u0003\u001f\u0001AQa\u0005\u000eA\u0002QAQa\u0007\u0001\u0005\u0002\u0001\"\u0012!\b\u0005\u0006E\u0001!\teI\u0001\bG>l\u0007/\u001e;f)\u0011!\u0003G\r\u001b\u0011\tU)s%L\u0005\u0003MY\u0011a\u0001V;qY\u0016\u0014\u0004C\u0001\u0015,\u001b\u0005I#B\u0001\u0016\u0005\u0003\u0019a\u0017N\\1mO&\u0011A&\u000b\u0002\u0007-\u0016\u001cGo\u001c:\u0011\u0005Uq\u0013BA\u0018\u0017\u0005\u0019!u.\u001e2mK\")\u0011'\ta\u0001O\u0005!A-\u0019;b\u0011\u0015\u0019\u0014\u00051\u0001.\u0003\u0015a\u0017MY3m\u0011\u0015)\u0014\u00051\u0001(\u0003\u001d9X-[4iiNDQA\t\u0001\u0005B]\"R!\f\u001d:umBQ!\r\u001cA\u0002\u001dBQa\r\u001cA\u00025BQ!\u000e\u001cA\u0002\u001dBQ\u0001\u0010\u001cA\u0002\u001d\n1bY;n\u000fJ\fG-[3oi\"\u0012\u0001A\u0010\t\u0003\u007f\tk\u0011\u0001\u0011\u0006\u0003\u0003\u001a\t!\"\u00198o_R\fG/[8o\u0013\t\u0019\u0005I\u0001\u0007EKZ,Gn\u001c9fe\u0006\u0003\u0018\u000e")
public class LogisticGradient
extends Gradient {
    private final int numClasses;

    @Override
    public Tuple2<Vector, Object> compute(Vector data, double label, Vector weights2) {
        Vector gradient2 = Vectors$.MODULE$.zeros(weights2.size());
        double loss2 = this.compute(data, label, weights2, gradient2);
        return new Tuple2((Object)gradient2, (Object)BoxesRunTime.boxToDouble((double)loss2));
    }

    @Override
    public double compute(Vector data, double label, Vector weights2, Vector cumGradient) {
        double d;
        int dataSize = data.size();
        Predef$.MODULE$.require(weights2.size() % dataSize == 0 && this.numClasses == weights2.size() / dataSize + 1);
        int n = this.numClasses;
        switch (n) {
            default: {
                Vector vector = weights2;
                if (vector instanceof DenseVector) {
                    double[] dArray;
                    DenseVector denseVector = (DenseVector)vector;
                    double[] weightsArray = dArray = denseVector.values();
                    Vector vector2 = cumGradient;
                    if (vector2 instanceof DenseVector) {
                        double loss2;
                        double[] dArray2;
                        DenseVector denseVector2 = (DenseVector)vector2;
                        double[] cumGradientArray = dArray2 = denseVector2.values();
                        DoubleRef marginY = new DoubleRef(0.0);
                        DoubleRef maxMargin = new DoubleRef(Double.NEGATIVE_INFINITY);
                        IntRef maxMarginIndex = new IntRef(0);
                        double[] margins = (double[])Array$.MODULE$.tabulate(this.numClasses - 1, (Function1)new Serializable(this, data, label, dataSize, weightsArray, marginY, maxMargin, maxMarginIndex){
                            public static final long serialVersionUID = 0L;
                            private final Vector data$1;
                            private final double label$1;
                            public final int dataSize$1;
                            public final double[] weightsArray$1;
                            private final DoubleRef marginY$1;
                            private final DoubleRef maxMargin$1;
                            private final IntRef maxMarginIndex$1;

                            public final double apply(int i) {
                                return this.apply$mcDI$sp(i);
                            }

                            public double apply$mcDI$sp(int i) {
                                DoubleRef margin = new DoubleRef(0.0);
                                this.data$1.foreachActive((Function2<Object, Object, BoxedUnit>)new Serializable(this, margin, i){
                                    public static final long serialVersionUID = 0L;
                                    private final /* synthetic */ $anonfun$1 $outer;
                                    private final DoubleRef margin$1;
                                    private final int i$1;

                                    public final void apply(int index2, double value) {
                                        this.apply$mcVID$sp(index2, value);
                                    }

                                    public void apply$mcVID$sp(int index2, double value) {
                                        if (value != 0.0) {
                                            this.margin$1.elem += value * this.$outer.weightsArray$1[this.i$1 * this.$outer.dataSize$1 + index2];
                                        }
                                    }
                                    {
                                        if ($outer == null) {
                                            throw new NullPointerException();
                                        }
                                        this.$outer = $outer;
                                        this.margin$1 = margin$1;
                                        this.i$1 = i$1;
                                    }
                                });
                                if (i == (int)this.label$1 - 1) {
                                    this.marginY$1.elem = margin.elem;
                                }
                                if (margin.elem > this.maxMargin$1.elem) {
                                    this.maxMargin$1.elem = margin.elem;
                                    this.maxMarginIndex$1.elem = i;
                                }
                                return margin.elem;
                            }
                            {
                                this.data$1 = data$1;
                                this.label$1 = label$1;
                                this.dataSize$1 = dataSize$1;
                                this.weightsArray$1 = weightsArray$1;
                                this.marginY$1 = marginY$1;
                                this.maxMargin$1 = maxMargin$1;
                                this.maxMarginIndex$1 = maxMarginIndex$1;
                            }
                        }, ClassTag$.MODULE$.Double());
                        DoubleRef temp = new DoubleRef(0.0);
                        if (maxMargin.elem > 0.0) {
                            RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), this.numClasses - 1).foreach$mVc$sp((Function1)new Serializable(this, maxMargin, maxMarginIndex, margins, temp){
                                public static final long serialVersionUID = 0L;
                                private final DoubleRef maxMargin$1;
                                private final IntRef maxMarginIndex$1;
                                private final double[] margins$1;
                                private final DoubleRef temp$1;

                                public final void apply(int i) {
                                    this.apply$mcVI$sp(i);
                                }

                                public void apply$mcVI$sp(int i) {
                                    this.margins$1[i] = this.margins$1[i] - this.maxMargin$1.elem;
                                    this.temp$1.elem = i == this.maxMarginIndex$1.elem ? (this.temp$1.elem += package$.MODULE$.exp(-this.maxMargin$1.elem)) : (this.temp$1.elem += package$.MODULE$.exp(this.margins$1[i]));
                                }
                                {
                                    this.maxMargin$1 = maxMargin$1;
                                    this.maxMarginIndex$1 = maxMarginIndex$1;
                                    this.margins$1 = margins$1;
                                    this.temp$1 = temp$1;
                                }
                            });
                        } else {
                            RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), this.numClasses - 1).foreach$mVc$sp((Function1)new Serializable(this, margins, temp){
                                public static final long serialVersionUID = 0L;
                                private final double[] margins$1;
                                private final DoubleRef temp$1;

                                public final void apply(int i) {
                                    this.apply$mcVI$sp(i);
                                }

                                public void apply$mcVI$sp(int i) {
                                    this.temp$1.elem += package$.MODULE$.exp(this.margins$1[i]);
                                }
                                {
                                    this.margins$1 = margins$1;
                                    this.temp$1 = temp$1;
                                }
                            });
                        }
                        double sum = temp.elem;
                        RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), this.numClasses - 1).foreach$mVc$sp((Function1)new Serializable(this, data, label, dataSize, cumGradientArray, margins, sum){
                            public static final long serialVersionUID = 0L;
                            private final Vector data$1;
                            private final double label$1;
                            public final int dataSize$1;
                            public final double[] cumGradientArray$1;
                            private final double[] margins$1;
                            private final double sum$1;

                            public final void apply(int i) {
                                this.apply$mcVI$sp(i);
                            }

                            public void apply$mcVI$sp(int i) {
                                double multiplier = package$.MODULE$.exp(this.margins$1[i]) / (this.sum$1 + 1.0) - (this.label$1 != 0.0 && this.label$1 == (double)(i + 1) ? 1.0 : 0.0);
                                this.data$1.foreachActive((Function2<Object, Object, BoxedUnit>)new Serializable(this, multiplier, i){
                                    public static final long serialVersionUID = 0L;
                                    private final /* synthetic */ $anonfun$compute$1 $outer;
                                    private final double multiplier$1;
                                    private final int i$2;

                                    public final void apply(int index2, double value) {
                                        this.apply$mcVID$sp(index2, value);
                                    }

                                    public void apply$mcVID$sp(int index2, double value) {
                                        if (value != 0.0) {
                                            int n = this.i$2 * this.$outer.dataSize$1 + index2;
                                            this.$outer.cumGradientArray$1[n] = this.$outer.cumGradientArray$1[n] + this.multiplier$1 * value;
                                        }
                                    }
                                    {
                                        if ($outer == null) {
                                            throw new NullPointerException();
                                        }
                                        this.$outer = $outer;
                                        this.multiplier$1 = multiplier$1;
                                        this.i$2 = i$2;
                                    }
                                });
                            }
                            {
                                this.data$1 = data$1;
                                this.label$1 = label$1;
                                this.dataSize$1 = dataSize$1;
                                this.cumGradientArray$1 = cumGradientArray$1;
                                this.margins$1 = margins$1;
                                this.sum$1 = sum$1;
                            }
                        });
                        double d2 = loss2 = label > 0.0 ? package$.MODULE$.log1p(sum) - marginY.elem : package$.MODULE$.log1p(sum);
                        if (maxMargin.elem > 0.0) {
                            d = loss2 + maxMargin.elem;
                            break;
                        }
                        d = loss2;
                        break;
                    }
                    throw new IllegalArgumentException(new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"cumGradient only supports dense vector but got type ", "."})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{cumGradient.getClass()})));
                }
                throw new IllegalArgumentException(new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"weights only supports dense vector but got type ", "."})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{weights2.getClass()})));
            }
            case 2: {
                double margin = -1.0 * BLAS$.MODULE$.dot(data, weights2);
                double multiplier = 1.0 / (1.0 + package$.MODULE$.exp(margin)) - label;
                BLAS$.MODULE$.axpy(multiplier, data, cumGradient);
                d = label > 0.0 ? MLUtils$.MODULE$.log1pExp(margin) : MLUtils$.MODULE$.log1pExp(margin) - margin;
            }
        }
        return d;
    }

    public LogisticGradient(int numClasses) {
        this.numClasses = numClasses;
    }

    public LogisticGradient() {
        this(2);
    }
}

