/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.ml.ann;

import breeze.linalg.DenseMatrix;
import breeze.linalg.DenseVector;
import java.io.Serializable;
import org.apache.spark.ml.ann.FeedForwardModel$;
import org.apache.spark.ml.ann.FeedForwardTopology;
import org.apache.spark.ml.ann.Layer;
import org.apache.spark.ml.ann.LayerModel;
import org.apache.spark.ml.ann.LossFunction;
import org.apache.spark.ml.ann.TopologyModel;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.linalg.Vectors$;
import scala.Function1;
import scala.Predef$;
import scala.collection.mutable.ArrayOps;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.IntRef;
import scala.runtime.RichInt$;
import scala.runtime.java8.JFunction1;

@ScalaSignature(bytes="\u0006\u0001\u0005\u0005e!B\u0001\u0003\u0001\u0011a!\u0001\u0005$fK\u00124uN]<be\u0012lu\u000eZ3m\u0015\t\u0019A!A\u0002b]:T!!\u0002\u0004\u0002\u00055d'BA\u0004\t\u0003\u0015\u0019\b/\u0019:l\u0015\tI!\"\u0001\u0004ba\u0006\u001c\u0007.\u001a\u0006\u0002\u0017\u0005\u0019qN]4\u0014\u0007\u0001i1\u0003\u0005\u0002\u000f#5\tqBC\u0001\u0011\u0003\u0015\u00198-\u00197b\u0013\t\u0011rB\u0001\u0004B]f\u0014VM\u001a\t\u0003)Ui\u0011AA\u0005\u0003-\t\u0011Q\u0002V8q_2|w-_'pI\u0016d\u0007\u0002\u0003\r\u0001\u0005\u000b\u0007I\u0011\u0001\u000e\u0002\u000f],\u0017n\u001a5ug\u000e\u0001Q#A\u000e\u0011\u0005qyR\"A\u000f\u000b\u0005y!\u0011A\u00027j]\u0006dw-\u0003\u0002!;\t1a+Z2u_JD\u0001B\t\u0001\u0003\u0002\u0003\u0006IaG\u0001\to\u0016Lw\r\u001b;tA!AA\u0005\u0001BC\u0002\u0013\u0005Q%\u0001\u0005u_B|Gn\\4z+\u00051\u0003C\u0001\u000b(\u0013\tA#AA\nGK\u0016$gi\u001c:xCJ$Gk\u001c9pY><\u0017\u0010\u0003\u0005+\u0001\t\u0005\t\u0015!\u0003'\u0003%!x\u000e]8m_\u001eL\b\u0005C\u0003-\u0001\u0011%Q&\u0001\u0004=S:LGO\u0010\u000b\u0004]=\u0002\u0004C\u0001\u000b\u0001\u0011\u0015A2\u00061\u0001\u001c\u0011\u0015!3\u00061\u0001'\u0011\u001d\u0011\u0004A1A\u0005\u0002M\na\u0001\\1zKJ\u001cX#\u0001\u001b\u0011\u00079)t'\u0003\u00027\u001f\t)\u0011I\u001d:bsB\u0011A\u0003O\u0005\u0003s\t\u0011Q\u0001T1zKJDaa\u000f\u0001!\u0002\u0013!\u0014a\u00027bs\u0016\u00148\u000f\t\u0005\b{\u0001\u0011\r\u0011\"\u0001?\u0003-a\u0017-_3s\u001b>$W\r\\:\u0016\u0003}\u00022AD\u001bA!\t!\u0012)\u0003\u0002C\u0005\tQA*Y=fe6{G-\u001a7\t\r\u0011\u0003\u0001\u0015!\u0003@\u00031a\u0017-_3s\u001b>$W\r\\:!\u0011\u001d1\u0005\u00011A\u0005\n\u001d\u000baa\u001c4gg\u0016$X#\u0001%\u0011\u00059I\u0015B\u0001&\u0010\u0005\rIe\u000e\u001e\u0005\b\u0019\u0002\u0001\r\u0011\"\u0003N\u0003)ygMZ:fi~#S-\u001d\u000b\u0003\u001dF\u0003\"AD(\n\u0005A{!\u0001B+oSRDqAU&\u0002\u0002\u0003\u0007\u0001*A\u0002yIEBa\u0001\u0016\u0001!B\u0013A\u0015aB8gMN,G\u000f\t\u0005\b-\u0002\u0001\r\u0011\"\u0003X\u0003\u001dyW\u000f\u001e9viN,\u0012\u0001\u0017\t\u0004\u001dUJ\u0006c\u0001._A6\t1L\u0003\u0002\u001f9*\tQ,\u0001\u0004ce\u0016,'0Z\u0005\u0003?n\u00131\u0002R3og\u0016l\u0015\r\u001e:jqB\u0011a\"Y\u0005\u0003E>\u0011a\u0001R8vE2,\u0007b\u00023\u0001\u0001\u0004%I!Z\u0001\f_V$\b/\u001e;t?\u0012*\u0017\u000f\u0006\u0002OM\"9!kYA\u0001\u0002\u0004A\u0006B\u00025\u0001A\u0003&\u0001,\u0001\u0005pkR\u0004X\u000f^:!\u0011\u001dQ\u0007\u00011A\u0005\n]\u000ba\u0001Z3mi\u0006\u001c\bb\u00027\u0001\u0001\u0004%I!\\\u0001\u000bI\u0016dG/Y:`I\u0015\fHC\u0001(o\u0011\u001d\u00116.!AA\u0002aCa\u0001\u001d\u0001!B\u0013A\u0016a\u00023fYR\f7\u000f\t\u0005\u0006e\u0002!\te]\u0001\bM>\u0014x/\u0019:e)\rAFO\u001e\u0005\u0006kF\u0004\r!W\u0001\u0005I\u0006$\u0018\rC\u0003xc\u0002\u0007\u00010\u0001\tj]\u000edW\u000fZ3MCN$H*Y=feB\u0011a\"_\u0005\u0003u>\u0011qAQ8pY\u0016\fg\u000eC\u0003}\u0001\u0011\u0005S0A\bd_6\u0004X\u000f^3He\u0006$\u0017.\u001a8u)\u001d\u0001gp`A\u0002\u0003\u000fAQ!^>A\u0002eCa!!\u0001|\u0001\u0004I\u0016A\u0002;be\u001e,G\u000f\u0003\u0004\u0002\u0006m\u0004\raG\u0001\fGVlwI]1eS\u0016tG\u000f\u0003\u0004\u0002\nm\u0004\r\u0001S\u0001\u000ee\u0016\fGNQ1uG\"\u001c\u0016N_3\t\u000f\u00055\u0001\u0001\"\u0011\u0002\u0010\u00059\u0001O]3eS\u000e$HcA\u000e\u0002\u0012!1Q/a\u0003A\u0002mAq!!\u0006\u0001\t\u0003\n9\"\u0001\u0006qe\u0016$\u0017n\u0019;SC^$2aGA\r\u0011\u0019)\u00181\u0003a\u00017!9\u0011Q\u0004\u0001\u0005B\u0005}\u0011A\u0006:boJ\u0002&o\u001c2bE&d\u0017\u000e^=J]Bc\u0017mY3\u0015\u0007m\t\t\u0003\u0003\u0004v\u00037\u0001\raG\u0004\t\u0003K\u0011\u0001\u0012\u0001\u0002\u0002(\u0005\u0001b)Z3e\r>\u0014x/\u0019:e\u001b>$W\r\u001c\t\u0004)\u0005%baB\u0001\u0003\u0011\u0003\u0011\u00111F\n\u0006\u0003Si\u0011Q\u0006\t\u0004\u001d\u0005=\u0012bAA\u0019\u001f\ta1+\u001a:jC2L'0\u00192mK\"9A&!\u000b\u0005\u0002\u0005UBCAA\u0014\u0011!\tI$!\u000b\u0005\u0002\u0005m\u0012!B1qa2LH#\u0002\u0018\u0002>\u0005}\u0002B\u0002\u0013\u00028\u0001\u0007a\u0005\u0003\u0004\u0019\u0003o\u0001\ra\u0007\u0005\t\u0003s\tI\u0003\"\u0001\u0002DQ)a&!\u0012\u0002H!1A%!\u0011A\u0002\u0019B!\"!\u0013\u0002BA\u0005\t\u0019AA&\u0003\u0011\u0019X-\u001a3\u0011\u00079\ti%C\u0002\u0002P=\u0011A\u0001T8oO\"Q\u00111KA\u0015#\u0003%\t!!\u0016\u0002\u001f\u0005\u0004\b\u000f\\=%I\u00164\u0017-\u001e7uII*\"!a\u0016+\t\u0005-\u0013\u0011L\u0016\u0003\u00037\u0002B!!\u0018\u0002h5\u0011\u0011q\f\u0006\u0005\u0003C\n\u0019'A\u0005v]\u000eDWmY6fI*\u0019\u0011QM\b\u0002\u0015\u0005tgn\u001c;bi&|g.\u0003\u0003\u0002j\u0005}#!E;oG\",7m[3e-\u0006\u0014\u0018.\u00198dK\"Q\u0011QNA\u0015\u0003\u0003%I!a\u001c\u0002\u0017I,\u0017\r\u001a*fg>dg/\u001a\u000b\u0003\u0003c\u0002B!a\u001d\u0002~5\u0011\u0011Q\u000f\u0006\u0005\u0003o\nI(\u0001\u0003mC:<'BAA>\u0003\u0011Q\u0017M^1\n\t\u0005}\u0014Q\u000f\u0002\u0007\u001f\nTWm\u0019;")
public class FeedForwardModel
implements TopologyModel {
    private final Vector weights;
    private final FeedForwardTopology topology;
    private final Layer[] layers;
    private final LayerModel[] layerModels;
    private int offset;
    private DenseMatrix<Object>[] outputs;
    private DenseMatrix<Object>[] deltas;

    public static long apply$default$2() {
        return FeedForwardModel$.MODULE$.apply$default$2();
    }

    public static FeedForwardModel apply(FeedForwardTopology feedForwardTopology, long l) {
        return FeedForwardModel$.MODULE$.apply(feedForwardTopology, l);
    }

    public static FeedForwardModel apply(FeedForwardTopology feedForwardTopology, Vector vector) {
        return FeedForwardModel$.MODULE$.apply(feedForwardTopology, vector);
    }

    @Override
    public Vector weights() {
        return this.weights;
    }

    public FeedForwardTopology topology() {
        return this.topology;
    }

    @Override
    public Layer[] layers() {
        return this.layers;
    }

    @Override
    public LayerModel[] layerModels() {
        return this.layerModels;
    }

    private int offset() {
        return this.offset;
    }

    private void offset_$eq(int x$1) {
        this.offset = x$1;
    }

    private DenseMatrix<Object>[] outputs() {
        return this.outputs;
    }

    private void outputs_$eq(DenseMatrix<Object>[] x$1) {
        this.outputs = x$1;
    }

    private DenseMatrix<Object>[] deltas() {
        return this.deltas;
    }

    private void deltas_$eq(DenseMatrix<Object>[] x$1) {
        this.deltas = x$1;
    }

    @Override
    public DenseMatrix<Object>[] forward(DenseMatrix<Object> data, boolean includeLastLayer) {
        int currentBatchSize = data.cols();
        if (this.outputs() == null || this.outputs()[0].cols() != currentBatchSize) {
            this.outputs_$eq(new DenseMatrix[this.layers().length]);
            IntRef inputSize = IntRef.create((int)data.rows());
            RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), this.layers().length).foreach$mVc$sp((Function1)(JFunction1.mcVI.sp & Serializable & scala.Serializable)i -> {
                if (this.layers()[i].inPlace()) {
                    $this.outputs()[i] = this.outputs()[i - 1];
                } else {
                    int outputSize = this.layers()[i].getOutputSize(inputSize$1.elem);
                    $this.outputs()[i] = new DenseMatrix.mcD.sp(outputSize, currentBatchSize, ClassTag$.MODULE$.Double());
                    inputSize$1.elem = outputSize;
                }
            });
        }
        this.layerModels()[0].eval(data, this.outputs()[0]);
        int end = includeLastLayer ? this.layerModels().length : this.layerModels().length - 1;
        RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(1), end).foreach$mVc$sp((Function1)(JFunction1.mcVI.sp & Serializable & scala.Serializable)i -> this.layerModels()[i].eval(this.outputs()[i - 1], this.outputs()[i]));
        return this.outputs();
    }

    @Override
    public double computeGradient(DenseMatrix<Object> data, DenseMatrix<Object> target, Vector cumGradient, int realBatchSize) {
        DenseMatrix<Object>[] outputs = this.forward(data, true);
        int currentBatchSize = data.cols();
        if (this.deltas() == null || this.deltas()[0].cols() != currentBatchSize) {
            this.deltas_$eq(new DenseMatrix[this.layerModels().length]);
            IntRef inputSize = IntRef.create((int)data.rows());
            RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), this.layerModels().length - 1).foreach$mVc$sp((Function1)(JFunction1.mcVI.sp & Serializable & scala.Serializable)i -> {
                int outputSize = this.layers()[i].getOutputSize(inputSize$2.elem);
                $this.deltas()[i] = new DenseMatrix.mcD.sp(outputSize, currentBatchSize, ClassTag$.MODULE$.Double());
                inputSize$2.elem = outputSize;
            });
        }
        int L = this.layerModels().length - 1;
        LayerModel layerModel = (LayerModel)new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])this.layerModels())).last();
        if (!(layerModel instanceof LossFunction)) {
            throw new UnsupportedOperationException("Top layer is required to have objective.");
        }
        LayerModel layerModel2 = layerModel;
        double d = ((LossFunction)((Object)layerModel2)).loss((DenseMatrix<Object>)((DenseMatrix)new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])outputs)).last()), target, this.deltas()[L - 1]);
        double loss = d;
        RichInt$.MODULE$.to$extension1(Predef$.MODULE$.intWrapper(L - 2), 0, -1).foreach$mVc$sp((Function1)(JFunction1.mcVI.sp & Serializable & scala.Serializable)i -> this.layerModels()[i + 1].computePrevDelta(this.deltas()[i + 1], (DenseMatrix<Object>)outputs[i + 1], this.deltas()[i]));
        double[] cumGradientArray = cumGradient.toArray();
        IntRef offset = IntRef.create((int)0);
        RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), this.layerModels().length).foreach$mVc$sp((Function1)(JFunction1.mcVI.sp & Serializable & scala.Serializable)i -> {
            DenseMatrix input = i == 0 ? data : outputs[i - 1];
            this.layerModels()[i].grad(this.deltas()[i], (DenseMatrix<Object>)input, (DenseVector<Object>)new DenseVector.mcD.sp(cumGradientArray, offset$1.elem, 1, this.layers()[i].weightSize()));
            offset$1.elem += this.layers()[i].weightSize();
        });
        return loss;
    }

    @Override
    public Vector predict(Vector data) {
        int size = data.size();
        DenseMatrix<Object>[] result = this.forward((DenseMatrix<Object>)new DenseMatrix.mcD.sp(size, 1, data.toArray()), true);
        return Vectors$.MODULE$.dense(((DenseMatrix)new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])result)).last()).toArray$mcD$sp());
    }

    @Override
    public Vector predictRaw(Vector data) {
        DenseMatrix<Object>[] result = this.forward((DenseMatrix<Object>)new DenseMatrix.mcD.sp(data.size(), 1, data.toArray()), false);
        return Vectors$.MODULE$.dense(result[result.length - 2].toArray$mcD$sp());
    }

    @Override
    public Vector raw2ProbabilityInPlace(Vector data) {
        DenseMatrix.mcD.sp dataMatrix = new DenseMatrix.mcD.sp(data.size(), 1, data.toArray());
        ((LayerModel)new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])this.layerModels())).last()).eval((DenseMatrix<Object>)dataMatrix, (DenseMatrix<Object>)dataMatrix);
        return data;
    }

    public FeedForwardModel(Vector weights, FeedForwardTopology topology) {
        this.weights = weights;
        this.topology = topology;
        this.layers = topology.layers();
        this.layerModels = new LayerModel[this.layers().length];
        this.offset = 0;
        RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), this.layers().length).foreach$mVc$sp((Function1)(JFunction1.mcVI.sp & Serializable & scala.Serializable)i -> {
            $this.layerModels()[i] = this.layers()[i].createModel((DenseVector<Object>)new DenseVector.mcD.sp(this.weights().toArray(), this.offset(), 1, this.layers()[i].weightSize()));
            this.offset_$eq(this.offset() + this.layers()[i].weightSize());
        });
        this.outputs = null;
        this.deltas = null;
    }
}

