/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.ml.optim.aggregator;

import java.io.Serializable;
import java.util.Arrays;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.internal.Logging;
import org.apache.spark.ml.feature.InstanceBlock;
import org.apache.spark.ml.linalg.BLAS$;
import org.apache.spark.ml.linalg.DenseVector;
import org.apache.spark.ml.linalg.DenseVector$;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.optim.aggregator.DifferentiableLossAggregator;
import org.slf4j.Logger;
import scala.Array$;
import scala.Function0;
import scala.Function1;
import scala.Option;
import scala.Predef$;
import scala.collection.ArrayOps$;
import scala.math.package$;
import scala.reflect.ClassTag;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxesRunTime;
import scala.runtime.java8.JFunction1;

@ScalaSignature(bytes="\u0006\u000514QAD\b\u0001'mA\u0001b\r\u0001\u0003\u0002\u0003\u0006I!\u000e\u0005\t\u0003\u0002\u0011\t\u0011)A\u0005k!A!\t\u0001B\u0001B\u0003%1\t\u0003\u0005G\u0001\t\u0005\t\u0015!\u0003?\u0011!9\u0005A!A!\u0002\u0013A\u0005\"B(\u0001\t\u0003\u0001\u0006bB,\u0001\u0005\u0004%I\u0001\u0017\u0005\u00079\u0002\u0001\u000b\u0011B-\t\u000fu\u0003!\u0019!C)1\"1a\f\u0001Q\u0001\neC\u0001b\u0018\u0001\t\u0006\u0004%I\u0001\u0019\u0005\tK\u0002A)\u0019!C\u0005M\")q\r\u0001C\u0001Q\n!\u0002*\u001e2fe\ncwnY6BO\u001e\u0014XmZ1u_JT!\u0001E\t\u0002\u0015\u0005<wM]3hCR|'O\u0003\u0002\u0013'\u0005)q\u000e\u001d;j[*\u0011A#F\u0001\u0003[2T!AF\f\u0002\u000bM\u0004\u0018M]6\u000b\u0005aI\u0012AB1qC\u000eDWMC\u0001\u001b\u0003\ry'oZ\n\u0005\u0001q\u0011S\u0006\u0005\u0002\u001eA5\taDC\u0001 \u0003\u0015\u00198-\u00197b\u0013\t\tcD\u0001\u0004B]f\u0014VM\u001a\t\u0005G\u00112C&D\u0001\u0010\u0013\t)sB\u0001\u000fES\u001a4WM]3oi&\f'\r\\3M_N\u001c\u0018iZ4sK\u001e\fGo\u001c:\u0011\u0005\u001dRS\"\u0001\u0015\u000b\u0005%\u001a\u0012a\u00024fCR,(/Z\u0005\u0003W!\u0012Q\"\u00138ti\u0006t7-\u001a\"m_\u000e\\\u0007CA\u0012\u0001!\tq\u0013'D\u00010\u0015\t\u0001T#\u0001\u0005j]R,'O\\1m\u0013\t\u0011tFA\u0004M_\u001e<\u0017N\\4\u0002\u0019\t\u001c\u0017J\u001c<feN,7\u000b\u001e3\u0004\u0001A\u0019a'O\u001e\u000e\u0003]R!\u0001O\u000b\u0002\u0013\t\u0014x.\u00193dCN$\u0018B\u0001\u001e8\u0005%\u0011%o\\1eG\u0006\u001cH\u000fE\u0002\u001eyyJ!!\u0010\u0010\u0003\u000b\u0005\u0013(/Y=\u0011\u0005uy\u0014B\u0001!\u001f\u0005\u0019!u.\u001e2mK\u0006a!mY*dC2,G-T3b]\u0006aa-\u001b;J]R,'oY3qiB\u0011Q\u0004R\u0005\u0003\u000bz\u0011qAQ8pY\u0016\fg.A\u0004faNLGn\u001c8\u0002\u001d\t\u001c7i\\3gM&\u001c\u0017.\u001a8ugB\u0019a'O%\u0011\u0005)kU\"A&\u000b\u00051\u001b\u0012A\u00027j]\u0006dw-\u0003\u0002O\u0017\n1a+Z2u_J\fa\u0001P5oSRtD#B)T)V3FC\u0001\u0017S\u0011\u00159e\u00011\u0001I\u0011\u0015\u0019d\u00011\u00016\u0011\u0015\te\u00011\u00016\u0011\u0015\u0011e\u00011\u0001D\u0011\u00151e\u00011\u0001?\u0003-qW/\u001c$fCR,(/Z:\u0016\u0003e\u0003\"!\b.\n\u0005ms\"aA%oi\u0006aa.^7GK\u0006$XO]3tA\u0005\u0019A-[7\u0002\t\u0011LW\u000eI\u0001\u0012G>,gMZ5dS\u0016tGo]!se\u0006LX#A\u001e)\u0005-\u0011\u0007CA\u000fd\u0013\t!gDA\u0005ue\u0006t7/[3oi\u0006aQ.\u0019:hS:|eMZ:fiV\ta(A\u0002bI\u0012$\"!\u001b6\u000e\u0003\u0001AQa[\u0007A\u0002\u0019\nQA\u00197pG.\u0004")
public class HuberBlockAggregator
implements DifferentiableLossAggregator<InstanceBlock, HuberBlockAggregator>,
Logging {
    private transient double[] coefficientsArray;
    private double marginOffset;
    private final Broadcast<double[]> bcScaledMean;
    private final boolean fitIntercept;
    private final double epsilon;
    private final Broadcast<Vector> bcCoefficients;
    private final int numFeatures;
    private final int dim;
    private transient Logger org$apache$spark$internal$Logging$$log_;
    private double weightSum;
    private double lossSum;
    private double[] gradientSumArray;
    private volatile transient boolean bitmap$trans$0;
    private volatile byte bitmap$0;

    public String logName() {
        return Logging.logName$((Logging)this);
    }

    public Logger log() {
        return Logging.log$((Logging)this);
    }

    public void logInfo(Function0<String> msg) {
        Logging.logInfo$((Logging)this, msg);
    }

    public void logDebug(Function0<String> msg) {
        Logging.logDebug$((Logging)this, msg);
    }

    public void logTrace(Function0<String> msg) {
        Logging.logTrace$((Logging)this, msg);
    }

    public void logWarning(Function0<String> msg) {
        Logging.logWarning$((Logging)this, msg);
    }

    public void logError(Function0<String> msg) {
        Logging.logError$((Logging)this, msg);
    }

    public void logInfo(Function0<String> msg, Throwable throwable) {
        Logging.logInfo$((Logging)this, msg, (Throwable)throwable);
    }

    public void logDebug(Function0<String> msg, Throwable throwable) {
        Logging.logDebug$((Logging)this, msg, (Throwable)throwable);
    }

    public void logTrace(Function0<String> msg, Throwable throwable) {
        Logging.logTrace$((Logging)this, msg, (Throwable)throwable);
    }

    public void logWarning(Function0<String> msg, Throwable throwable) {
        Logging.logWarning$((Logging)this, msg, (Throwable)throwable);
    }

    public void logError(Function0<String> msg, Throwable throwable) {
        Logging.logError$((Logging)this, msg, (Throwable)throwable);
    }

    public boolean isTraceEnabled() {
        return Logging.isTraceEnabled$((Logging)this);
    }

    public void initializeLogIfNecessary(boolean isInterpreter) {
        Logging.initializeLogIfNecessary$((Logging)this, (boolean)isInterpreter);
    }

    public boolean initializeLogIfNecessary(boolean isInterpreter, boolean silent) {
        return Logging.initializeLogIfNecessary$((Logging)this, (boolean)isInterpreter, (boolean)silent);
    }

    public boolean initializeLogIfNecessary$default$2() {
        return Logging.initializeLogIfNecessary$default$2$((Logging)this);
    }

    public void initializeForcefully(boolean isInterpreter, boolean silent) {
        Logging.initializeForcefully$((Logging)this, (boolean)isInterpreter, (boolean)silent);
    }

    @Override
    public DifferentiableLossAggregator merge(DifferentiableLossAggregator other) {
        return DifferentiableLossAggregator.merge$(this, other);
    }

    @Override
    public Vector gradient() {
        return DifferentiableLossAggregator.gradient$(this);
    }

    @Override
    public double weight() {
        return DifferentiableLossAggregator.weight$(this);
    }

    @Override
    public double loss() {
        return DifferentiableLossAggregator.loss$(this);
    }

    public Logger org$apache$spark$internal$Logging$$log_() {
        return this.org$apache$spark$internal$Logging$$log_;
    }

    public void org$apache$spark$internal$Logging$$log__$eq(Logger x$1) {
        this.org$apache$spark$internal$Logging$$log_ = x$1;
    }

    @Override
    public double weightSum() {
        return this.weightSum;
    }

    @Override
    public void weightSum_$eq(double x$1) {
        this.weightSum = x$1;
    }

    @Override
    public double lossSum() {
        return this.lossSum;
    }

    @Override
    public void lossSum_$eq(double x$1) {
        this.lossSum = x$1;
    }

    private double[] gradientSumArray$lzycompute() {
        HuberBlockAggregator huberBlockAggregator = this;
        synchronized (huberBlockAggregator) {
            if ((byte)(this.bitmap$0 & 2) == 0) {
                this.gradientSumArray = DifferentiableLossAggregator.gradientSumArray$(this);
                this.bitmap$0 = (byte)(this.bitmap$0 | 2);
            }
        }
        return this.gradientSumArray;
    }

    @Override
    public double[] gradientSumArray() {
        return (byte)(this.bitmap$0 & 2) == 0 ? this.gradientSumArray$lzycompute() : this.gradientSumArray;
    }

    private int numFeatures() {
        return this.numFeatures;
    }

    @Override
    public int dim() {
        return this.dim;
    }

    private double[] coefficientsArray$lzycompute() {
        HuberBlockAggregator huberBlockAggregator = this;
        synchronized (huberBlockAggregator) {
            if (!this.bitmap$trans$0) {
                double[] values;
                DenseVector denseVector;
                Option option;
                Vector vector = (Vector)this.bcCoefficients.value();
                if (!(vector instanceof DenseVector) || (option = DenseVector$.MODULE$.unapply(denseVector = (DenseVector)vector)).isEmpty()) {
                    throw new IllegalArgumentException(new StringBuilder(0).append("coefficients only supports dense vector but ").append(new StringBuilder(11).append("got type ").append(this.bcCoefficients.value().getClass()).append(".)").toString()).toString());
                }
                double[] dArray = values = (double[])option.get();
                this.coefficientsArray = dArray;
                this.bitmap$trans$0 = true;
            }
        }
        return this.coefficientsArray;
    }

    private double[] coefficientsArray() {
        return !this.bitmap$trans$0 ? this.coefficientsArray$lzycompute() : this.coefficientsArray;
    }

    private double marginOffset$lzycompute() {
        HuberBlockAggregator huberBlockAggregator = this;
        synchronized (huberBlockAggregator) {
            if ((byte)(this.bitmap$0 & 1) == 0) {
                this.marginOffset = this.fitIntercept ? this.coefficientsArray()[this.dim() - 2] - BLAS$.MODULE$.javaBLAS().ddot(this.numFeatures(), this.coefficientsArray(), 1, (double[])this.bcScaledMean.value(), 1) : Double.NaN;
                this.bitmap$0 = (byte)(this.bitmap$0 | 1);
            }
        }
        return this.marginOffset;
    }

    private double marginOffset() {
        return (byte)(this.bitmap$0 & 1) == 0 ? this.marginOffset$lzycompute() : this.marginOffset;
    }

    @Override
    public HuberBlockAggregator add(InstanceBlock block) {
        Predef$.MODULE$.require(block.matrix().isTransposed());
        Predef$.MODULE$.require(this.numFeatures() == block.numFeatures(), (Function0 & Serializable)() -> new StringBuilder(0).append("Dimensions mismatch when adding new ").append(new StringBuilder(30).append("instance. Expecting ").append(this.numFeatures()).append(" but got ").append(block.numFeatures()).append(".").toString()).toString());
        Predef$.MODULE$.require(block.weightIter().forall((Function1)(JFunction1.mcZD.sp & Serializable)x$1 -> x$1 >= 0.0), (Function0 & Serializable)() -> new StringBuilder(34).append("instance weights ").append(block.weightIter().mkString("[", ",", "]")).append(" has to be >= 0.0").toString());
        if (block.weightIter().forall((Function1)(JFunction1.mcZD.sp & Serializable)x$2 -> x$2 == 0.0)) {
            return this;
        }
        int size = block.size();
        double[] arr = (double[])Array$.MODULE$.ofDim(size, (ClassTag)ClassTag$.MODULE$.Double());
        if (this.fitIntercept) {
            Arrays.fill(arr, this.marginOffset());
        }
        BLAS$.MODULE$.gemv(1.0, block.matrix(), this.coefficientsArray(), 1.0, arr);
        double sigma = BoxesRunTime.unboxToDouble((Object)ArrayOps$.MODULE$.last$extension(Predef$.MODULE$.doubleArrayOps(this.coefficientsArray())));
        double sigmaGradSum = 0.0;
        double localLossSum = 0.0;
        double localWeightSum = 0.0;
        double multiplierSum = 0.0;
        for (int i = 0; i < size; ++i) {
            double weight = block.getWeight().apply$mcDI$sp(i);
            localWeightSum += weight;
            if (weight > 0.0) {
                double multiplier;
                double margin;
                double label = block.getLabel(i);
                double linearLoss = label - (margin = arr[i]);
                if (package$.MODULE$.abs(linearLoss) <= sigma * this.epsilon) {
                    double multiplier2;
                    localLossSum += 0.5 * weight * (sigma + package$.MODULE$.pow(linearLoss, 2.0) / sigma);
                    double linearLossDivSigma = linearLoss / sigma;
                    arr[i] = multiplier2 = -1.0 * weight * linearLossDivSigma;
                    multiplierSum += multiplier2;
                    sigmaGradSum += 0.5 * weight * (1.0 - package$.MODULE$.pow(linearLossDivSigma, 2.0));
                    continue;
                }
                localLossSum += 0.5 * weight * (sigma + 2.0 * this.epsilon * package$.MODULE$.abs(linearLoss) - sigma * this.epsilon * this.epsilon);
                double sign = linearLoss >= 0.0 ? -1.0 : 1.0;
                arr[i] = multiplier = weight * sign * this.epsilon;
                multiplierSum += multiplier;
                sigmaGradSum += 0.5 * weight * (1.0 - this.epsilon * this.epsilon);
                continue;
            }
            arr[i] = 0.0;
        }
        this.lossSum_$eq(this.lossSum() + localLossSum);
        this.weightSum_$eq(this.weightSum() + localWeightSum);
        BLAS$.MODULE$.gemv(1.0, block.matrix().transpose(), arr, 1.0, this.gradientSumArray());
        if (this.fitIntercept) {
            BLAS$.MODULE$.javaBLAS().daxpy(this.numFeatures(), -multiplierSum, (double[])this.bcScaledMean.value(), 1, this.gradientSumArray(), 1);
            int n = this.dim() - 2;
            this.gradientSumArray()[n] = this.gradientSumArray()[n] + multiplierSum;
        }
        int n = this.dim() - 1;
        this.gradientSumArray()[n] = this.gradientSumArray()[n] + sigmaGradSum;
        return this;
    }

    public HuberBlockAggregator(Broadcast<double[]> bcInverseStd, Broadcast<double[]> bcScaledMean, boolean fitIntercept, double epsilon, Broadcast<Vector> bcCoefficients) {
        this.bcScaledMean = bcScaledMean;
        this.fitIntercept = fitIntercept;
        this.epsilon = epsilon;
        this.bcCoefficients = bcCoefficients;
        DifferentiableLossAggregator.$init$(this);
        Logging.$init$((Logging)this);
        if (fitIntercept) {
            Predef$.MODULE$.require(bcScaledMean != null && ((double[])bcScaledMean.value()).length == ((double[])bcInverseStd.value()).length, (Function0 & Serializable)() -> "scaled means is required when center the vectors");
        }
        this.numFeatures = ((double[])bcInverseStd.value()).length;
        this.dim = ((Vector)bcCoefficients.value()).size();
    }
}

