Source code for nncore.nn.modules.msg_pass

# Copyright (c) Ye Liu. Licensed under the MIT License.

import torch.nn as nn

import nncore
from ..builder import (MODULES, NORMS, build_act_layer, build_msg_pass_layer,
from ..init import constant_init_

[docs] @MODULES.register() @nncore.bind_getter('in_features', 'out_features', 'bias', 'order') class MsgPassModule(nn.Module): """ A module that bundles message passing, normalization, and activation layers. Args: in_features (int): Number of input features. out_features (int): Number of output features. bias (str | bool, optional): Whether to add the bias term in the message passing layer. If ``bias='auto'``, the module will decide it automatically base on whether it has a normalization layer. Default: ``'auto'``. msg_pass_cfg (dict | str | None, optional): The config or name of the message passing layer. If not specified, ``GCN`` will be used. Default: ``None``. norm_cfg (dict | str | None, optional): The config or name of the normalization layer. Default: ``None``. act_cfg (dict | str | None, optional): The config or name of the activation layer. Default: ``dict(type='ReLU', inplace=True)``. order (tuple[str], optional): The order of layers. It is expected to be a sequence of ``'msg_pass'``, ``'norm'``, and ``'act'``. Default: ``('msg_pass', 'norm', 'act')``. """ def __init__(self, in_features, out_features, bias='auto', msg_pass_cfg=None, norm_cfg=None, act_cfg=dict(type='ReLU', inplace=True), order=('msg_pass', 'norm', 'act'), **kwargs): super(MsgPassModule, self).__init__() self._in_features = in_features self._out_features = out_features _map = dict(msg_pass=True, norm=norm_cfg, act=act_cfg) self._order = tuple(o for o in order if _map[o] is not None) if self.with_norm: _pos = self._order.index('norm') - self._order.index('msg_pass') if bias != 'auto': self._bias = bias elif self.with_norm: _typ = norm_cfg['type'] if isinstance(norm_cfg, dict) else norm_cfg self._bias = _typ in'drop') or _pos != 1 else: self._bias = True for layer in self._order: if layer == 'msg_pass': self.msg_pass = build_msg_pass_layer( msg_pass_cfg or 'GCN', in_features=in_features, out_features=out_features, bias=self._bias, **kwargs) elif layer == 'norm': self.norm = build_norm_layer( norm_cfg, dims=out_features if _pos > 0 else in_features) else: self.act = build_act_layer(act_cfg) self.init_weights() @property def with_norm(self): return 'norm' in self._order @property def with_act(self): return 'act' in self._order def init_weights(self): if self.with_norm: constant_init_(self.norm)
[docs] def forward(self, x, graph): """ Args: x (:obj:`torch.Tensor[N, M]`): The input node features. graph (:obj:`torch.Tensor[N, N]`): The graph structure where ``graph[i, j] == n (n > 0)`` means there is a link with weight ``n`` from node ``i`` to node ``j`` while ``graph[i, j] == 0`` means not. """ for layer in self._order: if layer == 'msg_pass': x = self.msg_pass(x, graph) else: x = getattr(self, layer)(x) return x
[docs] def build_msg_pass_modules(dims, last_norm=False, last_act=False, default=None, **kwargs): """ Build a module list containing message passing, normalization, and activation layers. Args: dims (list[int]): The sequence of numbers of dimensions of features. last_norm (bool, optional): Whether to add a normalization layer after the last message passing layer. Default: ``False``. last_act (bool, optional): Whether to add an activation layer after the last message passing layer. Default: ``False``. default (any, optional): The default value when the ``dims`` is not valid. Default: ``None``. Returns: :obj:`nn.ModuleList`: The constructed module list. """ if not nncore.is_seq_of(dims, int): return default _kwargs = kwargs.copy() _layers = [last_norm or 'norm', last_act or 'act'] cfg, layers = [], [] for key, value in _kwargs.items(): if isinstance(value, list): assert len(value) == len(dims) - 1 cfg.append(key) cfg = {k: _kwargs.pop(k) for k in cfg} for i in range(len(dims) - 1): if i == len(dims) - 2: order = list(_kwargs.get('order', ['msg_pass', 'norm', 'act'])) while order[-1] in _layers: order.pop() _kwargs['order'] = tuple(order) _cfg = _kwargs.get('msg_pass_cfg') if _cfg is not None and _cfg['type'] == 'GAT': _kwargs['concat'] = False _kwargs.update({k: v[i] for k, v in cfg.items()}) module = MsgPassModule(dims[i], dims[i + 1], **_kwargs) layers.append(module) return nn.ModuleList(layers)