Class PLNetLoss
- java.lang.Object
-
- ai.libs.jaicore.ml.ranking.dyad.learner.algorithm.PLNetLoss
-
public class PLNetLoss extends java.lang.ObjectImplements the negative log likelihood (NLL) loss function for PL networks as described in [1]. * [1]: Dirk Schäfer, Eyke Hüllermeier (2018). Dyad ranking using Plackett-Luce models based on joint feature representations
-
-
Method Summary
All Methods Static Methods Concrete Methods Modifier and Type Method Description static org.nd4j.linalg.api.ndarray.INDArraycomputeLoss(org.nd4j.linalg.api.ndarray.INDArray plNetOutputs)Computes the NLL for PL networks according to equation (27) in [1].static org.nd4j.linalg.api.ndarray.INDArraycomputeLossGradient(org.nd4j.linalg.api.ndarray.INDArray plNetOutputs, int k)Computes the gradient of the NLL for PL networks w.r.t. the k-th dyad according to equation (28) in [1].
-
-
-
Method Detail
-
computeLoss
public static org.nd4j.linalg.api.ndarray.INDArray computeLoss(org.nd4j.linalg.api.ndarray.INDArray plNetOutputs)
Computes the NLL for PL networks according to equation (27) in [1].- Parameters:
plNetOutputs- The outputs for M_n dyads generated by a PLNet's output layer in order of their ranking (from best to worst).- Returns:
- The NLL loss for the given PLNet outputs.
-
computeLossGradient
public static org.nd4j.linalg.api.ndarray.INDArray computeLossGradient(org.nd4j.linalg.api.ndarray.INDArray plNetOutputs, int k)Computes the gradient of the NLL for PL networks w.r.t. the k-th dyad according to equation (28) in [1].- Parameters:
plNetOutputs- The outputs for M_n dyads generated by a PLNet's output layer in order of their ranking (from best to worst).k- The ranking position with respect to which the gradient should be computed. Assumes zero-based indices, unlike the paper.- Returns:
- The gradient of the NLL loss w.r.t. the k-th dyad in the ranking.
-
-