/*
 * Decompiled with CFR 0.152.
 */
package com.tencent.angel.ml.core.optimizer.loss;

import com.tencent.angel.ml.core.network.layers.AngelGraph;
import com.tencent.angel.ml.core.optimizer.loss.LossFunc;
import com.tencent.angel.ml.core.optimizer.loss.LossFunc$class;
import com.tencent.angel.ml.math2.MFactory;
import com.tencent.angel.ml.math2.matrix.BlasDoubleMatrix;
import com.tencent.angel.ml.math2.matrix.BlasFloatMatrix;
import com.tencent.angel.ml.math2.matrix.Matrix;
import com.tencent.angel.ml.math2.ufuncs.LossFuncs;
import org.json4s.JsonAST;
import scala.Array$;
import scala.Function1;
import scala.MatchError;
import scala.Predef$;
import scala.Serializable;
import scala.StringContext;
import scala.Tuple2;
import scala.collection.Seq;
import scala.collection.immutable.Nil$;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;

@ScalaSignature(bytes="\u0006\u0001e3A!\u0001\u0002\u0001#\t\u00012I]8tg\u0016sGO]8qs2{7o\u001d\u0006\u0003\u0007\u0011\tA\u0001\\8tg*\u0011QAB\u0001\n_B$\u0018.\\5{KJT!a\u0002\u0005\u0002\t\r|'/\u001a\u0006\u0003\u0013)\t!!\u001c7\u000b\u0005-a\u0011!B1oO\u0016d'BA\u0007\u000f\u0003\u001d!XM\\2f]RT\u0011aD\u0001\u0004G>l7\u0001A\n\u0004\u0001IA\u0002CA\n\u0017\u001b\u0005!\"\"A\u000b\u0002\u000bM\u001c\u0017\r\\1\n\u0005]!\"AB!osJ+g\r\u0005\u0002\u001a55\t!!\u0003\u0002\u001c\u0005\tAAj\\:t\rVt7\rC\u0003\u001e\u0001\u0011\u0005a$\u0001\u0004=S:LGO\u0010\u000b\u0002?A\u0011\u0011\u0004\u0001\u0005\bC\u0001\u0011\r\u0011\"\u0001#\u0003\r)\u0007o]\u000b\u0002GA\u00111\u0003J\u0005\u0003KQ\u0011a\u0001R8vE2,\u0007BB\u0014\u0001A\u0003%1%\u0001\u0003faN\u0004\u0003\"B\u0015\u0001\t\u0003R\u0013aB2bY2{7o\u001d\u000b\u0004G-*\u0004\"\u0002\u0017)\u0001\u0004i\u0013\u0001C7pI\u0016dw*\u001e;\u0011\u00059\u001aT\"A\u0018\u000b\u0005A\n\u0014AB7biJL\u0007P\u0003\u00023\u0011\u0005)Q.\u0019;ie%\u0011Ag\f\u0002\u0007\u001b\u0006$(/\u001b=\t\u000bYB\u0003\u0019A\u001c\u0002\u000b\u001d\u0014\u0018\r\u001d5\u0011\u0005ajT\"A\u001d\u000b\u0005iZ\u0014A\u00027bs\u0016\u00148O\u0003\u0002=\r\u00059a.\u001a;x_J\\\u0017B\u0001 :\u0005)\tenZ3m\u000fJ\f\u0007\u000f\u001b\u0005\u0006\u0007\u0001!\t\u0005\u0011\u000b\u0004G\u0005\u001b\u0005\"\u0002\"@\u0001\u0004\u0019\u0013\u0001\u00029sK\u0012DQ\u0001R A\u0002\r\nQ\u0001\\1cK2DQA\u0012\u0001\u0005B\u001d\u000bqaY1m\u000fJ\fG\rF\u0002.\u0011&CQ\u0001L#A\u00025BQAN#A\u0002]BQa\u0013\u0001\u0005B1\u000bq\u0001\u001d:fI&\u001cG\u000fF\u0002.\u001b:CQ\u0001\f&A\u00025BQA\u000e&A\u0002]BQ\u0001\u0015\u0001\u0005BE\u000b\u0001\u0002^8TiJLgn\u001a\u000b\u0002%B\u00111K\u0016\b\u0003'QK!!\u0016\u000b\u0002\rA\u0013X\rZ3g\u0013\t9\u0006L\u0001\u0004TiJLgn\u001a\u0006\u0003+R\u0001")
public class CrossEntropyLoss
implements LossFunc {
    private final double eps;

    @Override
    public JsonAST.JObject toJson() {
        return LossFunc$class.toJson(this);
    }

    public double eps() {
        return this.eps;
    }

    @Override
    public double calLoss(Matrix modelOut, AngelGraph graph) {
        return LossFuncs.entropyloss((Matrix)modelOut, (Matrix)graph.placeHolder().getLabel()).average();
    }

    @Override
    public double loss(double pred, double label) {
        return label > 0.0 ? (pred < this.eps() ? -Math.log(this.eps()) : -Math.log(pred)) : (pred > 1.0 - this.eps() ? -Math.log(this.eps()) : -Math.log(1.0 - pred));
    }

    @Override
    public Matrix calGrad(Matrix modelOut, AngelGraph graph) {
        return LossFuncs.gradentropyloss((Matrix)modelOut, (Matrix)graph.placeHolder().getLabel());
    }

    @Override
    public Matrix predict(Matrix modelOut, AngelGraph graph) {
        Matrix matrix;
        block4: {
            BlasDoubleMatrix blasDoubleMatrix;
            block3: {
                block2: {
                    matrix = modelOut;
                    if (!(matrix instanceof BlasDoubleMatrix)) break block2;
                    BlasDoubleMatrix blasDoubleMatrix2 = (BlasDoubleMatrix)matrix;
                    double[] temp = blasDoubleMatrix2.getData();
                    double[] data = new double[temp.length * 3];
                    Predef$.MODULE$.refArrayOps((Object[])Predef$.MODULE$.doubleArrayOps(temp).zipWithIndex(Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Tuple2.class)))).foreach((Function1)new Serializable(this, data){
                        public static final long serialVersionUID = 0L;
                        private final double[] data$5;

                        public final void apply(Tuple2<Object, Object> x0$5) {
                            Tuple2<Object, Object> tuple2 = x0$5;
                            if (tuple2 != null) {
                                double value = tuple2._1$mcD$sp();
                                int idx = tuple2._2$mcI$sp();
                                double ord = value / (1.0 - value);
                                this.data$5[3 * idx] = ord < 1.0E-6 ? Math.log(1.0E-6) : (ord > 3.4028234663852886E38 ? Math.log(3.4028234663852886E38) : Math.log(ord));
                                this.data$5[3 * idx + 1] = value;
                                this.data$5[3 * idx + 2] = value > 0.5 ? 1.0 : -1.0;
                                BoxedUnit boxedUnit = BoxedUnit.UNIT;
                                return;
                            }
                            throw new MatchError(tuple2);
                        }
                        {
                            this.data$5 = data$5;
                        }
                    });
                    blasDoubleMatrix = MFactory.denseDoubleMatrix((int)temp.length, (int)3, (double[])data);
                    break block3;
                }
                if (!(matrix instanceof BlasFloatMatrix)) break block4;
                BlasFloatMatrix blasFloatMatrix = (BlasFloatMatrix)matrix;
                float[] temp = blasFloatMatrix.getData();
                float[] data = new float[temp.length * 3];
                Predef$.MODULE$.refArrayOps((Object[])Predef$.MODULE$.floatArrayOps(temp).zipWithIndex(Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Tuple2.class)))).foreach((Function1)new Serializable(this, data){
                    public static final long serialVersionUID = 0L;
                    private final float[] data$6;

                    public final void apply(Tuple2<Object, Object> x0$6) {
                        Tuple2<Object, Object> tuple2 = x0$6;
                        if (tuple2 != null) {
                            float value = BoxesRunTime.unboxToFloat((Object)tuple2._1());
                            int idx = tuple2._2$mcI$sp();
                            float ord = value / (1.0f - value);
                            this.data$6[3 * idx] = (double)ord < 1.0E-6 ? (float)Math.log(1.0E-6) : (ord > Float.MAX_VALUE ? (float)Math.log(3.4028234663852886E38) : (float)Math.log(ord));
                            this.data$6[3 * idx + 1] = value;
                            this.data$6[3 * idx + 2] = (double)value > 0.5 ? 1.0f : -1.0f;
                            BoxedUnit boxedUnit = BoxedUnit.UNIT;
                            return;
                        }
                        throw new MatchError(tuple2);
                    }
                    {
                        this.data$6 = data$6;
                    }
                });
                blasDoubleMatrix = MFactory.denseFloatMatrix((int)temp.length, (int)3, (float[])data);
            }
            return blasDoubleMatrix;
        }
        throw new MatchError((Object)matrix);
    }

    public String toString() {
        return new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"CrossEntropyLoss"})).s((Seq)Nil$.MODULE$);
    }

    public CrossEntropyLoss() {
        LossFunc$class.$init$(this);
        this.eps = 1.0E-9;
    }
}

