Class PLNetLoss


  • public class PLNetLoss
    extends java.lang.Object
    Implements the negative log likelihood (NLL) loss function for PL networks as described in [1]. * [1]: Dirk Schäfer, Eyke Hüllermeier (2018). Dyad ranking using Plackett-Luce models based on joint feature representations
    • Method Summary

      All Methods Static Methods Concrete Methods 
      Modifier and Type Method Description
      static org.nd4j.linalg.api.ndarray.INDArray computeLoss​(org.nd4j.linalg.api.ndarray.INDArray plNetOutputs)
      Computes the NLL for PL networks according to equation (27) in [1].
      static org.nd4j.linalg.api.ndarray.INDArray computeLossGradient​(org.nd4j.linalg.api.ndarray.INDArray plNetOutputs, int k)
      Computes the gradient of the NLL for PL networks w.r.t. the k-th dyad according to equation (28) in [1].
      • Methods inherited from class java.lang.Object

        clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
    • Method Detail

      • computeLoss

        public static org.nd4j.linalg.api.ndarray.INDArray computeLoss​(org.nd4j.linalg.api.ndarray.INDArray plNetOutputs)
        Computes the NLL for PL networks according to equation (27) in [1].
        Parameters:
        plNetOutputs - The outputs for M_n dyads generated by a PLNet's output layer in order of their ranking (from best to worst).
        Returns:
        The NLL loss for the given PLNet outputs.
      • computeLossGradient

        public static org.nd4j.linalg.api.ndarray.INDArray computeLossGradient​(org.nd4j.linalg.api.ndarray.INDArray plNetOutputs,
                                                                               int k)
        Computes the gradient of the NLL for PL networks w.r.t. the k-th dyad according to equation (28) in [1].
        Parameters:
        plNetOutputs - The outputs for M_n dyads generated by a PLNet's output layer in order of their ranking (from best to worst).
        k - The ranking position with respect to which the gradient should be computed. Assumes zero-based indices, unlike the paper.
        Returns:
        The gradient of the NLL loss w.r.t. the k-th dyad in the ranking.