Source code for nncore.nn.blocks.activation

# Copyright (c) Ye Liu. Licensed under the MIT License.

import torch
import torch.nn as nn
import torch.nn.functional as F

import nncore
from ..builder import ACTIVATIONS

class _MishImplementation(torch.autograd.Function):

    def forward(ctx, i):
        result = i * F.softplus(i).tanh()
        return result

    def backward(ctx, grad_output):
        i = ctx.saved_variables[0]
        v = 1. + i.exp()
        h = v.log()
        grad_gh = 1. / h.cosh().pow_(2)
        grad_hx = i.sigmoid()
        grad_gx = grad_gh * grad_hx
        grad_f = grad_gx * i + F.softplus(i).tanh()
        return grad_output * grad_f

class _SwishImplementation(torch.autograd.Function):

    def forward(ctx, i):
        result = i * i.sigmoid()
        return result

    def backward(ctx, grad_output):
        i = ctx.saved_tensors[0]
        sigmoid_i = i.sigmoid()
        return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))

[docs] @ACTIVATIONS.register() class EffMish(nn.Module): """ An efficient implementation of Mish activation layer introduced in [1]. References: 1. Misra et al. ( """ def forward(self, x): return _MishImplementation.apply(x)
[docs] @ACTIVATIONS.register() class EffSwish(nn.Module): """ An efficient implementation of Swish activation layer introduced in [1]. References: 1. Ramachandran et al. ( """ def forward(self, x): return _SwishImplementation.apply(x)
[docs] @ACTIVATIONS.register() class Mish(nn.Module): """ Mish activation layer introduced in [1]. References: 1. Misra et al. ( """ def forward(self, x): return x * F.softplus(x).tanh()
[docs] @ACTIVATIONS.register() class Swish(nn.Module): """ Swish activation layer introduced in [1]. References: 1. Ramachandran et al. ( """ def forward(self, x): return x * x.sigmoid()
[docs] @ACTIVATIONS.register() @nncore.bind_getter('min', 'max') class Clamp(nn.Module): """ Clamp activation layer. Args: min (float, optional): The lower-bound of the range. Default: ``-1``. max (float, optional): The upper-bound of the range. Default: ``1``. """ def __init__(self, min=-1, max=1): super(Clamp, self).__init__() self._min = min self._max = max def forward(self, x): return x.clamp(self._min, self._max)