/*
 * Decompiled with CFR 0.152.
 */
package com.tencent.angel.ml.core.optimizer.loss;

import com.tencent.angel.ml.core.network.layers.AngelGraph;
import com.tencent.angel.ml.core.optimizer.loss.LossFunc;
import com.tencent.angel.ml.core.optimizer.loss.LossFunc$class;
import com.tencent.angel.ml.math2.MFactory;
import com.tencent.angel.ml.math2.matrix.BlasDoubleMatrix;
import com.tencent.angel.ml.math2.matrix.BlasFloatMatrix;
import com.tencent.angel.ml.math2.matrix.BlasMatrix;
import com.tencent.angel.ml.math2.matrix.Matrix;
import com.tencent.angel.ml.math2.ufuncs.Ufuncs;
import org.json4s.JsonAST;
import scala.MatchError;
import scala.Predef$;
import scala.StringContext;
import scala.collection.Seq;
import scala.collection.immutable.Nil$;
import scala.reflect.ScalaSignature;

@ScalaSignature(bytes="\u0006\u0001Q3A!\u0001\u0002\u0001#\t1AJ\r'pgNT!a\u0001\u0003\u0002\t1|7o\u001d\u0006\u0003\u000b\u0019\t\u0011b\u001c9uS6L'0\u001a:\u000b\u0005\u001dA\u0011\u0001B2pe\u0016T!!\u0003\u0006\u0002\u00055d'BA\u0006\r\u0003\u0015\tgnZ3m\u0015\tia\"A\u0004uK:\u001cWM\u001c;\u000b\u0003=\t1aY8n\u0007\u0001\u00192\u0001\u0001\n\u0019!\t\u0019b#D\u0001\u0015\u0015\u0005)\u0012!B:dC2\f\u0017BA\f\u0015\u0005\u0019\te.\u001f*fMB\u0011\u0011DG\u0007\u0002\u0005%\u00111D\u0001\u0002\t\u0019>\u001c8OR;oG\")Q\u0004\u0001C\u0001=\u00051A(\u001b8jiz\"\u0012a\b\t\u00033\u0001AQ!\t\u0001\u0005B\t\nqaY1m\u0019>\u001c8\u000fF\u0002$MA\u0002\"a\u0005\u0013\n\u0005\u0015\"\"A\u0002#pk\ndW\rC\u0003(A\u0001\u0007\u0001&\u0001\u0005n_\u0012,GnT;u!\tIc&D\u0001+\u0015\tYC&\u0001\u0004nCR\u0014\u0018\u000e\u001f\u0006\u0003[!\tQ!\\1uQJJ!a\f\u0016\u0003\r5\u000bGO]5y\u0011\u0015\t\u0004\u00051\u00013\u0003\u00159'/\u00199i!\t\u0019\u0004(D\u00015\u0015\t)d'\u0001\u0004mCf,'o\u001d\u0006\u0003o\u0019\tqA\\3uo>\u00148.\u0003\u0002:i\tQ\u0011I\\4fY\u001e\u0013\u0018\r\u001d5\t\u000b\r\u0001A\u0011A\u001e\u0015\u0007\rbd\bC\u0003>u\u0001\u00071%\u0001\u0003qe\u0016$\u0007\"B ;\u0001\u0004\u0019\u0013!\u00027bE\u0016d\u0007\"B!\u0001\t\u0003\u0012\u0015aB2bY\u001e\u0013\u0018\r\u001a\u000b\u0004Q\r#\u0005\"B\u0014A\u0001\u0004A\u0003\"B\u0019A\u0001\u0004\u0011\u0004\"\u0002$\u0001\t\u0003:\u0015a\u00029sK\u0012L7\r\u001e\u000b\u0004Q!K\u0005\"B\u0014F\u0001\u0004A\u0003\"B\u0019F\u0001\u0004\u0011\u0004\"B&\u0001\t\u0003b\u0015\u0001\u0003;p'R\u0014\u0018N\\4\u0015\u00035\u0003\"AT)\u000f\u0005My\u0015B\u0001)\u0015\u0003\u0019\u0001&/\u001a3fM&\u0011!k\u0015\u0002\u0007'R\u0014\u0018N\\4\u000b\u0005A#\u0002")
public class L2Loss
implements LossFunc {
    @Override
    public JsonAST.JObject toJson() {
        return LossFunc$class.toJson(this);
    }

    @Override
    public double calLoss(Matrix modelOut, AngelGraph graph) {
        Matrix matrix = modelOut;
        if (matrix instanceof BlasMatrix) {
            BlasMatrix blasMatrix = (BlasMatrix)matrix;
            double d = 0.5 * Ufuncs.pow((Matrix)blasMatrix.sub(graph.placeHolder().getLabel()), (double)2.0).average();
            return d;
        }
        throw new MatchError((Object)matrix);
    }

    @Override
    public double loss(double pred, double label) {
        double diff = pred - label;
        return 0.5 * diff * diff;
    }

    @Override
    public Matrix calGrad(Matrix modelOut, AngelGraph graph) {
        return modelOut.sub(graph.placeHolder().getLabel());
    }

    @Override
    public Matrix predict(Matrix modelOut, AngelGraph graph) {
        Matrix matrix;
        block4: {
            BlasDoubleMatrix blasDoubleMatrix;
            block3: {
                block2: {
                    matrix = modelOut;
                    if (!(matrix instanceof BlasDoubleMatrix)) break block2;
                    BlasDoubleMatrix blasDoubleMatrix2 = (BlasDoubleMatrix)matrix;
                    BlasDoubleMatrix mat = MFactory.denseDoubleMatrix((int)blasDoubleMatrix2.getNumRows(), (int)3);
                    mat.setCol(0, blasDoubleMatrix2.getCol(0));
                    blasDoubleMatrix = mat;
                    break block3;
                }
                if (!(matrix instanceof BlasFloatMatrix)) break block4;
                BlasFloatMatrix blasFloatMatrix = (BlasFloatMatrix)matrix;
                BlasFloatMatrix mat = MFactory.denseFloatMatrix((int)blasFloatMatrix.getNumRows(), (int)3);
                mat.setCol(0, blasFloatMatrix.getCol(0));
                blasDoubleMatrix = mat;
            }
            return blasDoubleMatrix;
        }
        throw new MatchError((Object)matrix);
    }

    public String toString() {
        return new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"L2Loss"})).s((Seq)Nil$.MODULE$);
    }

    public L2Loss() {
        LossFunc$class.$init$(this);
    }
}

