/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.ranking.dyad.learner.algorithm;

import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.ops.transforms.Transforms;

public class PLNetLoss {
    private PLNetLoss() {
    }

    public static INDArray computeLoss(INDArray plNetOutputs) {
        if (!plNetOutputs.isRowVector() || plNetOutputs.size(1) < 2L) {
            throw new IllegalArgumentException("Input has to be a row vector of 2 or more elements.");
        }
        long dyadRankingLength = plNetOutputs.size(1);
        double loss = 0.0;
        int m = 0;
        while ((long)m <= dyadRankingLength - 2L) {
            INDArray innerSumSlice = plNetOutputs.get(new INDArrayIndex[]{NDArrayIndex.interval((long)m, (long)dyadRankingLength)});
            innerSumSlice = Transforms.exp((INDArray)innerSumSlice);
            loss += Transforms.log((INDArray)innerSumSlice.sum(new int[]{1})).getDouble(0L);
            ++m;
        }
        return Nd4j.create((double[])new double[]{loss -= plNetOutputs.get(new INDArrayIndex[]{NDArrayIndex.interval((long)0L, (long)(dyadRankingLength - 1L))}).sum(new int[]{1}).getDouble(0L)});
    }

    public static INDArray computeLossGradient(INDArray plNetOutputs, int k) {
        if (!plNetOutputs.isRowVector() || plNetOutputs.size(1) < 2L || k < 0 || (long)k >= plNetOutputs.size(1)) {
            throw new IllegalArgumentException("Input has to be a row vector of 2 or more elements. And k has to be a valid index of plNetOutputs.");
        }
        long dyadRankingLength = plNetOutputs.size(1);
        double errorGradient = 0.0;
        for (int m = 0; m <= k; ++m) {
            INDArray innerSumSlice = plNetOutputs.get(new INDArrayIndex[]{NDArrayIndex.interval((long)m, (long)dyadRankingLength)});
            innerSumSlice = Transforms.exp((INDArray)innerSumSlice);
            double innerSum = innerSumSlice.sum(new int[]{1}).getDouble(0L);
            errorGradient += Math.exp(plNetOutputs.getDouble((long)k)) / innerSum;
        }
        return Nd4j.create((double[])new double[]{errorGradient -= 1.0});
    }
}

