/*
 * Decompiled with CFR 0.152.
 */
package com.tencent.angel.ml.core.optimizer.loss;

import com.tencent.angel.ml.core.network.layers.AngelGraph;
import com.tencent.angel.ml.core.optimizer.loss.LossFunc;
import com.tencent.angel.ml.core.optimizer.loss.LossFunc$class;
import com.tencent.angel.ml.core.utils.paramsutils.ParamKeys$;
import com.tencent.angel.ml.math2.MFactory;
import com.tencent.angel.ml.math2.matrix.BlasDoubleMatrix;
import com.tencent.angel.ml.math2.matrix.BlasFloatMatrix;
import com.tencent.angel.ml.math2.matrix.Matrix;
import com.tencent.angel.ml.math2.ufuncs.LossFuncs;
import org.json4s.JsonAST;
import org.json4s.JsonDSL$;
import scala.Function1;
import scala.MatchError;
import scala.Predef;
import scala.Predef$;
import scala.Serializable;
import scala.StringContext;
import scala.collection.Seq;
import scala.collection.immutable.Nil$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxesRunTime;

@ScalaSignature(bytes="\u0006\u000114A!\u0001\u0002\u0001#\tI\u0001*\u001e2fe2{7o\u001d\u0006\u0003\u0007\u0011\tA\u0001\\8tg*\u0011QAB\u0001\n_B$\u0018.\\5{KJT!a\u0002\u0005\u0002\t\r|'/\u001a\u0006\u0003\u0013)\t!!\u001c7\u000b\u0005-a\u0011!B1oO\u0016d'BA\u0007\u000f\u0003\u001d!XM\\2f]RT\u0011aD\u0001\u0004G>l7\u0001A\n\u0004\u0001IA\u0002CA\n\u0017\u001b\u0005!\"\"A\u000b\u0002\u000bM\u001c\u0017\r\\1\n\u0005]!\"AB!osJ+g\r\u0005\u0002\u001a55\t!!\u0003\u0002\u001c\u0005\tAAj\\:t\rVt7\r\u0003\u0005\u001e\u0001\t\u0005\t\u0015!\u0003\u001f\u0003\u0015!W\r\u001c;b!\t\u0019r$\u0003\u0002!)\t1Ai\\;cY\u0016DQA\t\u0001\u0005\u0002\r\na\u0001P5oSRtDC\u0001\u0013&!\tI\u0002\u0001C\u0003\u001eC\u0001\u0007a\u0004C\u0003(\u0001\u0011\u0005\u0003&A\u0004dC2dun]:\u0015\u0007yI3\u0007C\u0003+M\u0001\u00071&\u0001\u0005n_\u0012,GnT;u!\ta\u0013'D\u0001.\u0015\tqs&\u0001\u0004nCR\u0014\u0018\u000e\u001f\u0006\u0003a!\tQ!\\1uQJJ!AM\u0017\u0003\r5\u000bGO]5y\u0011\u0015!d\u00051\u00016\u0003\u00159'/\u00199i!\t14(D\u00018\u0015\tA\u0014(\u0001\u0004mCf,'o\u001d\u0006\u0003u\u0019\tqA\\3uo>\u00148.\u0003\u0002=o\tQ\u0011I\\4fY\u001e\u0013\u0018\r\u001d5\t\u000b\r\u0001A\u0011\t \u0015\u0007yy\u0014\tC\u0003A{\u0001\u0007a$\u0001\u0003qe\u0016$\u0007\"\u0002\">\u0001\u0004q\u0012!\u00027bE\u0016d\u0007\"\u0002#\u0001\t\u0003*\u0015aB2bY\u001e\u0013\u0018\r\u001a\u000b\u0004W\u0019;\u0005\"\u0002\u0016D\u0001\u0004Y\u0003\"\u0002\u001bD\u0001\u0004)\u0004\"B%\u0001\t\u0003R\u0015a\u00029sK\u0012L7\r\u001e\u000b\u0004W-c\u0005\"\u0002\u0016I\u0001\u0004Y\u0003\"\u0002\u001bI\u0001\u0004)\u0004\"\u0002(\u0001\t\u0003z\u0015\u0001\u0003;p'R\u0014\u0018N\\4\u0015\u0003A\u0003\"!\u0015+\u000f\u0005M\u0011\u0016BA*\u0015\u0003\u0019\u0001&/\u001a3fM&\u0011QK\u0016\u0002\u0007'R\u0014\u0018N\\4\u000b\u0005M#\u0002\"\u0002-\u0001\t\u0003J\u0016A\u0002;p\u0015N|g.F\u0001[!\tY\u0016N\u0004\u0002]M:\u0011Ql\u0019\b\u0003=\u0006l\u0011a\u0018\u0006\u0003AB\ta\u0001\u0010:p_Rt\u0014\"\u00012\u0002\u0007=\u0014x-\u0003\u0002eK\u00061!n]8oiMT\u0011AY\u0005\u0003O\"\fqAS:p]\u0006\u001bFK\u0003\u0002eK&\u0011!n\u001b\u0002\b\u0015>\u0013'.Z2u\u0015\t9\u0007\u000e")
public class HuberLoss
implements LossFunc {
    private final double delta;

    @Override
    public double calLoss(Matrix modelOut, AngelGraph graph) {
        return LossFuncs.huberloss((Matrix)modelOut, (Matrix)graph.placeHolder().getLabel(), (double)this.delta).average();
    }

    @Override
    public double loss(double pred, double label) {
        double diff = Math.abs(pred - label);
        return diff > this.delta ? this.delta * diff - 0.5 * this.delta * this.delta : 0.5 * diff * diff;
    }

    @Override
    public Matrix calGrad(Matrix modelOut, AngelGraph graph) {
        return LossFuncs.gradhuberloss((Matrix)modelOut, (Matrix)graph.placeHolder().getLabel(), (double)this.delta);
    }

    @Override
    public Matrix predict(Matrix modelOut, AngelGraph graph) {
        Matrix matrix;
        block4: {
            BlasDoubleMatrix blasDoubleMatrix;
            block3: {
                block2: {
                    matrix = modelOut;
                    if (!(matrix instanceof BlasDoubleMatrix)) break block2;
                    BlasDoubleMatrix blasDoubleMatrix2 = (BlasDoubleMatrix)matrix;
                    BlasDoubleMatrix mat = MFactory.denseDoubleMatrix((int)blasDoubleMatrix2.getNumRows(), (int)3);
                    mat.setCol(0, blasDoubleMatrix2.getCol(0));
                    blasDoubleMatrix = mat;
                    break block3;
                }
                if (!(matrix instanceof BlasFloatMatrix)) break block4;
                BlasFloatMatrix blasFloatMatrix = (BlasFloatMatrix)matrix;
                BlasFloatMatrix mat = MFactory.denseFloatMatrix((int)blasFloatMatrix.getNumRows(), (int)3);
                mat.setCol(0, blasFloatMatrix.getCol(0));
                blasDoubleMatrix = mat;
            }
            return blasDoubleMatrix;
        }
        throw new MatchError((Object)matrix);
    }

    public String toString() {
        return new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"HuberLoss"})).s((Seq)Nil$.MODULE$);
    }

    @Override
    public JsonAST.JObject toJson() {
        return JsonDSL$.MODULE$.pair2Assoc(Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)ParamKeys$.MODULE$.typeName()), (Object)new JsonAST.JString(new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"", ""})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{this.getClass().getSimpleName()})))), (Function1)Predef$.MODULE$.$conforms()).$tilde(Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)"delta"), (Object)BoxesRunTime.boxToDouble((double)this.delta)), (Function1)new Serializable(this){
            public static final long serialVersionUID = 0L;

            public final JsonAST.JValue apply(double x) {
                return JsonDSL$.MODULE$.double2jvalue(x);
            }
        });
    }

    public HuberLoss(double delta) {
        this.delta = delta;
        LossFunc$class.$init$(this);
    }
}

