/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.mllib.clustering;

import breeze.linalg.DenseVector;
import org.apache.spark.SparkContext;
import org.apache.spark.annotation.Experimental;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.stat.distribution.MultivariateGaussian;
import org.apache.spark.mllib.util.MLUtils$;
import org.apache.spark.rdd.RDD;
import scala.Array$;
import scala.Function0;
import scala.Function1;
import scala.MatchError;
import scala.Predef$;
import scala.Serializable;
import scala.Tuple2;
import scala.collection.GenIterable;
import scala.math.Numeric;
import scala.math.Ordering;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxesRunTime;
import scala.runtime.RichInt$;
import scala.runtime.ScalaRunTime$;

@Experimental
@ScalaSignature(bytes="\u0006\u0001-4A!\u0001\u0002\u0001\u001b\t!r)Y;tg&\fg.T5yiV\u0014X-T8eK2T!a\u0001\u0003\u0002\u0015\rdWo\u001d;fe&twM\u0003\u0002\u0006\r\u0005)Q\u000e\u001c7jE*\u0011q\u0001C\u0001\u0006gB\f'o\u001b\u0006\u0003\u0013)\ta!\u00199bG\",'\"A\u0006\u0002\u0007=\u0014xm\u0001\u0001\u0014\u0007\u0001qA\u0003\u0005\u0002\u0010%5\t\u0001CC\u0001\u0012\u0003\u0015\u00198-\u00197b\u0013\t\u0019\u0002C\u0001\u0004B]f\u0014VM\u001a\t\u0003\u001fUI!A\u0006\t\u0003\u0019M+'/[1mSj\f'\r\\3\t\u0011a\u0001!Q1A\u0005\u0002e\tqa^3jO\"$8/F\u0001\u001b!\ry1$H\u0005\u00039A\u0011Q!\u0011:sCf\u0004\"a\u0004\u0010\n\u0005}\u0001\"A\u0002#pk\ndW\r\u0003\u0005\"\u0001\t\u0005\t\u0015!\u0003\u001b\u0003!9X-[4iiN\u0004\u0003\u0002C\u0012\u0001\u0005\u000b\u0007I\u0011\u0001\u0013\u0002\u0013\u001d\fWo]:jC:\u001cX#A\u0013\u0011\u0007=Yb\u0005\u0005\u0002(Y5\t\u0001F\u0003\u0002*U\u0005aA-[:ue&\u0014W\u000f^5p]*\u00111\u0006B\u0001\u0005gR\fG/\u0003\u0002.Q\t!R*\u001e7uSZ\f'/[1uK\u001e\u000bWo]:jC:D\u0001b\f\u0001\u0003\u0002\u0003\u0006I!J\u0001\u000bO\u0006,8o]5b]N\u0004\u0003\"B\u0019\u0001\t\u0003\u0011\u0014A\u0002\u001fj]&$h\bF\u00024kY\u0002\"\u0001\u000e\u0001\u000e\u0003\tAQ\u0001\u0007\u0019A\u0002iAQa\t\u0019A\u0002\u0015BQ\u0001\u000f\u0001\u0005\u0002e\n\u0011a[\u000b\u0002uA\u0011qbO\u0005\u0003yA\u00111!\u00138u\u0011\u0015q\u0004\u0001\"\u0001@\u0003\u001d\u0001(/\u001a3jGR$\"\u0001\u0011$\u0011\u0007\u0005#%(D\u0001C\u0015\t\u0019e!A\u0002sI\u0012L!!\u0012\"\u0003\u0007I#E\tC\u0003H{\u0001\u0007\u0001*\u0001\u0004q_&tGo\u001d\t\u0004\u0003\u0012K\u0005C\u0001&N\u001b\u0005Y%B\u0001'\u0005\u0003\u0019a\u0017N\\1mO&\u0011aj\u0013\u0002\u0007-\u0016\u001cGo\u001c:\t\u000bA\u0003A\u0011A)\u0002\u0017A\u0014X\rZ5diN{g\r\u001e\u000b\u0003%N\u00032!\u0011#\u001b\u0011\u00159u\n1\u0001I\u0011\u0015)\u0006\u0001\"\u0003W\u0003Y\u0019w.\u001c9vi\u0016\u001cvN\u001a;BgNLwM\\7f]R\u001cH#\u0002\u000eXA\n\u001c\u0007\"\u0002-U\u0001\u0004I\u0016A\u00019u!\rQf,H\u0007\u00027*\u0011A\n\u0018\u0006\u0002;\u00061!M]3fu\u0016L!aX.\u0003\u0017\u0011+gn]3WK\u000e$xN\u001d\u0005\u0006CR\u0003\r!J\u0001\u0006I&\u001cHo\u001d\u0005\u00061Q\u0003\rA\u0007\u0005\u0006qQ\u0003\rA\u000f\u0015\u0003\u0001\u0015\u0004\"AZ5\u000e\u0003\u001dT!\u0001\u001b\u0004\u0002\u0015\u0005tgn\u001c;bi&|g.\u0003\u0002kO\naQ\t\u001f9fe&lWM\u001c;bY\u0002")
public class GaussianMixtureModel
implements Serializable {
    private final double[] weights;
    private final MultivariateGaussian[] gaussians;

    public double[] weights() {
        return this.weights;
    }

    public MultivariateGaussian[] gaussians() {
        return this.gaussians;
    }

    public int k() {
        return this.weights().length;
    }

    public RDD<Object> predict(RDD<Vector> points) {
        RDD<double[]> responsibilityMatrix = this.predictSoft(points);
        return responsibilityMatrix.map((Function1)new Serializable(this){
            public static final long serialVersionUID = 0L;

            public final int apply(double[] r) {
                return Predef$.MODULE$.doubleArrayOps(r).indexOf(Predef$.MODULE$.doubleArrayOps(r).max((Ordering)Ordering.Double$.MODULE$));
            }
        }, ClassTag$.MODULE$.Int());
    }

    public RDD<double[]> predictSoft(RDD<Vector> points) {
        SparkContext sc = points.sparkContext();
        Broadcast bcDists = sc.broadcast((Object)this.gaussians(), ClassTag$.MODULE$.apply(ScalaRunTime$.MODULE$.arrayClass(MultivariateGaussian.class)));
        Broadcast bcWeights = sc.broadcast((Object)this.weights(), ClassTag$.MODULE$.apply(ScalaRunTime$.MODULE$.arrayClass(Double.TYPE)));
        return points.map((Function1)new Serializable(this, bcDists, bcWeights){
            public static final long serialVersionUID = 0L;
            private final /* synthetic */ GaussianMixtureModel $outer;
            private final Broadcast bcDists$1;
            private final Broadcast bcWeights$1;

            public final double[] apply(Vector x) {
                return this.$outer.org$apache$spark$mllib$clustering$GaussianMixtureModel$$computeSoftAssignments((DenseVector<Object>)x.toBreeze().toDenseVector$mcD$sp(ClassTag$.MODULE$.Double()), (MultivariateGaussian[])this.bcDists$1.value(), (double[])this.bcWeights$1.value(), this.$outer.k());
            }
            {
                if ($outer == null) {
                    throw new NullPointerException();
                }
                this.$outer = $outer;
                this.bcDists$1 = bcDists$1;
                this.bcWeights$1 = bcWeights$1;
            }
        }, ClassTag$.MODULE$.apply(ScalaRunTime$.MODULE$.arrayClass(Double.TYPE)));
    }

    public double[] org$apache$spark$mllib$clustering$GaussianMixtureModel$$computeSoftAssignments(DenseVector<Object> pt, MultivariateGaussian[] dists, double[] weights, int k) {
        double[] p = (double[])Predef$.MODULE$.refArrayOps((Object[])Predef$.MODULE$.doubleArrayOps(weights).zip((GenIterable)Predef$.MODULE$.wrapRefArray((Object[])dists), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Tuple2.class)))).map((Function1)new Serializable(this, pt){
            public static final long serialVersionUID = 0L;
            private final DenseVector pt$1;

            public final double apply(Tuple2<Object, MultivariateGaussian> x0$1) {
                Tuple2<Object, MultivariateGaussian> tuple2 = x0$1;
                if (tuple2 != null) {
                    double weight = tuple2._1$mcD$sp();
                    MultivariateGaussian dist = (MultivariateGaussian)tuple2._2();
                    double d = MLUtils$.MODULE$.EPSILON() + weight * dist.pdf((breeze.linalg.Vector<Object>)this.pt$1);
                    return d;
                }
                throw new MatchError(tuple2);
            }
            {
                this.pt$1 = pt$1;
            }
        }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.Double()));
        double pSum = BoxesRunTime.unboxToDouble((Object)Predef$.MODULE$.doubleArrayOps(p).sum((Numeric)Numeric.DoubleIsFractional$.MODULE$));
        RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), k).foreach$mVc$sp((Function1)new Serializable(this, p, pSum){
            public static final long serialVersionUID = 0L;
            private final double[] p$1;
            private final double pSum$1;

            public final void apply(int i) {
                this.apply$mcVI$sp(i);
            }

            public void apply$mcVI$sp(int i) {
                this.p$1[i] = this.p$1[i] / this.pSum$1;
            }
            {
                this.p$1 = p$1;
                this.pSum$1 = pSum$1;
            }
        });
        return p;
    }

    public GaussianMixtureModel(double[] weights, MultivariateGaussian[] gaussians) {
        this.weights = weights;
        this.gaussians = gaussians;
        Predef$.MODULE$.require(weights.length == gaussians.length, (Function0)new Serializable(this){
            public static final long serialVersionUID = 0L;

            public final String apply() {
                return "Length of weight and Gaussian arrays must match";
            }
        });
    }
}

