/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.ml.optim.aggregator;

import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.ml.feature.Instance;
import org.apache.spark.ml.linalg.DenseVector;
import org.apache.spark.ml.linalg.DenseVector$;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.optim.aggregator.DifferentiableLossAggregator;
import scala.Function0;
import scala.Function2;
import scala.MatchError;
import scala.Option;
import scala.Predef$;
import scala.Serializable;
import scala.reflect.ScalaSignature;
import scala.runtime.DoubleRef;
import scala.runtime.java8.JFunction2;

@ScalaSignature(bytes="\u0006\u0001\u00054Q!\u0004\b\u0001%iA\u0001\u0002\f\u0001\u0003\u0002\u0003\u0006IA\f\u0005\tu\u0001\u0011\t\u0011)A\u0005w!Aa\b\u0001B\u0001B\u0003%q\bC\u0003G\u0001\u0011\u0005q\tC\u0004M\u0001\t\u0007I\u0011B'\t\rE\u0003\u0001\u0015!\u0003O\u0011\u001d\u0011\u0006A1A\u0005\n5Caa\u0015\u0001!\u0002\u0013q\u0005\u0002\u0003+\u0001\u0011\u000b\u0007I\u0011B+\t\u000fi\u0003!\u0019!C)\u001b\"11\f\u0001Q\u0001\n9CQ\u0001\u0018\u0001\u0005\u0002u\u0013q\u0002S5oO\u0016\fum\u001a:fO\u0006$xN\u001d\u0006\u0003\u001fA\t!\"Y4he\u0016<\u0017\r^8s\u0015\t\t\"#A\u0003paRLWN\u0003\u0002\u0014)\u0005\u0011Q\u000e\u001c\u0006\u0003+Y\tQa\u001d9be.T!a\u0006\r\u0002\r\u0005\u0004\u0018m\u00195f\u0015\u0005I\u0012aA8sON\u0019\u0001aG\u0011\u0011\u0005qyR\"A\u000f\u000b\u0003y\tQa]2bY\u0006L!\u0001I\u000f\u0003\r\u0005s\u0017PU3g!\u0011\u00113%J\u0016\u000e\u00039I!\u0001\n\b\u00039\u0011KgMZ3sK:$\u0018.\u00192mK2{7o]!hOJ,w-\u0019;peB\u0011a%K\u0007\u0002O)\u0011\u0001FE\u0001\bM\u0016\fG/\u001e:f\u0013\tQsE\u0001\u0005J]N$\u0018M\\2f!\t\u0011\u0003!A\u0007cG\u001a+\u0017\r^;sKN\u001cF\u000fZ\u0002\u0001!\ry#\u0007N\u0007\u0002a)\u0011\u0011\u0007F\u0001\nEJ|\u0017\rZ2bgRL!a\r\u0019\u0003\u0013\t\u0013x.\u00193dCN$\bc\u0001\u000f6o%\u0011a'\b\u0002\u0006\u0003J\u0014\u0018-\u001f\t\u00039aJ!!O\u000f\u0003\r\u0011{WO\u00197f\u000311\u0017\u000e^%oi\u0016\u00148-\u001a9u!\taB(\u0003\u0002>;\t9!i\\8mK\u0006t\u0017A\u00042d\u0007>,gMZ5dS\u0016tGo\u001d\t\u0004_I\u0002\u0005CA!E\u001b\u0005\u0011%BA\"\u0013\u0003\u0019a\u0017N\\1mO&\u0011QI\u0011\u0002\u0007-\u0016\u001cGo\u001c:\u0002\rqJg.\u001b;?)\rA%j\u0013\u000b\u0003W%CQA\u0010\u0003A\u0002}BQ\u0001\f\u0003A\u00029BQA\u000f\u0003A\u0002m\n1B\\;n\r\u0016\fG/\u001e:fgV\ta\n\u0005\u0002\u001d\u001f&\u0011\u0001+\b\u0002\u0004\u0013:$\u0018\u0001\u00048v[\u001a+\u0017\r^;sKN\u0004\u0013\u0001\u00078v[\u001a+\u0017\r^;sKN\u0004F.^:J]R,'oY3qi\u0006Ib.^7GK\u0006$XO]3t!2,8/\u00138uKJ\u001cW\r\u001d;!\u0003E\u0019w.\u001a4gS\u000eLWM\u001c;t\u0003J\u0014\u0018-_\u000b\u0002i!\u0012\u0011b\u0016\t\u00039aK!!W\u000f\u0003\u0013Q\u0014\u0018M\\:jK:$\u0018a\u00013j[\u0006!A-[7!\u0003\r\tG\r\u001a\u000b\u0003=~k\u0011\u0001\u0001\u0005\u0006A2\u0001\r!J\u0001\tS:\u001cH/\u00198dK\u0002")
public class HingeAggregator
implements DifferentiableLossAggregator<Instance, HingeAggregator> {
    private transient double[] coefficientsArray;
    private final Broadcast<double[]> bcFeaturesStd;
    private final boolean fitIntercept;
    private final Broadcast<Vector> bcCoefficients;
    private final int numFeatures;
    private final int numFeaturesPlusIntercept;
    private final int dim;
    private double weightSum;
    private double lossSum;
    private double[] gradientSumArray;
    private volatile transient boolean bitmap$trans$0;
    private volatile boolean bitmap$0;

    @Override
    public DifferentiableLossAggregator merge(DifferentiableLossAggregator other) {
        return DifferentiableLossAggregator.merge$(this, other);
    }

    @Override
    public Vector gradient() {
        return DifferentiableLossAggregator.gradient$(this);
    }

    @Override
    public double weight() {
        return DifferentiableLossAggregator.weight$(this);
    }

    @Override
    public double loss() {
        return DifferentiableLossAggregator.loss$(this);
    }

    @Override
    public double weightSum() {
        return this.weightSum;
    }

    @Override
    public void weightSum_$eq(double x$1) {
        this.weightSum = x$1;
    }

    @Override
    public double lossSum() {
        return this.lossSum;
    }

    @Override
    public void lossSum_$eq(double x$1) {
        this.lossSum = x$1;
    }

    private double[] gradientSumArray$lzycompute() {
        HingeAggregator hingeAggregator = this;
        synchronized (hingeAggregator) {
            if (!this.bitmap$0) {
                this.gradientSumArray = DifferentiableLossAggregator.gradientSumArray$(this);
                this.bitmap$0 = true;
            }
        }
        return this.gradientSumArray;
    }

    @Override
    public double[] gradientSumArray() {
        return !this.bitmap$0 ? this.gradientSumArray$lzycompute() : this.gradientSumArray;
    }

    private int numFeatures() {
        return this.numFeatures;
    }

    private int numFeaturesPlusIntercept() {
        return this.numFeaturesPlusIntercept;
    }

    private double[] coefficientsArray$lzycompute() {
        HingeAggregator hingeAggregator = this;
        synchronized (hingeAggregator) {
            if (!this.bitmap$trans$0) {
                double[] values;
                DenseVector denseVector;
                Option option;
                Vector vector = (Vector)this.bcCoefficients.value();
                if (!(vector instanceof DenseVector) || (option = DenseVector$.MODULE$.unapply(denseVector = (DenseVector)vector)).isEmpty()) {
                    throw new IllegalArgumentException(new StringBuilder(54).append("coefficients only supports dense vector").append(" but got type ").append(this.bcCoefficients.value().getClass()).append(".").toString());
                }
                double[] dArray = values = (double[])option.get();
                this.coefficientsArray = dArray;
                this.bitmap$trans$0 = true;
            }
        }
        return this.coefficientsArray;
    }

    private double[] coefficientsArray() {
        return !this.bitmap$trans$0 ? this.coefficientsArray$lzycompute() : this.coefficientsArray;
    }

    @Override
    public int dim() {
        return this.dim;
    }

    @Override
    public HingeAggregator add(Instance instance) {
        double loss;
        double weight;
        Instance instance2 = instance;
        if (instance2 != null) {
            double dotProduct;
            double labelScaled;
            double label = instance2.label();
            weight = instance2.weight();
            Vector features = instance2.features();
            Predef$.MODULE$.require(this.numFeatures() == features.size(), (Function0 & java.io.Serializable & Serializable)() -> new StringBuilder(66).append("Dimensions mismatch when adding new instance.").append(" Expecting ").append(this.numFeatures()).append(" but got ").append(features.size()).append(".").toString());
            Predef$.MODULE$.require(weight >= 0.0, (Function0 & java.io.Serializable & Serializable)() -> new StringBuilder(34).append("instance weight, ").append(weight).append(" has to be >= 0.0").toString());
            if (weight == 0.0) {
                return this;
            }
            double[] localFeaturesStd = (double[])this.bcFeaturesStd.value();
            double[] localCoefficients = this.coefficientsArray();
            double[] localGradientSumArray = this.gradientSumArray();
            DoubleRef sum = DoubleRef.create((double)0.0);
            features.foreachActive((Function2)(JFunction2.mcVID.sp & java.io.Serializable & Serializable)(index, value) -> {
                block0: {
                    if (localFeaturesStd[index] == 0.0 || value == 0.0) break block0;
                    sum$1.elem += localCoefficients[index] * value / localFeaturesStd[index];
                }
            });
            if (this.fitIntercept) {
                sum.elem += localCoefficients[this.numFeaturesPlusIntercept() - 1];
            }
            double d = loss = 1.0 > (labelScaled = (double)2 * label - 1.0) * (dotProduct = sum.elem) ? (1.0 - labelScaled * dotProduct) * weight : 0.0;
            if (1.0 > labelScaled * dotProduct) {
                double gradientScale = -labelScaled * weight;
                features.foreachActive((Function2)(JFunction2.mcVID.sp & java.io.Serializable & Serializable)(index, value) -> {
                    block0: {
                        if (localFeaturesStd[index] == 0.0 || value == 0.0) break block0;
                        localGradientSumArray$1[index] = localGradientSumArray[index] + value * gradientScale / localFeaturesStd[index];
                    }
                });
                if (this.fitIntercept) {
                    int n = localGradientSumArray.length - 1;
                    localGradientSumArray[n] = localGradientSumArray[n] + gradientScale;
                }
            }
        } else {
            throw new MatchError((Object)instance2);
        }
        this.lossSum_$eq(this.lossSum() + loss);
        this.weightSum_$eq(this.weightSum() + weight);
        HingeAggregator hingeAggregator = this;
        return hingeAggregator;
    }

    public HingeAggregator(Broadcast<double[]> bcFeaturesStd, boolean fitIntercept, Broadcast<Vector> bcCoefficients) {
        this.bcFeaturesStd = bcFeaturesStd;
        this.fitIntercept = fitIntercept;
        this.bcCoefficients = bcCoefficients;
        DifferentiableLossAggregator.$init$(this);
        this.numFeatures = ((double[])bcFeaturesStd.value()).length;
        this.numFeaturesPlusIntercept = fitIntercept ? this.numFeatures() + 1 : this.numFeatures();
        this.dim = this.numFeaturesPlusIntercept();
    }
}

