Source code for nncore.nn.losses.contrastive

# Copyright (c) Ye Liu. Licensed under the MIT License.

from math import log

import torch
import torch.nn as nn
import torch.nn.functional as F

import nncore
from nncore.ops import cosine_similarity
from ..builder import LOSSES
from ..bundle import Parameter
from .utils import weighted_loss


[docs] @weighted_loss def infonce_loss(a, b, temperature=0.07, scale=None, max_scale=100): """ InfoNCE Loss introduced in [1]. Args: a (:obj:`torch.Tensor`): The first group of samples. b (:obj:`torch.Tensor`): The second group of samples. temperature (float, optional): The temperature for softmax. Default: ``0.07``. scale (:obj:`torch.Tensor` | None, optional): The logit scale to use. If not specified, the scale will be calculated from temperature. Default: ``None``. max_scale (float, optional): The maximum logit scale value. Default: ``100``. References: 1. Oord et al. (https://arxiv.org/abs/1807.03748) """ a = F.normalize(a, dim=-1) b = F.normalize(b, dim=-1) if scale is None: scale = a.new_tensor([log(1 / temperature)]) scale = torch.clamp(scale.exp(), max=max_scale) a_sim = torch.matmul(a, b.transpose(-1, -2)) * scale b_sim = a_sim.transpose(-1, -2) target = torch.arange(a.size(-2), device=a.device).expand(a.size()[:-1]) a_loss = F.cross_entropy(a_sim, target) b_loss = F.cross_entropy(b_sim, target) loss = (a_loss + b_loss) / 2 return loss
[docs] @weighted_loss def triplet_loss(pos, neg, anchor, margin=0.5): """ Triplet Loss. Args: pos (:obj:`torch.Tensor`): Positive samples. neg (:obj:`torch.Tensor`): Negative samples. anchor (:obj:`torch.Tensor`): Anchors for distance calculation. margin (float, optional): The margin between positive and negative samples. Default: ``0.5``. Returns: :obj:`torch.Tensor`: The loss tensor. """ pos_sim = cosine_similarity(pos, anchor) neg_sim = cosine_similarity(neg, anchor) loss = (margin - pos_sim + neg_sim).relu() return loss
[docs] @LOSSES.register() @nncore.bind_getter('temperature', 'max_scale', 'learnable', 'loss_weight') class InfoNCELoss(nn.Module): """ InfoNCE Loss introduced in [1]. Args: temperature (float, optional): The initial temperature for softmax. Default: ``0.07``. max_scale (float, optional): The maximum value of learnable scale. Default: ``100``. learnable (bool, optional): Whether the logit scale is learnable. Default: ``True``. loss_weight (float, optional): Weight of the loss. Default: ``1.0``. References: 1. Oord et al. (https://arxiv.org/abs/1807.03748) """ def __init__(self, temperature=0.07, max_scale=100, learnable=True, loss_weight=1.0): super(InfoNCELoss, self).__init__() if learnable: self.scale = Parameter(log(1 / temperature)) else: self.scale = None self._temperature = temperature self._max_scale = max_scale self._learnable = learnable self._loss_weight = loss_weight def extra_repr(self): return ('temperature={}, max_scale={}, learnable={}, loss_weight={}'. format(self._temperature, self._max_scale, self._learnable, self._loss_weight)) def forward(self, a, b, weight=None, avg_factor=None): return infonce_loss( a, b, temperature=self._temperature, scale=self.scale, max_scale=self._max_scale, weight=weight, avg_factor=avg_factor) * self._loss_weight
[docs] @LOSSES.register() @nncore.bind_getter('margin', 'reduction', 'loss_weight') class TripletLoss(nn.Module): """ Triplet Loss. Args: margin (float, optional): The margin between positive and negative samples. Default: ``0.5``. reduction (str, optional): Reduction method. Currently supported values include ``'mean'``, ``'sum'``, and ``'none'``. Default: ``'mean'``. loss_weight (float, optional): Weight of the loss. Default: ``1.0``. """ def __init__(self, margin=0.5, reduction='mean', loss_weight=1.0): super(TripletLoss, self).__init__() self._margin = margin self._reduction = reduction self._loss_weight = loss_weight def extra_repr(self): return 'margin={}, reduction={}, loss_weight={}'.format( self._margin, self._reduction, self._loss_weight) def forward(self, pos, neg, anchor, weight=None, avg_factor=None): return triplet_loss( pos, neg, anchor, margin=self._margin, weight=weight, reduction=self._reduction, avg_factor=avg_factor) * self._loss_weight