# Copyright (c) Ye Liu. Licensed under the MIT License.
import torch
from torch.nn.parallel import DataParallel, DistributedDataParallel
from torch.nn.parallel._functions import Scatter, _get_stream
from torch.nn.parallel.scatter_gather import _is_namedtuple
import nncore
from .container import DataContainer
class _Scatter(torch.autograd.Function):
@staticmethod
def forward(target_gpus, input):
input_device = _get_input_device(input)
streams = None
if input_device == -1 and target_gpus != [-1]:
streams = [
_get_stream(torch.device('cuda', gpu_id))
for gpu_id in target_gpus
]
outputs = _scatter_stream(input, target_gpus, streams)
if streams is not None:
_sync_stream(outputs, target_gpus, streams)
return tuple(outputs)
def _get_input_device(input):
if isinstance(input, list):
for item in input:
input_device = _get_input_device(item)
if input_device != -1:
return input_device
return -1
elif torch.is_tensor(input):
return input.get_device() if input.is_cuda else -1
else:
raise TypeError('unknown type {}'.format(type(input)))
def _scatter_stream(input, devices, streams=None):
if streams is None:
streams = [None] * len(devices)
if isinstance(input, list):
chunk_size = (len(input) - 1) // len(devices) + 1
outputs = [
_scatter_stream(input[i], [devices[i // chunk_size]],
[streams[i // chunk_size]])
for i in range(len(input))
]
return outputs
elif torch.is_tensor(input):
output = input.contiguous()
stream = streams[0] if output.numel() > 0 else None
if devices != [-1]:
with torch.cuda.device(devices[0]), torch.cuda.stream(stream):
output = output.cuda(devices[0], non_blocking=True)
return output
else:
raise TypeError('unknown type {}'.format(type(input)))
def _sync_stream(output, devices, streams):
if isinstance(output, list):
chunk_size = len(output) // len(devices)
for i in range(len(devices)):
for j in range(chunk_size):
_sync_stream(output[i * chunk_size + j], [devices[i]],
[streams[i]])
elif torch.is_tensor(output):
if output.numel() != 0:
with torch.cuda.device(devices[0]):
main_stream = torch.cuda.current_stream()
main_stream.wait_stream(streams[0])
output.record_stream(main_stream)
else:
raise TypeError('unknown type {}'.format(type(output)))
def _scatter(inputs, target_gpus, dim=0):
def _scatter_map(obj):
if torch.is_tensor(obj) and target_gpus != [-1]:
return Scatter.apply(target_gpus, None, dim, obj)
if isinstance(obj, DataContainer):
return obj.data if obj.cpu_only else _Scatter.forward(
target_gpus, obj.data)
if _is_namedtuple(obj):
return [type(obj)(*args) for args in zip(*map(_scatter_map, obj))]
if isinstance(obj, tuple) and len(obj) > 0:
return list(zip(*map(_scatter_map, obj)))
if isinstance(obj, list) and len(obj) > 0:
return [list(i) for i in zip(*map(_scatter_map, obj))]
if isinstance(obj, dict) and len(obj) > 0:
return [type(obj)(i) for i in zip(*map(_scatter_map, obj.items()))]
return [obj for _ in target_gpus]
try:
res = _scatter_map(inputs)
finally:
_scatter_map = None
return res
def _scatter_kwargs(inputs, kwargs, target_gpus, dim=0):
inputs = _scatter(inputs, target_gpus, dim) if inputs else []
kwargs = _scatter(kwargs, target_gpus, dim) if kwargs else []
if len(inputs) < len(kwargs):
inputs.extend(() for _ in range(len(kwargs) - len(inputs)))
elif len(kwargs) < len(inputs):
kwargs.extend({} for _ in range(len(inputs) - len(kwargs)))
return tuple(inputs), tuple(kwargs)
def _get_device(device_id=None):
if device_id is not None:
return device_id
if torch.cuda.is_available():
return torch.cuda.current_device()
return -1
[docs]
class NNDataParallel(DataParallel):
"""
A :obj:`nn.DataParallel` class with :obj:`DataContainer` support. This
class only bundles single-device modules.
Args:
module (:obj:`nn.Module`): The module to be bundled.
device_id (int | None, optional): The device id to be used. ``None``
means using the default device, and ``-1`` means CPU. Default:
``None``.
"""
def __init__(self, module, device_id=None, dim=0, **kwargs):
assert isinstance(device_id, int) or device_id is None
assert 'device_ids' not in kwargs and 'output_device' not in kwargs
device_id = _get_device(device_id)
if device_id >= 0:
super(NNDataParallel, self).__init__(
module,
device_ids=[device_id],
output_device=device_id,
**kwargs)
else:
logger = nncore.get_logger()
logger.warn('{} is running on CPU'.format(self.__class__.__name__))
super(DataParallel, self).__init__()
torch._C._log_api_usage_once('torch.nn.parallel.DataParallel')
self.module = module
self.device_ids = []
self.dim = dim
def scatter(self, inputs, kwargs, device_ids):
return _scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
def forward(self, *inputs, **kwargs):
if self.device_ids:
return super(NNDataParallel, self).forward(*inputs, **kwargs)
else:
inputs, kwargs = self.scatter(inputs, kwargs, [-1])
return self.module(*inputs[0], **kwargs[0])
[docs]
class NNDistributedDataParallel(DistributedDataParallel):
"""
A :obj:`nn.DistributedDataParallel` class with :obj:`DataContainer`
support. This class only bundles single-device modules.
Args:
module (:obj:`nn.Module`): The module to be bundled.
device_id (int | None, optional): The device id to be used. ``None``
means using the default device, and ``-1`` means CPU. Default:
``None``.
"""
def __init__(self, module, device_id=None, **kwargs):
assert isinstance(device_id, int) or device_id is None
assert 'device_ids' not in kwargs and 'output_device' not in kwargs
device_id = _get_device(device_id)
if device_id >= 0:
module = module.to('cuda:{}'.format(device_id))
device_ids = [device_id]
else:
logger = nncore.get_logger()
logger.warn('{} is running on CPU'.format(self.__class__.__name__))
device_ids = None
super(NNDistributedDataParallel, self).__init__(
module, device_ids=device_ids, **kwargs)
def _pre_forward(self, *inputs, **kwargs):
if self.device_ids:
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
inputs, kwargs = inputs[0], kwargs[0]
return super(NNDistributedDataParallel,
self)._pre_forward(*inputs, **kwargs)
def scatter(self, inputs, kwargs, device_ids):
return _scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
def forward(self, *inputs, **kwargs):
if self.device_ids:
return super(NNDistributedDataParallel,
self).forward(*inputs, **kwargs)
else:
inputs, kwargs = self.scatter(inputs, kwargs, [-1])
return self.module(*inputs[0], **kwargs[0])