/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.ml.optim.aggregator;

import java.io.Serializable;
import java.util.Arrays;
import java.util.HashMap;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.internal.LogEntry;
import org.apache.spark.internal.Logging;
import org.apache.spark.ml.feature.InstanceBlock;
import org.apache.spark.ml.linalg.BLAS$;
import org.apache.spark.ml.linalg.DenseVector;
import org.apache.spark.ml.linalg.DenseVector$;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.optim.aggregator.DifferentiableLossAggregator;
import org.slf4j.Logger;
import scala.Array$;
import scala.Function0;
import scala.Function1;
import scala.Option;
import scala.Predef$;
import scala.StringContext;
import scala.collection.ArrayOps$;
import scala.math.package$;
import scala.reflect.ClassTag;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.java8.JFunction1;

@ScalaSignature(bytes="\u0006\u0005e4Q!\u0005\n\u0001-yA\u0001B\u000e\u0001\u0003\u0002\u0003\u0006I\u0001\u000f\u0005\t\t\u0002\u0011\t\u0011)A\u0005q!AQ\t\u0001B\u0001B\u0003%a\t\u0003\u0005J\u0001\t\u0005\t\u0015!\u0003B\u0011!Q\u0005A!A!\u0002\u0013Y\u0005\"\u0002*\u0001\t\u0003\u0019\u0006b\u0002.\u0001\u0005\u0004%Ia\u0017\u0005\u0007?\u0002\u0001\u000b\u0011\u0002/\t\u000f\u0001\u0004!\u0019!C)7\"1\u0011\r\u0001Q\u0001\nqC\u0001B\u0019\u0001\t\u0006\u0004%Ia\u0019\u0005\tQ\u0002A)\u0019!C\u0005S\"I!\u000e\u0001a\u0001\u0002\u0004%Ia\u0019\u0005\nW\u0002\u0001\r\u00111A\u0005\n1D\u0011B\u001d\u0001A\u0002\u0003\u0005\u000b\u0015\u0002 \t\u000bQ\u0004A\u0011A;\u0003)!+(-\u001a:CY>\u001c7.Q4he\u0016<\u0017\r^8s\u0015\t\u0019B#\u0001\u0006bO\u001e\u0014XmZ1u_JT!!\u0006\f\u0002\u000b=\u0004H/[7\u000b\u0005]A\u0012AA7m\u0015\tI\"$A\u0003ta\u0006\u00148N\u0003\u0002\u001c9\u00051\u0011\r]1dQ\u0016T\u0011!H\u0001\u0004_J<7\u0003\u0002\u0001 KA\u0002\"\u0001I\u0012\u000e\u0003\u0005R\u0011AI\u0001\u0006g\u000e\fG.Y\u0005\u0003I\u0005\u0012a!\u00118z%\u00164\u0007\u0003\u0002\u0014(S=j\u0011AE\u0005\u0003QI\u0011A\u0004R5gM\u0016\u0014XM\u001c;jC\ndW\rT8tg\u0006;wM]3hCR|'\u000f\u0005\u0002+[5\t1F\u0003\u0002--\u00059a-Z1ukJ,\u0017B\u0001\u0018,\u00055Ien\u001d;b]\u000e,'\t\\8dWB\u0011a\u0005\u0001\t\u0003cQj\u0011A\r\u0006\u0003ga\t\u0001\"\u001b8uKJt\u0017\r\\\u0005\u0003kI\u0012q\u0001T8hO&tw-\u0001\u0007cG&sg/\u001a:tKN#Hm\u0001\u0001\u0011\u0007ebd(D\u0001;\u0015\tY\u0004$A\u0005ce>\fGmY1ti&\u0011QH\u000f\u0002\n\u0005J|\u0017\rZ2bgR\u00042\u0001I B\u0013\t\u0001\u0015EA\u0003BeJ\f\u0017\u0010\u0005\u0002!\u0005&\u00111)\t\u0002\u0007\t>,(\r\\3\u0002\u0019\t\u001c7kY1mK\u0012lU-\u00198\u0002\u0019\u0019LG/\u00138uKJ\u001cW\r\u001d;\u0011\u0005\u0001:\u0015B\u0001%\"\u0005\u001d\u0011un\u001c7fC:\fq!\u001a9tS2|g.\u0001\bcG\u000e{WM\u001a4jG&,g\u000e^:\u0011\u0007ebD\n\u0005\u0002N!6\taJ\u0003\u0002P-\u00051A.\u001b8bY\u001eL!!\u0015(\u0003\rY+7\r^8s\u0003\u0019a\u0014N\\5u}Q)AKV,Y3R\u0011q&\u0016\u0005\u0006\u0015\u001a\u0001\ra\u0013\u0005\u0006m\u0019\u0001\r\u0001\u000f\u0005\u0006\t\u001a\u0001\r\u0001\u000f\u0005\u0006\u000b\u001a\u0001\rA\u0012\u0005\u0006\u0013\u001a\u0001\r!Q\u0001\f]Vlg)Z1ukJ,7/F\u0001]!\t\u0001S,\u0003\u0002_C\t\u0019\u0011J\u001c;\u0002\u00199,XNR3biV\u0014Xm\u001d\u0011\u0002\u0007\u0011LW.\u0001\u0003eS6\u0004\u0013!E2pK\u001a4\u0017nY5f]R\u001c\u0018I\u001d:bsV\ta\b\u000b\u0002\fKB\u0011\u0001EZ\u0005\u0003O\u0006\u0012\u0011\u0002\u001e:b]NLWM\u001c;\u0002\u00195\f'oZ5o\u001f\u001a47/\u001a;\u0016\u0003\u0005\u000baAY;gM\u0016\u0014\u0018A\u00032vM\u001a,'o\u0018\u0013fcR\u0011Q\u000e\u001d\t\u0003A9L!a\\\u0011\u0003\tUs\u0017\u000e\u001e\u0005\bc:\t\t\u00111\u0001?\u0003\rAH%M\u0001\bEV4g-\u001a:!Q\tyQ-A\u0002bI\u0012$\"A^<\u000e\u0003\u0001AQ\u0001\u001f\tA\u0002%\nQA\u00197pG.\u0004")
public class HuberBlockAggregator
implements DifferentiableLossAggregator<InstanceBlock, HuberBlockAggregator>,
Logging {
    private transient double[] coefficientsArray;
    private double marginOffset;
    private final Broadcast<double[]> bcScaledMean;
    private final boolean fitIntercept;
    private final double epsilon;
    private final Broadcast<Vector> bcCoefficients;
    private final int numFeatures;
    private final int dim;
    private transient double[] buffer;
    private transient Logger org$apache$spark$internal$Logging$$log_;
    private double weightSum;
    private double lossSum;
    private double[] gradientSumArray;
    private volatile byte bitmap$0;
    private volatile transient boolean bitmap$trans$0;

    public String logName() {
        return Logging.logName$((Logging)this);
    }

    public Logger log() {
        return Logging.log$((Logging)this);
    }

    public Logging.LogStringContext LogStringContext(StringContext sc) {
        return Logging.LogStringContext$((Logging)this, (StringContext)sc);
    }

    public void withLogContext(HashMap<String, String> context, Function0<BoxedUnit> body) {
        Logging.withLogContext$((Logging)this, context, body);
    }

    public void logInfo(Function0<String> msg) {
        Logging.logInfo$((Logging)this, msg);
    }

    public void logInfo(LogEntry entry) {
        Logging.logInfo$((Logging)this, (LogEntry)entry);
    }

    public void logInfo(LogEntry entry, Throwable throwable) {
        Logging.logInfo$((Logging)this, (LogEntry)entry, (Throwable)throwable);
    }

    public void logDebug(Function0<String> msg) {
        Logging.logDebug$((Logging)this, msg);
    }

    public void logDebug(LogEntry entry) {
        Logging.logDebug$((Logging)this, (LogEntry)entry);
    }

    public void logDebug(LogEntry entry, Throwable throwable) {
        Logging.logDebug$((Logging)this, (LogEntry)entry, (Throwable)throwable);
    }

    public void logTrace(Function0<String> msg) {
        Logging.logTrace$((Logging)this, msg);
    }

    public void logTrace(LogEntry entry) {
        Logging.logTrace$((Logging)this, (LogEntry)entry);
    }

    public void logTrace(LogEntry entry, Throwable throwable) {
        Logging.logTrace$((Logging)this, (LogEntry)entry, (Throwable)throwable);
    }

    public void logWarning(Function0<String> msg) {
        Logging.logWarning$((Logging)this, msg);
    }

    public void logWarning(LogEntry entry) {
        Logging.logWarning$((Logging)this, (LogEntry)entry);
    }

    public void logWarning(LogEntry entry, Throwable throwable) {
        Logging.logWarning$((Logging)this, (LogEntry)entry, (Throwable)throwable);
    }

    public void logError(Function0<String> msg) {
        Logging.logError$((Logging)this, msg);
    }

    public void logError(LogEntry entry) {
        Logging.logError$((Logging)this, (LogEntry)entry);
    }

    public void logError(LogEntry entry, Throwable throwable) {
        Logging.logError$((Logging)this, (LogEntry)entry, (Throwable)throwable);
    }

    public void logInfo(Function0<String> msg, Throwable throwable) {
        Logging.logInfo$((Logging)this, msg, (Throwable)throwable);
    }

    public void logDebug(Function0<String> msg, Throwable throwable) {
        Logging.logDebug$((Logging)this, msg, (Throwable)throwable);
    }

    public void logTrace(Function0<String> msg, Throwable throwable) {
        Logging.logTrace$((Logging)this, msg, (Throwable)throwable);
    }

    public void logWarning(Function0<String> msg, Throwable throwable) {
        Logging.logWarning$((Logging)this, msg, (Throwable)throwable);
    }

    public void logError(Function0<String> msg, Throwable throwable) {
        Logging.logError$((Logging)this, msg, (Throwable)throwable);
    }

    public boolean isTraceEnabled() {
        return Logging.isTraceEnabled$((Logging)this);
    }

    public void initializeLogIfNecessary(boolean isInterpreter) {
        Logging.initializeLogIfNecessary$((Logging)this, (boolean)isInterpreter);
    }

    public boolean initializeLogIfNecessary(boolean isInterpreter, boolean silent) {
        return Logging.initializeLogIfNecessary$((Logging)this, (boolean)isInterpreter, (boolean)silent);
    }

    public boolean initializeLogIfNecessary$default$2() {
        return Logging.initializeLogIfNecessary$default$2$((Logging)this);
    }

    public void initializeForcefully(boolean isInterpreter, boolean silent) {
        Logging.initializeForcefully$((Logging)this, (boolean)isInterpreter, (boolean)silent);
    }

    @Override
    public DifferentiableLossAggregator merge(DifferentiableLossAggregator other) {
        return DifferentiableLossAggregator.merge$(this, other);
    }

    @Override
    public Vector gradient() {
        return DifferentiableLossAggregator.gradient$(this);
    }

    @Override
    public double weight() {
        return DifferentiableLossAggregator.weight$(this);
    }

    @Override
    public double loss() {
        return DifferentiableLossAggregator.loss$(this);
    }

    public Logger org$apache$spark$internal$Logging$$log_() {
        return this.org$apache$spark$internal$Logging$$log_;
    }

    public void org$apache$spark$internal$Logging$$log__$eq(Logger x$1) {
        this.org$apache$spark$internal$Logging$$log_ = x$1;
    }

    @Override
    public double weightSum() {
        return this.weightSum;
    }

    @Override
    public void weightSum_$eq(double x$1) {
        this.weightSum = x$1;
    }

    @Override
    public double lossSum() {
        return this.lossSum;
    }

    @Override
    public void lossSum_$eq(double x$1) {
        this.lossSum = x$1;
    }

    private double[] gradientSumArray$lzycompute() {
        HuberBlockAggregator huberBlockAggregator = this;
        synchronized (huberBlockAggregator) {
            if ((byte)(this.bitmap$0 & 2) == 0) {
                this.gradientSumArray = DifferentiableLossAggregator.gradientSumArray$(this);
                this.bitmap$0 = (byte)(this.bitmap$0 | 2);
            }
        }
        return this.gradientSumArray;
    }

    @Override
    public double[] gradientSumArray() {
        if ((byte)(this.bitmap$0 & 2) == 0) {
            return this.gradientSumArray$lzycompute();
        }
        return this.gradientSumArray;
    }

    private int numFeatures() {
        return this.numFeatures;
    }

    @Override
    public int dim() {
        return this.dim;
    }

    private double[] coefficientsArray$lzycompute() {
        HuberBlockAggregator huberBlockAggregator = this;
        synchronized (huberBlockAggregator) {
            if (!this.bitmap$trans$0) {
                DenseVector denseVector;
                Option option;
                Vector vector = (Vector)this.bcCoefficients.value();
                if (!(vector instanceof DenseVector) || (option = DenseVector$.MODULE$.unapply(denseVector = (DenseVector)vector)).isEmpty()) {
                    throw new IllegalArgumentException("coefficients only supports dense vector but got type " + this.bcCoefficients.value().getClass() + ".)");
                }
                double[] values = (double[])option.get();
                this.coefficientsArray = values;
                this.bitmap$trans$0 = true;
            }
        }
        return this.coefficientsArray;
    }

    private double[] coefficientsArray() {
        if (!this.bitmap$trans$0) {
            return this.coefficientsArray$lzycompute();
        }
        return this.coefficientsArray;
    }

    private double marginOffset$lzycompute() {
        HuberBlockAggregator huberBlockAggregator = this;
        synchronized (huberBlockAggregator) {
            if ((byte)(this.bitmap$0 & 1) == 0) {
                this.marginOffset = this.fitIntercept ? this.coefficientsArray()[this.dim() - 2] - BLAS$.MODULE$.javaBLAS().ddot(this.numFeatures(), this.coefficientsArray(), 1, (double[])this.bcScaledMean.value(), 1) : Double.NaN;
                this.bitmap$0 = (byte)(this.bitmap$0 | 1);
            }
        }
        return this.marginOffset;
    }

    private double marginOffset() {
        if ((byte)(this.bitmap$0 & 1) == 0) {
            return this.marginOffset$lzycompute();
        }
        return this.marginOffset;
    }

    private double[] buffer() {
        return this.buffer;
    }

    private void buffer_$eq(double[] x$1) {
        this.buffer = x$1;
    }

    @Override
    public HuberBlockAggregator add(InstanceBlock block) {
        Predef$.MODULE$.require(block.matrix().isTransposed());
        Predef$.MODULE$.require(this.numFeatures() == block.numFeatures(), (Function0 & Serializable)() -> "Dimensions mismatch when adding new instance. Expecting " + this.numFeatures() + " but got " + block.numFeatures() + ".");
        Predef$.MODULE$.require(block.weightIter().forall((Function1)(JFunction1.mcZD.sp & Serializable)x$2 -> x$2 >= 0.0), (Function0 & Serializable)() -> "instance weights " + block.weightIter().mkString("[", ",", "]") + " has to be >= 0.0");
        if (block.weightIter().forall((Function1)(JFunction1.mcZD.sp & Serializable)x$3 -> x$3 == 0.0)) {
            return this;
        }
        int size = block.size();
        if (this.buffer() == null || this.buffer().length < size) {
            this.buffer_$eq((double[])Array$.MODULE$.ofDim(size, (ClassTag)ClassTag$.MODULE$.Double()));
        }
        double[] arr = this.buffer();
        if (this.fitIntercept) {
            Arrays.fill(arr, 0, size, this.marginOffset());
            BLAS$.MODULE$.gemv(1.0, block.matrix(), this.coefficientsArray(), 1.0, arr);
        } else {
            BLAS$.MODULE$.gemv(1.0, block.matrix(), this.coefficientsArray(), 0.0, arr);
        }
        double sigma = BoxesRunTime.unboxToDouble((Object)ArrayOps$.MODULE$.last$extension(Predef$.MODULE$.doubleArrayOps(this.coefficientsArray())));
        double sigmaGradSum = 0.0;
        double localLossSum = 0.0;
        double localWeightSum = 0.0;
        double multiplierSum = 0.0;
        for (int i = 0; i < size; ++i) {
            double weight = block.getWeight().apply$mcDI$sp(i);
            localWeightSum += weight;
            if (weight > 0.0) {
                double multiplier;
                double margin;
                double label = block.getLabel(i);
                double linearLoss = label - (margin = arr[i]);
                if (package$.MODULE$.abs(linearLoss) <= sigma * this.epsilon) {
                    double multiplier2;
                    localLossSum += 0.5 * weight * (sigma + package$.MODULE$.pow(linearLoss, 2.0) / sigma);
                    double linearLossDivSigma = linearLoss / sigma;
                    arr[i] = multiplier2 = -1.0 * weight * linearLossDivSigma;
                    multiplierSum += multiplier2;
                    sigmaGradSum += 0.5 * weight * (1.0 - package$.MODULE$.pow(linearLossDivSigma, 2.0));
                    continue;
                }
                localLossSum += 0.5 * weight * (sigma + 2.0 * this.epsilon * package$.MODULE$.abs(linearLoss) - sigma * this.epsilon * this.epsilon);
                double sign = linearLoss >= 0.0 ? -1.0 : 1.0;
                arr[i] = multiplier = weight * sign * this.epsilon;
                multiplierSum += multiplier;
                sigmaGradSum += 0.5 * weight * (1.0 - this.epsilon * this.epsilon);
                continue;
            }
            arr[i] = 0.0;
        }
        this.lossSum_$eq(this.lossSum() + localLossSum);
        this.weightSum_$eq(this.weightSum() + localWeightSum);
        BLAS$.MODULE$.gemv(1.0, block.matrix().transpose(), arr, 1.0, this.gradientSumArray());
        if (this.fitIntercept) {
            BLAS$.MODULE$.javaBLAS().daxpy(this.numFeatures(), -multiplierSum, (double[])this.bcScaledMean.value(), 1, this.gradientSumArray(), 1);
            int n = this.dim() - 2;
            this.gradientSumArray()[n] = this.gradientSumArray()[n] + multiplierSum;
        }
        int n = this.dim() - 1;
        this.gradientSumArray()[n] = this.gradientSumArray()[n] + sigmaGradSum;
        return this;
    }

    public HuberBlockAggregator(Broadcast<double[]> bcInverseStd, Broadcast<double[]> bcScaledMean, boolean fitIntercept, double epsilon, Broadcast<Vector> bcCoefficients) {
        this.bcScaledMean = bcScaledMean;
        this.fitIntercept = fitIntercept;
        this.epsilon = epsilon;
        this.bcCoefficients = bcCoefficients;
        DifferentiableLossAggregator.$init$(this);
        Logging.$init$((Logging)this);
        if (fitIntercept) {
            Predef$.MODULE$.require(bcScaledMean != null && ((double[])bcScaledMean.value()).length == ((double[])bcInverseStd.value()).length, (Function0 & Serializable)() -> "scaled means is required when center the vectors");
        }
        this.numFeatures = ((double[])bcInverseStd.value()).length;
        this.dim = ((Vector)bcCoefficients.value()).size();
    }
}

