/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.ml.optim;

import java.io.Serializable;
import org.apache.spark.ml.feature.Instance;
import org.apache.spark.ml.feature.OffsetInstance;
import org.apache.spark.ml.linalg.BLAS$;
import org.apache.spark.ml.linalg.DenseVector;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.optim.IterativelyReweightedLeastSquaresModel;
import org.apache.spark.ml.optim.WeightedLeastSquares;
import org.apache.spark.ml.optim.WeightedLeastSquares$;
import org.apache.spark.ml.optim.WeightedLeastSquaresModel;
import org.apache.spark.ml.util.OptionalInstrumentation;
import org.apache.spark.ml.util.OptionalInstrumentation$;
import org.apache.spark.rdd.RDD;
import scala.Function0;
import scala.Function1;
import scala.Function2;
import scala.MatchError;
import scala.Predef$;
import scala.Tuple2;
import scala.collection.mutable.ArrayOps;
import scala.math.package$;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxesRunTime;
import scala.runtime.IntRef;
import scala.runtime.ObjectRef;
import scala.runtime.java8.JFunction2;

@ScalaSignature(bytes="\u0006\u0001\u0005\u0005a!B\t\u0013\u0001Qa\u0002\u0002\u0003\u0014\u0001\u0005\u000b\u0007I\u0011\u0001\u0015\t\u00115\u0002!\u0011!Q\u0001\n%B\u0001B\f\u0001\u0003\u0006\u0004%\ta\f\u0005\t\u007f\u0001\u0011\t\u0011)A\u0005a!A\u0001\t\u0001BC\u0002\u0013\u0005\u0011\t\u0003\u0005F\u0001\t\u0005\t\u0015!\u0003C\u0011!1\u0005A!b\u0001\n\u00039\u0005\u0002\u0003%\u0001\u0005\u0003\u0005\u000b\u0011\u0002\u001f\t\u0011%\u0003!Q1A\u0005\u0002)C\u0001B\u0014\u0001\u0003\u0002\u0003\u0006Ia\u0013\u0005\t\u001f\u0002\u0011)\u0019!C\u0001\u000f\"A\u0001\u000b\u0001B\u0001B\u0003%A\bC\u0003R\u0001\u0011\u0005!\u000bC\u0003[\u0001\u0011\u00051\fC\u0004r\u0001E\u0005I\u0011\u0001:\t\u000fu\u0004\u0011\u0013!C\u0001}\n\t\u0013\n^3sCRLg/\u001a7z%\u0016<X-[4ii\u0016$G*Z1tiN\u000bX/\u0019:fg*\u00111\u0003F\u0001\u0006_B$\u0018.\u001c\u0006\u0003+Y\t!!\u001c7\u000b\u0005]A\u0012!B:qCJ\\'BA\r\u001b\u0003\u0019\t\u0007/Y2iK*\t1$A\u0002pe\u001e\u001c2\u0001A\u000f$!\tq\u0012%D\u0001 \u0015\u0005\u0001\u0013!B:dC2\f\u0017B\u0001\u0012 \u0005\u0019\te.\u001f*fMB\u0011a\u0004J\u0005\u0003K}\u0011AbU3sS\u0006d\u0017N_1cY\u0016\fA\"\u001b8ji&\fG.T8eK2\u001c\u0001!F\u0001*!\tQ3&D\u0001\u0013\u0013\ta#CA\rXK&<\u0007\u000e^3e\u0019\u0016\f7\u000f^*rk\u0006\u0014Xm]'pI\u0016d\u0017!D5oSRL\u0017\r\\'pI\u0016d\u0007%\u0001\u0007sK^,\u0017n\u001a5u\rVt7-F\u00011!\u0015q\u0012gM\u0015:\u0013\t\u0011tDA\u0005Gk:\u001cG/[8oeA\u0011AgN\u0007\u0002k)\u0011a\u0007F\u0001\bM\u0016\fG/\u001e:f\u0013\tATG\u0001\bPM\u001a\u001cX\r^%ogR\fgnY3\u0011\tyQD\bP\u0005\u0003w}\u0011a\u0001V;qY\u0016\u0014\u0004C\u0001\u0010>\u0013\tqtD\u0001\u0004E_V\u0014G.Z\u0001\u000ee\u0016<X-[4ii\u001a+hn\u0019\u0011\u0002\u0019\u0019LG/\u00138uKJ\u001cW\r\u001d;\u0016\u0003\t\u0003\"AH\"\n\u0005\u0011{\"a\u0002\"p_2,\u0017M\\\u0001\u000eM&$\u0018J\u001c;fe\u000e,\u0007\u000f\u001e\u0011\u0002\u0011I,w\rU1sC6,\u0012\u0001P\u0001\ne\u0016<\u0007+\u0019:b[\u0002\nq!\\1y\u0013R,'/F\u0001L!\tqB*\u0003\u0002N?\t\u0019\u0011J\u001c;\u0002\u00115\f\u00070\u0013;fe\u0002\n1\u0001^8m\u0003\u0011!x\u000e\u001c\u0011\u0002\rqJg.\u001b;?)\u001d\u0019F+\u0016,X1f\u0003\"A\u000b\u0001\t\u000b\u0019j\u0001\u0019A\u0015\t\u000b9j\u0001\u0019\u0001\u0019\t\u000b\u0001k\u0001\u0019\u0001\"\t\u000b\u0019k\u0001\u0019\u0001\u001f\t\u000b%k\u0001\u0019A&\t\u000b=k\u0001\u0019\u0001\u001f\u0002\u0007\u0019LG\u000f\u0006\u0003]?\u001e|\u0007C\u0001\u0016^\u0013\tq&C\u0001\u0014Ji\u0016\u0014\u0018\r^5wK2L(+Z<fS\u001eDG/\u001a3MK\u0006\u001cHoU9vCJ,7/T8eK2DQ\u0001\u0019\bA\u0002\u0005\f\u0011\"\u001b8ti\u0006t7-Z:\u0011\u0007\t,7'D\u0001d\u0015\t!g#A\u0002sI\u0012L!AZ2\u0003\u0007I#E\tC\u0004i\u001dA\u0005\t\u0019A5\u0002\u000b%t7\u000f\u001e:\u0011\u0005)lW\"A6\u000b\u00051$\u0012\u0001B;uS2L!A\\6\u0003/=\u0003H/[8oC2Len\u001d;sk6,g\u000e^1uS>t\u0007b\u00029\u000f!\u0003\u0005\raS\u0001\u0006I\u0016\u0004H\u000f[\u0001\u000eM&$H\u0005Z3gCVdG\u000f\n\u001a\u0016\u0003MT#!\u001b;,\u0003U\u0004\"A^>\u000e\u0003]T!\u0001_=\u0002\u0013Ut7\r[3dW\u0016$'B\u0001> \u0003)\tgN\\8uCRLwN\\\u0005\u0003y^\u0014\u0011#\u001e8dQ\u0016\u001c7.\u001a3WCJL\u0017M\\2f\u000351\u0017\u000e\u001e\u0013eK\u001a\fW\u000f\u001c;%gU\tqP\u000b\u0002Li\u0002")
public class IterativelyReweightedLeastSquares
implements scala.Serializable {
    private final WeightedLeastSquaresModel initialModel;
    private final Function2<OffsetInstance, WeightedLeastSquaresModel, Tuple2<Object, Object>> reweightFunc;
    private final boolean fitIntercept;
    private final double regParam;
    private final int maxIter;
    private final double tol;

    public WeightedLeastSquaresModel initialModel() {
        return this.initialModel;
    }

    public Function2<OffsetInstance, WeightedLeastSquaresModel, Tuple2<Object, Object>> reweightFunc() {
        return this.reweightFunc;
    }

    public boolean fitIntercept() {
        return this.fitIntercept;
    }

    public double regParam() {
        return this.regParam;
    }

    public int maxIter() {
        return this.maxIter;
    }

    public double tol() {
        return this.tol;
    }

    public IterativelyReweightedLeastSquaresModel fit(RDD<OffsetInstance> instances, OptionalInstrumentation instr, int depth) {
        boolean converged = false;
        IntRef iter = IntRef.create((int)0);
        WeightedLeastSquaresModel model = this.initialModel();
        ObjectRef oldModel = ObjectRef.create(null);
        while (iter.elem < this.maxIter() && !converged) {
            oldModel.elem = model;
            RDD newInstances = instances.map((Function1 & Serializable & scala.Serializable)instance -> {
                Tuple2 tuple2 = (Tuple2)this.reweightFunc().apply(instance, (Object)((WeightedLeastSquaresModel)oldModel$1.elem));
                if (tuple2 == null) {
                    throw new MatchError((Object)tuple2);
                }
                double newLabel = tuple2._1$mcD$sp();
                double newWeight = tuple2._2$mcD$sp();
                Tuple2.mcDD.sp sp2 = new Tuple2.mcDD.sp(newLabel, newWeight);
                Tuple2.mcDD.sp sp3 = sp2;
                double newLabel2 = sp3._1$mcD$sp();
                double newWeight2 = sp3._2$mcD$sp();
                return new Instance(newLabel2, newWeight2, instance.features());
            }, ClassTag$.MODULE$.apply(Instance.class));
            model = new WeightedLeastSquares(this.fitIntercept(), this.regParam(), 0.0, false, false, WeightedLeastSquares$.MODULE$.$lessinit$greater$default$6(), WeightedLeastSquares$.MODULE$.$lessinit$greater$default$7(), WeightedLeastSquares$.MODULE$.$lessinit$greater$default$8()).fit((RDD<Instance>)newInstances, instr, depth);
            DenseVector oldCoefficients = ((WeightedLeastSquaresModel)oldModel.elem).coefficients();
            DenseVector coefficients = model.coefficients();
            BLAS$.MODULE$.axpy(-1.0, (Vector)coefficients, (Vector)oldCoefficients);
            double maxTolOfCoefficients = BoxesRunTime.unboxToDouble((Object)new ArrayOps.ofDouble(Predef$.MODULE$.doubleArrayOps(oldCoefficients.toArray())).foldLeft((Object)BoxesRunTime.boxToDouble((double)0.0), (Function2)(JFunction2.mcDDD.sp & Serializable & scala.Serializable)(x, y) -> package$.MODULE$.max(package$.MODULE$.abs(x), package$.MODULE$.abs(y))));
            double maxTol = package$.MODULE$.max(maxTolOfCoefficients, package$.MODULE$.abs(((WeightedLeastSquaresModel)oldModel.elem).intercept() - model.intercept()));
            if (maxTol < this.tol()) {
                converged = true;
                instr.logInfo((Function0<String>)(Function0 & Serializable & scala.Serializable)() -> new StringBuilder(30).append("IRLS converged in ").append(iter$1.elem).append(" iterations.").toString());
            }
            instr.logInfo((Function0<String>)(Function0 & Serializable & scala.Serializable)() -> new StringBuilder(34).append("Iteration ").append(iter$1.elem).append(" : relative tolerance = ").append(maxTol).toString());
            ++iter.elem;
            if (iter.elem != this.maxIter()) continue;
            instr.logInfo((Function0<String>)(Function0 & Serializable & scala.Serializable)() -> new StringBuilder(44).append("IRLS reached the max number of iterations: ").append(this.maxIter()).append(".").toString());
        }
        return new IterativelyReweightedLeastSquaresModel(model.coefficients(), model.intercept(), model.diagInvAtWA(), iter.elem);
    }

    public OptionalInstrumentation fit$default$2() {
        return OptionalInstrumentation$.MODULE$.create(IterativelyReweightedLeastSquares.class);
    }

    public int fit$default$3() {
        return 2;
    }

    public IterativelyReweightedLeastSquares(WeightedLeastSquaresModel initialModel, Function2<OffsetInstance, WeightedLeastSquaresModel, Tuple2<Object, Object>> reweightFunc, boolean fitIntercept, double regParam, int maxIter, double tol) {
        this.initialModel = initialModel;
        this.reweightFunc = reweightFunc;
        this.fitIntercept = fitIntercept;
        this.regParam = regParam;
        this.maxIter = maxIter;
        this.tol = tol;
    }
}

