Class AbstractCompositeLoss

    • Field Detail

      • components

        protected java.util.List<Loss> components
    • Constructor Detail

      • AbstractCompositeLoss

        public AbstractCompositeLoss​(java.lang.String name)
        Constructs a composite loss with the given name.
        Parameters:
        name - the display name of the loss
    • Method Detail

      • inputForComponent

        protected abstract ai.djl.util.Pair<NDList,​NDList> inputForComponent​(int componentIndex,
                                                                                   NDList labels,
                                                                                   NDList predictions)
        Returns the inputs to computing the loss for a component loss.
        Parameters:
        componentIndex - the index of the component loss
        labels - the label input to the composite loss
        predictions - the predictions input to the composite loss
        Returns:
        a pair of the (labels, predictions) inputs to the component loss
      • getComponents

        public java.util.List<Loss> getComponents()
        Returns the component losses that make up the composite loss.
        Returns:
        the component losses that make up the composite loss
      • evaluate

        public NDArray evaluate​(NDList labels,
                                NDList predictions)
        Calculates the evaluation between the labels and the predictions.
        Specified by:
        evaluate in class Evaluator
        Parameters:
        labels - the correct values
        predictions - the predicted values
        Returns:
        the evaluation result
      • addAccumulator

        public void addAccumulator​(java.lang.String key)
        Adds an accumulator for the results of the evaluation with the given key.
        Overrides:
        addAccumulator in class Loss
        Parameters:
        key - the key for the new accumulator
      • updateAccumulator

        public void updateAccumulator​(java.lang.String key,
                                      NDList labels,
                                      NDList predictions)
        Updates the evaluator with the given key based on a NDList of labels and predictions.

        This is a synchronized operation. You should only call it at the end of a batch or epoch.

        Overrides:
        updateAccumulator in class Loss
        Parameters:
        key - the key of the accumulator to update
        labels - a NDList of labels
        predictions - a NDList of predictions
      • resetAccumulator

        public void resetAccumulator​(java.lang.String key)
        Resets the evaluator value with the given key.
        Overrides:
        resetAccumulator in class Loss
        Parameters:
        key - the key of the accumulator to reset
      • getAccumulator

        public float getAccumulator​(java.lang.String key)
        Returns the accumulated evaluator value.
        Overrides:
        getAccumulator in class Loss
        Parameters:
        key - the key of the accumulator to get
        Returns:
        the accumulated value