/*
 * Decompiled with CFR 0.152.
 */
package com.tencent.angel.ml.core.optimizer.loss;

import com.tencent.angel.ml.core.network.layers.AngelGraph;
import com.tencent.angel.ml.core.optimizer.loss.LossFunc;
import com.tencent.angel.ml.core.optimizer.loss.LossFunc$class;
import com.tencent.angel.ml.math2.MFactory;
import com.tencent.angel.ml.math2.matrix.BlasDoubleMatrix;
import com.tencent.angel.ml.math2.matrix.BlasFloatMatrix;
import com.tencent.angel.ml.math2.matrix.Matrix;
import com.tencent.angel.ml.math2.ufuncs.LossFuncs;
import org.json4s.JsonAST;
import scala.Array$;
import scala.Function1;
import scala.MatchError;
import scala.Predef$;
import scala.Serializable;
import scala.StringContext;
import scala.Tuple2;
import scala.collection.Seq;
import scala.collection.immutable.Nil$;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;

@ScalaSignature(bytes="\u0006\u0001Q3A!\u0001\u0002\u0001#\t9Aj\\4M_N\u001c(BA\u0002\u0005\u0003\u0011awn]:\u000b\u0005\u00151\u0011!C8qi&l\u0017N_3s\u0015\t9\u0001\"\u0001\u0003d_J,'BA\u0005\u000b\u0003\tiGN\u0003\u0002\f\u0019\u0005)\u0011M\\4fY*\u0011QBD\u0001\bi\u0016t7-\u001a8u\u0015\u0005y\u0011aA2p[\u000e\u00011c\u0001\u0001\u00131A\u00111CF\u0007\u0002))\tQ#A\u0003tG\u0006d\u0017-\u0003\u0002\u0018)\t1\u0011I\\=SK\u001a\u0004\"!\u0007\u000e\u000e\u0003\tI!a\u0007\u0002\u0003\u00111{7o\u001d$v]\u000eDQ!\b\u0001\u0005\u0002y\ta\u0001P5oSRtD#A\u0010\u0011\u0005e\u0001\u0001\"B\u0011\u0001\t\u0003\u0012\u0013aB2bY2{7o\u001d\u000b\u0004G\u0019\u0002\u0004CA\n%\u0013\t)CC\u0001\u0004E_V\u0014G.\u001a\u0005\u0006O\u0001\u0002\r\u0001K\u0001\t[>$W\r\\(viB\u0011\u0011FL\u0007\u0002U)\u00111\u0006L\u0001\u0007[\u0006$(/\u001b=\u000b\u00055B\u0011!B7bi\"\u0014\u0014BA\u0018+\u0005\u0019i\u0015\r\u001e:jq\")\u0011\u0007\ta\u0001e\u0005)qM]1qQB\u00111\u0007O\u0007\u0002i)\u0011QGN\u0001\u0007Y\u0006LXM]:\u000b\u0005]2\u0011a\u00028fi^|'o[\u0005\u0003sQ\u0012!\"\u00118hK2<%/\u00199i\u0011\u0015\u0019\u0001\u0001\"\u0011<)\r\u0019CH\u0010\u0005\u0006{i\u0002\raI\u0001\u0005aJ,G\rC\u0003@u\u0001\u00071%A\u0003mC\n,G\u000eC\u0003B\u0001\u0011\u0005#)A\u0004dC2<%/\u00193\u0015\u0007!\u001aE\tC\u0003(\u0001\u0002\u0007\u0001\u0006C\u00032\u0001\u0002\u0007!\u0007C\u0003G\u0001\u0011\u0005s)A\u0004qe\u0016$\u0017n\u0019;\u0015\u0007!B\u0015\nC\u0003(\u000b\u0002\u0007\u0001\u0006C\u00032\u000b\u0002\u0007!\u0007C\u0003L\u0001\u0011\u0005C*\u0001\u0005u_N#(/\u001b8h)\u0005i\u0005C\u0001(R\u001d\t\u0019r*\u0003\u0002Q)\u00051\u0001K]3eK\u001aL!AU*\u0003\rM#(/\u001b8h\u0015\t\u0001F\u0003")
public class LogLoss
implements LossFunc {
    @Override
    public JsonAST.JObject toJson() {
        return LossFunc$class.toJson(this);
    }

    @Override
    public double calLoss(Matrix modelOut, AngelGraph graph) {
        return LossFuncs.logloss((Matrix)modelOut, (Matrix)graph.placeHolder().getLabel()).average();
    }

    @Override
    public double loss(double pred, double label) {
        return Math.log(1.0 + Math.exp(-pred * label));
    }

    @Override
    public Matrix calGrad(Matrix modelOut, AngelGraph graph) {
        return LossFuncs.gradlogloss((Matrix)modelOut, (Matrix)graph.placeHolder().getLabel());
    }

    @Override
    public Matrix predict(Matrix modelOut, AngelGraph graph) {
        Matrix matrix;
        block4: {
            BlasDoubleMatrix blasDoubleMatrix;
            block3: {
                block2: {
                    matrix = modelOut;
                    if (!(matrix instanceof BlasDoubleMatrix)) break block2;
                    BlasDoubleMatrix blasDoubleMatrix2 = (BlasDoubleMatrix)matrix;
                    double[] temp = blasDoubleMatrix2.getData();
                    double[] data = new double[temp.length * 3];
                    Predef$.MODULE$.refArrayOps((Object[])Predef$.MODULE$.doubleArrayOps(temp).zipWithIndex(Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Tuple2.class)))).foreach((Function1)new Serializable(this, data){
                        public static final long serialVersionUID = 0L;
                        private final double[] data$1;

                        public final void apply(Tuple2<Object, Object> x0$1) {
                            Tuple2<Object, Object> tuple2 = x0$1;
                            if (tuple2 != null) {
                                double value = tuple2._1$mcD$sp();
                                int idx = tuple2._2$mcI$sp();
                                this.data$1[3 * idx] = value;
                                this.data$1[3 * idx + 1] = 1.0 / (1.0 + Math.exp(-value));
                                this.data$1[3 * idx + 2] = value > 0.0 ? 1.0 : -1.0;
                                BoxedUnit boxedUnit = BoxedUnit.UNIT;
                                return;
                            }
                            throw new MatchError(tuple2);
                        }
                        {
                            this.data$1 = data$1;
                        }
                    });
                    blasDoubleMatrix = MFactory.denseDoubleMatrix((int)temp.length, (int)3, (double[])data);
                    break block3;
                }
                if (!(matrix instanceof BlasFloatMatrix)) break block4;
                BlasFloatMatrix blasFloatMatrix = (BlasFloatMatrix)matrix;
                float[] temp = blasFloatMatrix.getData();
                float[] data = new float[temp.length * 3];
                Predef$.MODULE$.refArrayOps((Object[])Predef$.MODULE$.floatArrayOps(temp).zipWithIndex(Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Tuple2.class)))).foreach((Function1)new Serializable(this, data){
                    public static final long serialVersionUID = 0L;
                    private final float[] data$2;

                    public final void apply(Tuple2<Object, Object> x0$2) {
                        Tuple2<Object, Object> tuple2 = x0$2;
                        if (tuple2 != null) {
                            float value = BoxesRunTime.unboxToFloat((Object)tuple2._1());
                            int idx = tuple2._2$mcI$sp();
                            this.data$2[3 * idx] = value;
                            this.data$2[3 * idx + 1] = (float)(1.0 / (1.0 + Math.exp(-value)));
                            this.data$2[3 * idx + 2] = value > 0.0f ? 1.0f : -1.0f;
                            BoxedUnit boxedUnit = BoxedUnit.UNIT;
                            return;
                        }
                        throw new MatchError(tuple2);
                    }
                    {
                        this.data$2 = data$2;
                    }
                });
                blasDoubleMatrix = MFactory.denseFloatMatrix((int)temp.length, (int)3, (float[])data);
            }
            return blasDoubleMatrix;
        }
        throw new MatchError((Object)matrix);
    }

    public String toString() {
        return new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"LogLoss"})).s((Seq)Nil$.MODULE$);
    }

    public LogLoss() {
        LossFunc$class.$init$(this);
    }
}

