Source code for nncore.nn.losses.utils

# Copyright (c) Ye Liu. Licensed under the MIT License.

from functools import wraps


[docs] def weighted_loss(func): """ A syntactic sugar for loss functions with dynamic weights and average factors. This method is expected to be used as a decorator. """ @wraps(func) def _wrapper(*args, weight=None, reduction='mean', avg_factor=None, **kwargs): assert reduction in ('mean', 'sum', 'none') loss = func(*args, **kwargs) if weight is not None: loss = loss * weight if reduction == 'mean': if avg_factor is None: loss = loss.mean() else: loss = loss.sum() / avg_factor elif reduction == 'sum': loss = loss.sum() return loss return _wrapper