public class PLNetLoss
extends java.lang.Object
| Modifier and Type | Method and Description |
|---|---|
static org.nd4j.linalg.api.ndarray.INDArray |
computeLoss(org.nd4j.linalg.api.ndarray.INDArray plNetOutputs)
Computes the NLL for PL networks according to equation (27) in [1].
|
static org.nd4j.linalg.api.ndarray.INDArray |
computeLossGradient(org.nd4j.linalg.api.ndarray.INDArray plNetOutputs,
int k)
Computes the gradient of the NLL for PL networks w.r.t. the k-th dyad according to equation (28) in [1].
|
public static org.nd4j.linalg.api.ndarray.INDArray computeLoss(org.nd4j.linalg.api.ndarray.INDArray plNetOutputs)
plNetOutputs - The outputs for M_n dyads generated by a PLNet's output layer in order of their ranking (from best to worst).public static org.nd4j.linalg.api.ndarray.INDArray computeLossGradient(org.nd4j.linalg.api.ndarray.INDArray plNetOutputs,
int k)
plNetOutputs - The outputs for M_n dyads generated by a PLNet's output layer in order of their ranking (from best to worst).k - The ranking position with respect to which the gradient should be computed. Assumes zero-based indices, unlike the paper.