/*
 * Decompiled with CFR 0.152.
 */
package com.tencent.angel.ml.core.optimizer.loss;

import com.tencent.angel.ml.core.network.layers.AngelGraph;
import com.tencent.angel.ml.core.optimizer.loss.LossFunc;
import com.tencent.angel.ml.core.optimizer.loss.LossFunc$class;
import com.tencent.angel.ml.math2.MFactory;
import com.tencent.angel.ml.math2.matrix.BlasDoubleMatrix;
import com.tencent.angel.ml.math2.matrix.BlasFloatMatrix;
import com.tencent.angel.ml.math2.matrix.Matrix;
import com.tencent.angel.ml.math2.ufuncs.LossFuncs;
import org.json4s.JsonAST;
import scala.Array$;
import scala.Function1;
import scala.MatchError;
import scala.Predef$;
import scala.Serializable;
import scala.StringContext;
import scala.Tuple2;
import scala.collection.Seq;
import scala.collection.immutable.Nil$;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;

@ScalaSignature(bytes="\u0006\u0001Q3A!\u0001\u0002\u0001#\tI\u0001*\u001b8hK2{7o\u001d\u0006\u0003\u0007\u0011\tA\u0001\\8tg*\u0011QAB\u0001\n_B$\u0018.\\5{KJT!a\u0002\u0005\u0002\t\r|'/\u001a\u0006\u0003\u0013)\t!!\u001c7\u000b\u0005-a\u0011!B1oO\u0016d'BA\u0007\u000f\u0003\u001d!XM\\2f]RT\u0011aD\u0001\u0004G>l7\u0001A\n\u0004\u0001IA\u0002CA\n\u0017\u001b\u0005!\"\"A\u000b\u0002\u000bM\u001c\u0017\r\\1\n\u0005]!\"AB!osJ+g\r\u0005\u0002\u001a55\t!!\u0003\u0002\u001c\u0005\tAAj\\:t\rVt7\rC\u0003\u001e\u0001\u0011\u0005a$\u0001\u0004=S:LGO\u0010\u000b\u0002?A\u0011\u0011\u0004\u0001\u0005\u0006C\u0001!\tEI\u0001\bG\u0006dGj\\:t)\r\u0019c\u0005\r\t\u0003'\u0011J!!\n\u000b\u0003\r\u0011{WO\u00197f\u0011\u00159\u0003\u00051\u0001)\u0003!iw\u000eZ3m\u001fV$\bCA\u0015/\u001b\u0005Q#BA\u0016-\u0003\u0019i\u0017\r\u001e:jq*\u0011Q\u0006C\u0001\u0006[\u0006$\bNM\u0005\u0003_)\u0012a!T1ue&D\b\"B\u0019!\u0001\u0004\u0011\u0014!B4sCBD\u0007CA\u001a9\u001b\u0005!$BA\u001b7\u0003\u0019a\u0017-_3sg*\u0011qGB\u0001\b]\u0016$xo\u001c:l\u0013\tIDG\u0001\u0006B]\u001e,Gn\u0012:ba\"DQa\u0001\u0001\u0005Bm\"2a\t\u001f?\u0011\u0015i$\b1\u0001$\u0003\u0011\u0001(/\u001a3\t\u000b}R\u0004\u0019A\u0012\u0002\u000b1\f'-\u001a7\t\u000b\u0005\u0003A\u0011\t\"\u0002\u000f\r\fGn\u0012:bIR\u0019\u0001f\u0011#\t\u000b\u001d\u0002\u0005\u0019\u0001\u0015\t\u000bE\u0002\u0005\u0019\u0001\u001a\t\u000b\u0019\u0003A\u0011I$\u0002\u000fA\u0014X\rZ5diR\u0019\u0001\u0006S%\t\u000b\u001d*\u0005\u0019\u0001\u0015\t\u000bE*\u0005\u0019\u0001\u001a\t\u000b-\u0003A\u0011\t'\u0002\u0011Q|7\u000b\u001e:j]\u001e$\u0012!\u0014\t\u0003\u001dFs!aE(\n\u0005A#\u0012A\u0002)sK\u0012,g-\u0003\u0002S'\n11\u000b\u001e:j]\u001eT!\u0001\u0015\u000b")
public class HingeLoss
implements LossFunc {
    @Override
    public JsonAST.JObject toJson() {
        return LossFunc$class.toJson(this);
    }

    @Override
    public double calLoss(Matrix modelOut, AngelGraph graph) {
        return LossFuncs.hingeloss((Matrix)modelOut, (Matrix)graph.placeHolder().getLabel()).average();
    }

    @Override
    public double loss(double pred, double label) {
        return Math.max(0.0, 1.0 - pred * label);
    }

    @Override
    public Matrix calGrad(Matrix modelOut, AngelGraph graph) {
        return LossFuncs.gradhingeloss((Matrix)modelOut, (Matrix)graph.placeHolder().getLabel());
    }

    @Override
    public Matrix predict(Matrix modelOut, AngelGraph graph) {
        Matrix matrix;
        block4: {
            BlasDoubleMatrix blasDoubleMatrix;
            block3: {
                block2: {
                    matrix = modelOut;
                    if (!(matrix instanceof BlasDoubleMatrix)) break block2;
                    BlasDoubleMatrix blasDoubleMatrix2 = (BlasDoubleMatrix)matrix;
                    double[] temp = blasDoubleMatrix2.getData();
                    double[] data = new double[temp.length * 3];
                    Predef$.MODULE$.refArrayOps((Object[])Predef$.MODULE$.doubleArrayOps(temp).zipWithIndex(Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Tuple2.class)))).foreach((Function1)new Serializable(this, data){
                        public static final long serialVersionUID = 0L;
                        private final double[] data$3;

                        public final void apply(Tuple2<Object, Object> x0$3) {
                            Tuple2<Object, Object> tuple2 = x0$3;
                            if (tuple2 != null) {
                                double value = tuple2._1$mcD$sp();
                                int idx = tuple2._2$mcI$sp();
                                this.data$3[3 * idx] = value;
                                this.data$3[3 * idx + 1] = 1.0 / (1.0 + Math.exp(-2.0 * value));
                                this.data$3[3 * idx + 2] = value > 0.0 ? 1.0 : -1.0;
                                BoxedUnit boxedUnit = BoxedUnit.UNIT;
                                return;
                            }
                            throw new MatchError(tuple2);
                        }
                        {
                            this.data$3 = data$3;
                        }
                    });
                    blasDoubleMatrix = MFactory.denseDoubleMatrix((int)temp.length, (int)3, (double[])data);
                    break block3;
                }
                if (!(matrix instanceof BlasFloatMatrix)) break block4;
                BlasFloatMatrix blasFloatMatrix = (BlasFloatMatrix)matrix;
                float[] temp = blasFloatMatrix.getData();
                float[] data = new float[temp.length * 3];
                Predef$.MODULE$.refArrayOps((Object[])Predef$.MODULE$.floatArrayOps(temp).zipWithIndex(Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Tuple2.class)))).foreach((Function1)new Serializable(this, data){
                    public static final long serialVersionUID = 0L;
                    private final float[] data$4;

                    public final void apply(Tuple2<Object, Object> x0$4) {
                        Tuple2<Object, Object> tuple2 = x0$4;
                        if (tuple2 != null) {
                            float value = BoxesRunTime.unboxToFloat((Object)tuple2._1());
                            int idx = tuple2._2$mcI$sp();
                            this.data$4[3 * idx] = value;
                            this.data$4[3 * idx + 1] = (float)(1.0 / (1.0 + Math.exp(-2.0 * (double)value)));
                            this.data$4[3 * idx + 2] = value > 0.0f ? 1.0f : -1.0f;
                            BoxedUnit boxedUnit = BoxedUnit.UNIT;
                            return;
                        }
                        throw new MatchError(tuple2);
                    }
                    {
                        this.data$4 = data$4;
                    }
                });
                blasDoubleMatrix = MFactory.denseFloatMatrix((int)temp.length, (int)3, (float[])data);
            }
            return blasDoubleMatrix;
        }
        throw new MatchError((Object)matrix);
    }

    public String toString() {
        return new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"HingeLoss"})).s((Seq)Nil$.MODULE$);
    }

    public HingeLoss() {
        LossFunc$class.$init$(this);
    }
}

