/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.ml.optim.aggregator;

import java.io.Serializable;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.internal.Logging;
import org.apache.spark.ml.feature.Instance;
import org.apache.spark.ml.linalg.DenseVector;
import org.apache.spark.ml.linalg.DenseVector$;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.optim.aggregator.DifferentiableLossAggregator;
import org.apache.spark.mllib.util.MLUtils$;
import org.slf4j.Logger;
import scala.Function0;
import scala.Function1;
import scala.Function2;
import scala.MatchError;
import scala.Option;
import scala.Predef$;
import scala.collection.mutable.ArrayOps;
import scala.math.package$;
import scala.reflect.ScalaSignature;
import scala.runtime.DoubleRef;
import scala.runtime.java8.JFunction1;
import scala.runtime.java8.JFunction2;

@ScalaSignature(bytes="\u0006\u0001\u0005\u001da!B\n\u0015\u0001a\u0001\u0003\u0002\u0003\u001d\u0001\u0005\u0003\u0005\u000b\u0011\u0002\u001e\t\u0011\u0019\u0003!\u0011!Q\u0001\n\u001dC\u0001B\u0013\u0001\u0003\u0002\u0003\u0006Ia\u0013\u0005\t\u001d\u0002\u0011\t\u0011)A\u0005\u0017\"Aq\n\u0001B\u0001B\u0003%\u0001\u000bC\u0003X\u0001\u0011\u0005\u0001\fC\u0004`\u0001\t\u0007I\u0011\u00021\t\r\u0005\u0004\u0001\u0015!\u0003H\u0011\u001d\u0011\u0007A1A\u0005\n\u0001Daa\u0019\u0001!\u0002\u00139\u0005b\u00023\u0001\u0005\u0004%I\u0001\u0019\u0005\u0007K\u0002\u0001\u000b\u0011B$\t\u000f\u0019\u0004!\u0019!C)A\"1q\r\u0001Q\u0001\n\u001dC\u0001\u0002\u001b\u0001\t\u0006\u0004%I!\u001b\u0005\u0006]\u0002!Ia\u001c\u0005\u0006s\u0002!IA\u001f\u0005\u0006}\u0002!\ta \u0002\u0013\u0019><\u0017n\u001d;jG\u0006;wM]3hCR|'O\u0003\u0002\u0016-\u0005Q\u0011mZ4sK\u001e\fGo\u001c:\u000b\u0005]A\u0012!B8qi&l'BA\r\u001b\u0003\tiGN\u0003\u0002\u001c9\u0005)1\u000f]1sW*\u0011QDH\u0001\u0007CB\f7\r[3\u000b\u0003}\t1a\u001c:h'\u0011\u0001\u0011e\n\u001a\u0011\u0005\t*S\"A\u0012\u000b\u0003\u0011\nQa]2bY\u0006L!AJ\u0012\u0003\r\u0005s\u0017PU3g!\u0011A\u0013fK\u0019\u000e\u0003QI!A\u000b\u000b\u00039\u0011KgMZ3sK:$\u0018.\u00192mK2{7o]!hOJ,w-\u0019;peB\u0011AfL\u0007\u0002[)\u0011a\u0006G\u0001\bM\u0016\fG/\u001e:f\u0013\t\u0001TF\u0001\u0005J]N$\u0018M\\2f!\tA\u0003\u0001\u0005\u00024m5\tAG\u0003\u000265\u0005A\u0011N\u001c;fe:\fG.\u0003\u00028i\t9Aj\\4hS:<\u0017!\u00042d\r\u0016\fG/\u001e:fgN#Hm\u0001\u0001\u0011\u0007mr\u0004)D\u0001=\u0015\ti$$A\u0005ce>\fGmY1ti&\u0011q\b\u0010\u0002\n\u0005J|\u0017\rZ2bgR\u00042AI!D\u0013\t\u00115EA\u0003BeJ\f\u0017\u0010\u0005\u0002#\t&\u0011Qi\t\u0002\u0007\t>,(\r\\3\u0002\u00159,Xn\u00117bgN,7\u000f\u0005\u0002#\u0011&\u0011\u0011j\t\u0002\u0004\u0013:$\u0018\u0001\u00044ji&sG/\u001a:dKB$\bC\u0001\u0012M\u0013\ti5EA\u0004C_>dW-\u00198\u0002\u00175,H\u000e^5o_6L\u0017\r\\\u0001\u000fE\u000e\u001cu.\u001a4gS\u000eLWM\u001c;t!\rYd(\u0015\t\u0003%Vk\u0011a\u0015\u0006\u0003)b\ta\u0001\\5oC2<\u0017B\u0001,T\u0005\u00191Vm\u0019;pe\u00061A(\u001b8jiz\"R!W.];z#\"!\r.\t\u000b=3\u0001\u0019\u0001)\t\u000ba2\u0001\u0019\u0001\u001e\t\u000b\u00193\u0001\u0019A$\t\u000b)3\u0001\u0019A&\t\u000b93\u0001\u0019A&\u0002\u00179,XNR3biV\u0014Xm]\u000b\u0002\u000f\u0006aa.^7GK\u0006$XO]3tA\u0005Ab.^7GK\u0006$XO]3t!2,8/\u00138uKJ\u001cW\r\u001d;\u000239,XNR3biV\u0014Xm\u001d)mkNLe\u000e^3sG\u0016\u0004H\u000fI\u0001\u0010G>,gMZ5dS\u0016tGoU5{K\u0006\u00012m\\3gM&\u001c\u0017.\u001a8u'&TX\rI\u0001\u0004I&l\u0017\u0001\u00023j[\u0002\n\u0011cY8fM\u001aL7-[3oiN\f%O]1z+\u0005\u0001\u0005FA\bl!\t\u0011C.\u0003\u0002nG\tIAO]1og&,g\u000e^\u0001\u0014E&t\u0017M]=Va\u0012\fG/Z%o!2\f7-\u001a\u000b\u0005aN,x\u000f\u0005\u0002#c&\u0011!o\t\u0002\u0005+:LG\u000fC\u0003u!\u0001\u0007\u0011+\u0001\u0005gK\u0006$XO]3t\u0011\u00151\b\u00031\u0001D\u0003\u00199X-[4ii\")\u0001\u0010\u0005a\u0001\u0007\u0006)A.\u00192fY\u0006AR.\u001e7uS:|W.[1m+B$\u0017\r^3J]Bc\u0017mY3\u0015\tA\\H0 \u0005\u0006iF\u0001\r!\u0015\u0005\u0006mF\u0001\ra\u0011\u0005\u0006qF\u0001\raQ\u0001\u0004C\u0012$G\u0003BA\u0001\u0003\u0007i\u0011\u0001\u0001\u0005\u0007\u0003\u000b\u0011\u0002\u0019A\u0016\u0002\u0011%t7\u000f^1oG\u0016\u0004")
public class LogisticAggregator
implements DifferentiableLossAggregator<Instance, LogisticAggregator>,
Logging {
    private transient double[] coefficientsArray;
    private final Broadcast<double[]> bcFeaturesStd;
    private final int numClasses;
    private final boolean fitIntercept;
    private final boolean multinomial;
    private final Broadcast<Vector> bcCoefficients;
    private final int numFeatures;
    private final int numFeaturesPlusIntercept;
    private final int coefficientSize;
    private final int dim;
    private transient Logger org$apache$spark$internal$Logging$$log_;
    private double weightSum;
    private double lossSum;
    private double[] gradientSumArray;
    private volatile transient boolean bitmap$trans$0;
    private volatile boolean bitmap$0;

    public String logName() {
        return Logging.logName$((Logging)this);
    }

    public Logger log() {
        return Logging.log$((Logging)this);
    }

    public void logInfo(Function0<String> msg) {
        Logging.logInfo$((Logging)this, msg);
    }

    public void logDebug(Function0<String> msg) {
        Logging.logDebug$((Logging)this, msg);
    }

    public void logTrace(Function0<String> msg) {
        Logging.logTrace$((Logging)this, msg);
    }

    public void logWarning(Function0<String> msg) {
        Logging.logWarning$((Logging)this, msg);
    }

    public void logError(Function0<String> msg) {
        Logging.logError$((Logging)this, msg);
    }

    public void logInfo(Function0<String> msg, Throwable throwable) {
        Logging.logInfo$((Logging)this, msg, (Throwable)throwable);
    }

    public void logDebug(Function0<String> msg, Throwable throwable) {
        Logging.logDebug$((Logging)this, msg, (Throwable)throwable);
    }

    public void logTrace(Function0<String> msg, Throwable throwable) {
        Logging.logTrace$((Logging)this, msg, (Throwable)throwable);
    }

    public void logWarning(Function0<String> msg, Throwable throwable) {
        Logging.logWarning$((Logging)this, msg, (Throwable)throwable);
    }

    public void logError(Function0<String> msg, Throwable throwable) {
        Logging.logError$((Logging)this, msg, (Throwable)throwable);
    }

    public boolean isTraceEnabled() {
        return Logging.isTraceEnabled$((Logging)this);
    }

    public void initializeLogIfNecessary(boolean isInterpreter) {
        Logging.initializeLogIfNecessary$((Logging)this, (boolean)isInterpreter);
    }

    public boolean initializeLogIfNecessary(boolean isInterpreter, boolean silent) {
        return Logging.initializeLogIfNecessary$((Logging)this, (boolean)isInterpreter, (boolean)silent);
    }

    public boolean initializeLogIfNecessary$default$2() {
        return Logging.initializeLogIfNecessary$default$2$((Logging)this);
    }

    @Override
    public DifferentiableLossAggregator merge(DifferentiableLossAggregator other) {
        return DifferentiableLossAggregator.merge$(this, other);
    }

    @Override
    public Vector gradient() {
        return DifferentiableLossAggregator.gradient$(this);
    }

    @Override
    public double weight() {
        return DifferentiableLossAggregator.weight$(this);
    }

    @Override
    public double loss() {
        return DifferentiableLossAggregator.loss$(this);
    }

    public Logger org$apache$spark$internal$Logging$$log_() {
        return this.org$apache$spark$internal$Logging$$log_;
    }

    public void org$apache$spark$internal$Logging$$log__$eq(Logger x$1) {
        this.org$apache$spark$internal$Logging$$log_ = x$1;
    }

    @Override
    public double weightSum() {
        return this.weightSum;
    }

    @Override
    public void weightSum_$eq(double x$1) {
        this.weightSum = x$1;
    }

    @Override
    public double lossSum() {
        return this.lossSum;
    }

    @Override
    public void lossSum_$eq(double x$1) {
        this.lossSum = x$1;
    }

    private double[] gradientSumArray$lzycompute() {
        LogisticAggregator logisticAggregator = this;
        synchronized (logisticAggregator) {
            if (!this.bitmap$0) {
                this.gradientSumArray = DifferentiableLossAggregator.gradientSumArray$(this);
                this.bitmap$0 = true;
            }
        }
        return this.gradientSumArray;
    }

    @Override
    public double[] gradientSumArray() {
        return !this.bitmap$0 ? this.gradientSumArray$lzycompute() : this.gradientSumArray;
    }

    private int numFeatures() {
        return this.numFeatures;
    }

    private int numFeaturesPlusIntercept() {
        return this.numFeaturesPlusIntercept;
    }

    private int coefficientSize() {
        return this.coefficientSize;
    }

    @Override
    public int dim() {
        return this.dim;
    }

    private double[] coefficientsArray$lzycompute() {
        LogisticAggregator logisticAggregator = this;
        synchronized (logisticAggregator) {
            if (!this.bitmap$trans$0) {
                double[] values;
                DenseVector denseVector;
                Option option;
                Vector vector = (Vector)this.bcCoefficients.value();
                if (!(vector instanceof DenseVector) || (option = DenseVector$.MODULE$.unapply(denseVector = (DenseVector)vector)).isEmpty()) {
                    throw new IllegalArgumentException(new StringBuilder(55).append("coefficients only supports dense vector but ").append("got type ").append(this.bcCoefficients.value().getClass()).append(".)").toString());
                }
                double[] dArray = values = (double[])option.get();
                this.coefficientsArray = dArray;
                this.bitmap$trans$0 = true;
            }
        }
        return this.coefficientsArray;
    }

    private double[] coefficientsArray() {
        return !this.bitmap$trans$0 ? this.coefficientsArray$lzycompute() : this.coefficientsArray;
    }

    private void binaryUpdateInPlace(Vector features, double weight, double label) {
        double[] localFeaturesStd = (double[])this.bcFeaturesStd.value();
        double[] localCoefficients = this.coefficientsArray();
        double[] localGradientArray = this.gradientSumArray();
        DoubleRef sum = DoubleRef.create((double)0.0);
        features.foreachActive((Function2)(JFunction2.mcVID.sp & Serializable & scala.Serializable)(index, value) -> {
            block0: {
                if (localFeaturesStd[index] == 0.0 || value == 0.0) break block0;
                sum$1.elem += localCoefficients[index] * value / localFeaturesStd[index];
            }
        });
        if (this.fitIntercept) {
            sum.elem += localCoefficients[this.numFeaturesPlusIntercept() - 1];
        }
        double margin = -sum.elem;
        double multiplier = weight * (1.0 / (1.0 + package$.MODULE$.exp(margin)) - label);
        features.foreachActive((Function2)(JFunction2.mcVID.sp & Serializable & scala.Serializable)(index, value) -> {
            block0: {
                if (localFeaturesStd[index] == 0.0 || value == 0.0) break block0;
                localGradientArray$1[index] = localGradientArray[index] + multiplier * value / localFeaturesStd[index];
            }
        });
        if (this.fitIntercept) {
            int n = this.numFeaturesPlusIntercept() - 1;
            localGradientArray[n] = localGradientArray[n] + multiplier;
        }
        if (label > 0.0) {
            this.lossSum_$eq(this.lossSum() + weight * MLUtils$.MODULE$.log1pExp(margin));
        } else {
            this.lossSum_$eq(this.lossSum() + weight * (MLUtils$.MODULE$.log1pExp(margin) - margin));
        }
    }

    private void multinomialUpdateInPlace(Vector features, double weight, double label) {
        double[] localFeaturesStd = (double[])this.bcFeaturesStd.value();
        double[] localCoefficients = this.coefficientsArray();
        double[] localGradientArray = this.gradientSumArray();
        double marginOfLabel = 0.0;
        double maxMargin = Double.NEGATIVE_INFINITY;
        double[] margins = new double[this.numClasses];
        features.foreachActive((Function2)(JFunction2.mcVID.sp & Serializable & scala.Serializable)(index, value) -> {
            if (localFeaturesStd[index] != 0.0 && value != 0.0) {
                double stdValue = value / localFeaturesStd[index];
                for (int j = 0; j < $this.numClasses; ++j) {
                    int n = j;
                    margins$1[n] = margins[n] + localCoefficients[index * $this.numClasses + j] * stdValue;
                }
            }
        });
        for (int i2 = 0; i2 < this.numClasses; ++i2) {
            if (this.fitIntercept) {
                int n = i2;
                margins[n] = margins[n] + localCoefficients[this.numClasses * this.numFeatures() + i2];
            }
            if (i2 == (int)label) {
                marginOfLabel = margins[i2];
            }
            if (!(margins[i2] > maxMargin)) continue;
            maxMargin = margins[i2];
        }
        double[] multipliers = new double[this.numClasses];
        double temp = 0.0;
        for (int i3 = 0; i3 < this.numClasses; ++i3) {
            if (maxMargin > 0.0) {
                int n = i3;
                margins[n] = margins[n] - maxMargin;
            }
            double exp = package$.MODULE$.exp(margins[i3]);
            temp += exp;
            multipliers[i3] = exp;
        }
        double sum = temp;
        new ArrayOps.ofDouble(Predef$.MODULE$.doubleArrayOps(margins)).indices().foreach$mVc$sp((Function1)(JFunction1.mcVI.sp & Serializable & scala.Serializable)i -> {
            multipliers$1[i] = multipliers[i] / sum - (label == (double)i ? 1.0 : 0.0);
        });
        features.foreachActive((Function2)(JFunction2.mcVID.sp & Serializable & scala.Serializable)(index, value) -> {
            if (localFeaturesStd[index] != 0.0 && value != 0.0) {
                double stdValue = value / localFeaturesStd[index];
                for (int j = 0; j < $this.numClasses; ++j) {
                    int n = index * $this.numClasses + j;
                    localGradientArray$2[n] = localGradientArray[n] + weight * multipliers[j] * stdValue;
                }
            }
        });
        if (this.fitIntercept) {
            for (int i4 = 0; i4 < this.numClasses; ++i4) {
                int n = this.numFeatures() * this.numClasses + i4;
                localGradientArray[n] = localGradientArray[n] + weight * multipliers[i4];
            }
        }
        double loss = maxMargin > 0.0 ? package$.MODULE$.log(sum) - marginOfLabel + maxMargin : package$.MODULE$.log(sum) - marginOfLabel;
        this.lossSum_$eq(this.lossSum() + weight * loss);
    }

    @Override
    public LogisticAggregator add(Instance instance) {
        double weight;
        Instance instance2 = instance;
        if (instance2 != null) {
            double label = instance2.label();
            weight = instance2.weight();
            Vector features = instance2.features();
            Predef$.MODULE$.require(this.numFeatures() == features.size(), (Function0 & Serializable & scala.Serializable)() -> new StringBuilder(66).append("Dimensions mismatch when adding new instance.").append(" Expecting ").append(this.numFeatures()).append(" but got ").append(features.size()).append(".").toString());
            Predef$.MODULE$.require(weight >= 0.0, (Function0 & Serializable & scala.Serializable)() -> new StringBuilder(34).append("instance weight, ").append(weight).append(" has to be >= 0.0").toString());
            if (weight == 0.0) {
                return this;
            }
            if (this.multinomial) {
                this.multinomialUpdateInPlace(features, weight, label);
            } else {
                this.binaryUpdateInPlace(features, weight, label);
            }
        } else {
            throw new MatchError((Object)instance2);
        }
        this.weightSum_$eq(this.weightSum() + weight);
        LogisticAggregator logisticAggregator = this;
        return logisticAggregator;
    }

    public LogisticAggregator(Broadcast<double[]> bcFeaturesStd, int numClasses, boolean fitIntercept, boolean multinomial, Broadcast<Vector> bcCoefficients) {
        block2: {
            this.bcFeaturesStd = bcFeaturesStd;
            this.numClasses = numClasses;
            this.fitIntercept = fitIntercept;
            this.multinomial = multinomial;
            this.bcCoefficients = bcCoefficients;
            DifferentiableLossAggregator.$init$(this);
            Logging.$init$((Logging)this);
            this.numFeatures = ((double[])bcFeaturesStd.value()).length;
            this.numFeaturesPlusIntercept = fitIntercept ? this.numFeatures() + 1 : this.numFeatures();
            this.coefficientSize = ((Vector)bcCoefficients.value()).size();
            this.dim = this.coefficientSize();
            if (multinomial) {
                Predef$.MODULE$.require(numClasses == this.coefficientSize() / this.numFeaturesPlusIntercept(), (Function0 & Serializable & scala.Serializable)() -> new StringBuilder(46).append("The number of ").append("coefficients should be ").append($this.numClasses * this.numFeaturesPlusIntercept()).append(" but was ").append(this.coefficientSize()).toString());
            } else {
                Predef$.MODULE$.require(this.coefficientSize() == this.numFeaturesPlusIntercept(), (Function0 & Serializable & scala.Serializable)() -> new StringBuilder(31).append("Expected ").append(this.numFeaturesPlusIntercept()).append(" ").append("coefficients but got ").append(this.coefficientSize()).toString());
                Predef$.MODULE$.require(numClasses == 1 || numClasses == 2, (Function0 & Serializable & scala.Serializable)() -> new StringBuilder(68).append("Binary logistic aggregator requires numClasses ").append("in {1, 2} but found ").append($this.numClasses).append(".").toString());
            }
            if (!multinomial || numClasses > 2) break block2;
            this.logInfo((Function0<String>)(Function0 & Serializable & scala.Serializable)() -> new StringBuilder(324).append("Multinomial logistic regression for binary classification yields separate ").append("coefficients for positive and negative classes. When no regularization is applied, the").append("result will be effectively the same as binary logistic regression. When regularization").append("is applied, multinomial loss will produce a result different from binary loss.").toString());
        }
    }
}

