/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.ml.classification;

import org.apache.spark.mllib.linalg.BLAS$;
import org.apache.spark.mllib.linalg.DenseVector;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors$;
import org.apache.spark.mllib.util.MLUtils$;
import scala.Array$;
import scala.Function0;
import scala.Function2;
import scala.NotImplementedError;
import scala.Predef$;
import scala.Serializable;
import scala.StringContext;
import scala.collection.Seq;
import scala.collection.immutable.Nil$;
import scala.collection.mutable.StringBuilder;
import scala.math.package$;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.DoubleRef;

@ScalaSignature(bytes="\u0006\u0001i4A!\u0001\u0002\u0005\u001b\t\u0011Bj\\4jgRL7-Q4he\u0016<\u0017\r^8s\u0015\t\u0019A!\u0001\bdY\u0006\u001c8/\u001b4jG\u0006$\u0018n\u001c8\u000b\u0005\u00151\u0011AA7m\u0015\t9\u0001\"A\u0003ta\u0006\u00148N\u0003\u0002\n\u0015\u00051\u0011\r]1dQ\u0016T\u0011aC\u0001\u0004_J<7\u0001A\n\u0004\u00019!\u0002CA\b\u0013\u001b\u0005\u0001\"\"A\t\u0002\u000bM\u001c\u0017\r\\1\n\u0005M\u0001\"AB!osJ+g\r\u0005\u0002\u0010+%\u0011a\u0003\u0005\u0002\r'\u0016\u0014\u0018.\u00197ju\u0006\u0014G.\u001a\u0005\t1\u0001\u0011\t\u0011)A\u00053\u00059q/Z5hQR\u001c\bC\u0001\u000e \u001b\u0005Y\"B\u0001\u000f\u001e\u0003\u0019a\u0017N\\1mO*\u0011aDB\u0001\u0006[2d\u0017NY\u0005\u0003Am\u0011aAV3di>\u0014\b\u0002\u0003\u0012\u0001\u0005\u0003\u0005\u000b\u0011B\u0012\u0002\u00159,Xn\u00117bgN,7\u000f\u0005\u0002\u0010I%\u0011Q\u0005\u0005\u0002\u0004\u0013:$\b\u0002C\u0014\u0001\u0005\u0003\u0005\u000b\u0011\u0002\u0015\u0002\u0019\u0019LG/\u00138uKJ\u001cW\r\u001d;\u0011\u0005=I\u0013B\u0001\u0016\u0011\u0005\u001d\u0011un\u001c7fC:D\u0001\u0002\f\u0001\u0003\u0002\u0003\u0006I!L\u0001\fM\u0016\fG/\u001e:fgN#H\rE\u0002\u0010]AJ!a\f\t\u0003\u000b\u0005\u0013(/Y=\u0011\u0005=\t\u0014B\u0001\u001a\u0011\u0005\u0019!u.\u001e2mK\"AA\u0007\u0001B\u0001B\u0003%Q&\u0001\u0007gK\u0006$XO]3t\u001b\u0016\fg\u000eC\u00037\u0001\u0011\u0005q'\u0001\u0004=S:LGO\u0010\u000b\u0007qiZD(\u0010 \u0011\u0005e\u0002Q\"\u0001\u0002\t\u000ba)\u0004\u0019A\r\t\u000b\t*\u0004\u0019A\u0012\t\u000b\u001d*\u0004\u0019\u0001\u0015\t\u000b1*\u0004\u0019A\u0017\t\u000bQ*\u0004\u0019A\u0017\t\u000f\u0001\u0003\u0001\u0019!C\u0005\u0003\u0006AAo\u001c;bY\u000esG/F\u0001C!\ty1)\u0003\u0002E!\t!Aj\u001c8h\u0011\u001d1\u0005\u00011A\u0005\n\u001d\u000bA\u0002^8uC2\u001ce\u000e^0%KF$\"\u0001S&\u0011\u0005=I\u0015B\u0001&\u0011\u0005\u0011)f.\u001b;\t\u000f1+\u0015\u0011!a\u0001\u0005\u0006\u0019\u0001\u0010J\u0019\t\r9\u0003\u0001\u0015)\u0003C\u0003%!x\u000e^1m\u0007:$\b\u0005C\u0004Q\u0001\u0001\u0007I\u0011B)\u0002\u000f1|7o]*v[V\t\u0001\u0007C\u0004T\u0001\u0001\u0007I\u0011\u0002+\u0002\u00171|7o]*v[~#S-\u001d\u000b\u0003\u0011VCq\u0001\u0014*\u0002\u0002\u0003\u0007\u0001\u0007\u0003\u0004X\u0001\u0001\u0006K\u0001M\u0001\tY>\u001c8oU;nA!9\u0011\f\u0001b\u0001\n\u0013Q\u0016\u0001D<fS\u001eDGo]!se\u0006LX#A\u0017\t\rq\u0003\u0001\u0015!\u0003.\u000359X-[4iiN\f%O]1zA!9a\f\u0001b\u0001\n\u0013y\u0016a\u00013j[V\t1\u0005\u0003\u0004b\u0001\u0001\u0006IaI\u0001\u0005I&l\u0007\u0005C\u0004d\u0001\t\u0007I\u0011\u0002.\u0002!\u001d\u0014\u0018\rZ5f]R\u001cV/\\!se\u0006L\bBB3\u0001A\u0003%Q&A\the\u0006$\u0017.\u001a8u'Vl\u0017I\u001d:bs\u0002BQa\u001a\u0001\u0005\u0002!\f1!\u00193e)\rI'\u000e\\\u0007\u0002\u0001!)1N\u001aa\u0001a\u0005)A.\u00192fY\")QN\u001aa\u00013\u0005!A-\u0019;b\u0011\u0015y\u0007\u0001\"\u0001q\u0003\u0015iWM]4f)\tI\u0017\u000fC\u0003s]\u0002\u0007\u0001(A\u0003pi\",'\u000fC\u0003u\u0001\u0011\u0005\u0011)A\u0003d_VtG\u000fC\u0003w\u0001\u0011\u0005\u0011+\u0001\u0003m_N\u001c\b\"\u0002=\u0001\t\u0003I\u0018\u0001C4sC\u0012LWM\u001c;\u0016\u0003e\u0001")
public class LogisticAggregator
implements Serializable {
    private final int numClasses;
    private final boolean fitIntercept;
    public final double[] org$apache$spark$ml$classification$LogisticAggregator$$featuresStd;
    private long totalCnt;
    private double lossSum;
    private final double[] weightsArray;
    private final int org$apache$spark$ml$classification$LogisticAggregator$$dim;
    private final double[] gradientSumArray;

    private long totalCnt() {
        return this.totalCnt;
    }

    private void totalCnt_$eq(long x$1) {
        this.totalCnt = x$1;
    }

    private double lossSum() {
        return this.lossSum;
    }

    private void lossSum_$eq(double x$1) {
        this.lossSum = x$1;
    }

    private double[] weightsArray() {
        return this.weightsArray;
    }

    public int org$apache$spark$ml$classification$LogisticAggregator$$dim() {
        return this.org$apache$spark$ml$classification$LogisticAggregator$$dim;
    }

    private double[] gradientSumArray() {
        return this.gradientSumArray;
    }

    public LogisticAggregator add(double label, Vector data) {
        Predef$.MODULE$.require(this.org$apache$spark$ml$classification$LogisticAggregator$$dim() == data.size(), (Function0)new Serializable(this, data){
            public static final long serialVersionUID = 0L;
            private final /* synthetic */ LogisticAggregator $outer;
            private final Vector data$1;

            public final String apply() {
                return new StringBuilder().append((Object)new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"Dimensions mismatch when adding new sample."})).s((Seq)Nil$.MODULE$)).append((Object)new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{" Expecting ", " but got ", "."})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToInteger((int)this.$outer.org$apache$spark$ml$classification$LogisticAggregator$$dim()), BoxesRunTime.boxToInteger((int)this.data$1.size())}))).toString();
            }
            {
                if ($outer == null) {
                    throw new NullPointerException();
                }
                this.$outer = $outer;
                this.data$1 = data$1;
            }
        });
        double[] localWeightsArray = this.weightsArray();
        double[] localGradientSumArray = this.gradientSumArray();
        int n = this.numClasses;
        switch (n) {
            default: {
                NotImplementedError notImplementedError = new NotImplementedError("LogisticRegression with ElasticNet in ML package only supports binary classification for now.");
                break;
            }
            case 2: {
                DoubleRef sum = new DoubleRef(0.0);
                data.foreachActive((Function2<Object, Object, BoxedUnit>)new Serializable(this, localWeightsArray, sum){
                    public static final long serialVersionUID = 0L;
                    private final /* synthetic */ LogisticAggregator $outer;
                    private final double[] localWeightsArray$1;
                    private final DoubleRef sum$1;

                    public final void apply(int index, double value) {
                        this.apply$mcVID$sp(index, value);
                    }

                    public void apply$mcVID$sp(int index, double value) {
                        if (this.$outer.org$apache$spark$ml$classification$LogisticAggregator$$featuresStd[index] != 0.0 && value != 0.0) {
                            this.sum$1.elem += this.localWeightsArray$1[index] * (value / this.$outer.org$apache$spark$ml$classification$LogisticAggregator$$featuresStd[index]);
                        }
                    }
                    {
                        if ($outer == null) {
                            throw new NullPointerException();
                        }
                        this.$outer = $outer;
                        this.localWeightsArray$1 = localWeightsArray$1;
                        this.sum$1 = sum$1;
                    }
                });
                double margin = -(sum.elem + (this.fitIntercept ? localWeightsArray[this.org$apache$spark$ml$classification$LogisticAggregator$$dim()] : 0.0));
                double multiplier = 1.0 / (1.0 + package$.MODULE$.exp(margin)) - label;
                data.foreachActive((Function2<Object, Object, BoxedUnit>)new Serializable(this, localGradientSumArray, multiplier){
                    public static final long serialVersionUID = 0L;
                    private final /* synthetic */ LogisticAggregator $outer;
                    private final double[] localGradientSumArray$1;
                    private final double multiplier$1;

                    public final void apply(int index, double value) {
                        this.apply$mcVID$sp(index, value);
                    }

                    public void apply$mcVID$sp(int index, double value) {
                        if (this.$outer.org$apache$spark$ml$classification$LogisticAggregator$$featuresStd[index] != 0.0 && value != 0.0) {
                            this.localGradientSumArray$1[index] = this.localGradientSumArray$1[index] + this.multiplier$1 * (value / this.$outer.org$apache$spark$ml$classification$LogisticAggregator$$featuresStd[index]);
                        }
                    }
                    {
                        if ($outer == null) {
                            throw new NullPointerException();
                        }
                        this.$outer = $outer;
                        this.localGradientSumArray$1 = localGradientSumArray$1;
                        this.multiplier$1 = multiplier$1;
                    }
                });
                if (this.fitIntercept) {
                    localGradientSumArray[this.org$apache$spark$ml$classification$LogisticAggregator$$dim()] = localGradientSumArray[this.org$apache$spark$ml$classification$LogisticAggregator$$dim()] + multiplier;
                }
                if (label > 0.0) {
                    this.lossSum_$eq(this.lossSum() + MLUtils$.MODULE$.log1pExp(margin));
                } else {
                    this.lossSum_$eq(this.lossSum() + (MLUtils$.MODULE$.log1pExp(margin) - margin));
                }
                NotImplementedError notImplementedError = BoxedUnit.UNIT;
            }
        }
        this.totalCnt_$eq(this.totalCnt() + 1L);
        return this;
    }

    public LogisticAggregator merge(LogisticAggregator other) {
        Predef$.MODULE$.require(this.org$apache$spark$ml$classification$LogisticAggregator$$dim() == other.org$apache$spark$ml$classification$LogisticAggregator$$dim(), (Function0)new Serializable(this, other){
            public static final long serialVersionUID = 0L;
            private final /* synthetic */ LogisticAggregator $outer;
            private final LogisticAggregator other$1;

            public final String apply() {
                return new StringBuilder().append((Object)new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"Dimensions mismatch when merging with another "})).s((Seq)Nil$.MODULE$)).append((Object)new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"LeastSquaresAggregator. Expecting ", " but got ", "."})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToInteger((int)this.$outer.org$apache$spark$ml$classification$LogisticAggregator$$dim()), BoxesRunTime.boxToInteger((int)this.other$1.org$apache$spark$ml$classification$LogisticAggregator$$dim())}))).toString();
            }
            {
                if ($outer == null) {
                    throw new NullPointerException();
                }
                this.$outer = $outer;
                this.other$1 = other$1;
            }
        });
        if (other.totalCnt() != 0L) {
            this.totalCnt_$eq(this.totalCnt() + other.totalCnt());
            this.lossSum_$eq(this.lossSum() + other.lossSum());
            double[] localThisGradientSumArray = this.gradientSumArray();
            double[] localOtherGradientSumArray = other.gradientSumArray();
            int len = localThisGradientSumArray.length;
            for (int i = 0; i < len; ++i) {
                int n = i;
                localThisGradientSumArray[n] = localThisGradientSumArray[n] + localOtherGradientSumArray[i];
            }
        }
        return this;
    }

    public long count() {
        return this.totalCnt();
    }

    public double loss() {
        return this.lossSum() / (double)this.totalCnt();
    }

    /*
     * WARNING - void declaration
     */
    public Vector gradient() {
        void var1_1;
        Vector result = Vectors$.MODULE$.dense((double[])this.gradientSumArray().clone());
        BLAS$.MODULE$.scal(1.0 / (double)this.totalCnt(), result);
        return var1_1;
    }

    public LogisticAggregator(Vector weights2, int numClasses, boolean fitIntercept, double[] featuresStd, double[] featuresMean) {
        this.numClasses = numClasses;
        this.fitIntercept = fitIntercept;
        this.org$apache$spark$ml$classification$LogisticAggregator$$featuresStd = featuresStd;
        this.totalCnt = 0L;
        this.lossSum = 0.0;
        Vector vector = weights2;
        if (vector instanceof DenseVector) {
            DenseVector denseVector = (DenseVector)vector;
            double[] dArray = denseVector.values();
            this.weightsArray = dArray;
            this.org$apache$spark$ml$classification$LogisticAggregator$$dim = fitIntercept ? this.weightsArray().length - 1 : this.weightsArray().length;
            this.gradientSumArray = (double[])Array$.MODULE$.ofDim(this.weightsArray().length, ClassTag$.MODULE$.Double());
            return;
        }
        throw new IllegalArgumentException(new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"weights only supports dense vector but got type ", "."})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{weights2.getClass()})));
    }
}

