Source code for nncore.parallel.parallel

# Copyright (c) Ye Liu. Licensed under the MIT License.

import torch
from torch.nn.parallel import DataParallel, DistributedDataParallel
from torch.nn.parallel.scatter_gather import _is_namedtuple
from torch.nn.parallel._functions import Function, Scatter, _get_stream

from .container import DataContainer


class _Scatter(Function):

    @staticmethod
    def forward(target_gpus, input):
        input_device = _get_input_device(input)
        streams = None
        if input_device == -1 and target_gpus != [-1]:
            streams = [_get_stream(device) for device in target_gpus]

        outputs = _scatter_stream(input, target_gpus, streams)
        if streams is not None:
            _sync_stream(outputs, target_gpus, streams)

        return tuple(outputs)


def _get_input_device(input):
    if isinstance(input, list):
        for item in input:
            input_device = _get_input_device(item)
            if input_device != -1:
                return input_device
        return -1
    elif torch.is_tensor(input):
        return input.get_device() if input.is_cuda else -1
    else:
        raise TypeError('unknown type {}'.format(type(input)))


def _scatter_stream(input, devices, streams=None):
    if streams is None:
        streams = [None] * len(devices)

    if isinstance(input, list):
        chunk_size = (len(input) - 1) // len(devices) + 1
        outputs = [
            _scatter_stream(input[i], [devices[i // chunk_size]],
                            [streams[i // chunk_size]])
            for i in range(len(input))
        ]
        return outputs
    elif torch.is_tensor(input):
        output = input.contiguous()
        stream = streams[0] if output.numel() > 0 else None
        if devices != [-1]:
            with torch.cuda.device(devices[0]), torch.cuda.stream(stream):
                output = output.cuda(devices[0], non_blocking=True)
        return output
    else:
        raise TypeError('unknown type {}'.format(type(input)))


def _sync_stream(output, devices, streams):
    if isinstance(output, list):
        chunk_size = len(output) // len(devices)
        for i in range(len(devices)):
            for j in range(chunk_size):
                _sync_stream(output[i * chunk_size + j], [devices[i]],
                             [streams[i]])
    elif torch.is_tensor(output):
        if output.numel() != 0:
            with torch.cuda.device(devices[0]):
                main_stream = torch.cuda.current_stream()
                main_stream.wait_stream(streams[0])
                output.record_stream(main_stream)
    else:
        raise TypeError('unknown type {}'.format(type(output)))


def _scatter(inputs, target_gpus, dim=0):

    def _scatter_map(obj):
        if torch.is_tensor(obj):
            if target_gpus != [-1]:
                return Scatter.apply(target_gpus, None, dim, obj)
            else:
                return _Scatter.forward(target_gpus, obj)
        if isinstance(obj, DataContainer):
            if obj.cpu_only:
                return obj.data
            else:
                return _Scatter.forward(target_gpus, obj.data)
        if _is_namedtuple(obj):
            return [type(obj)(*args) for args in zip(*map(_scatter_map, obj))]
        if isinstance(obj, tuple) and len(obj) > 0:
            return list(zip(*map(_scatter_map, obj)))
        if isinstance(obj, list) and len(obj) > 0:
            return [list(i) for i in zip(*map(_scatter_map, obj))]
        if isinstance(obj, dict) and len(obj) > 0:
            return [type(obj)(i) for i in zip(*map(_scatter_map, obj.items()))]
        return [obj for _ in target_gpus]

    try:
        res = _scatter_map(inputs)
    finally:
        _scatter_map = None

    return res


def _scatter_kwargs(inputs, kwargs, target_gpus, dim=0):
    inputs = _scatter(inputs, target_gpus, dim) if inputs else []
    kwargs = _scatter(kwargs, target_gpus, dim) if kwargs else []
    if len(inputs) < len(kwargs):
        inputs.extend(() for _ in range(len(kwargs) - len(inputs)))
    elif len(kwargs) < len(inputs):
        kwargs.extend({} for _ in range(len(inputs) - len(kwargs)))
    return tuple(inputs), tuple(kwargs)


[docs] class NNDataParallel(DataParallel): """ A :obj:`nn.DataParallel` class with :obj:`DataContainer` support. This class only bundles single-device modules. """ def __init__(self, module, device_ids=None, dim=0, **kwargs): assert device_ids is None or len(device_ids) <= 1 super(NNDataParallel, self).__init__( module, device_ids=[0] if device_ids is None else device_ids, **kwargs) self.dim = dim def scatter(self, inputs, kwargs, device_ids): return _scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim) def forward(self, *inputs, **kwargs): if self.device_ids: return super(NNDataParallel, self).forward(*inputs, **kwargs) else: inputs, kwargs = self.scatter(inputs, kwargs, [-1]) return self.module(*inputs[0], **kwargs[0])
[docs] class NNDistributedDataParallel(DistributedDataParallel): """ A :obj:`nn.DistributedDataParallel` class with :obj:`DataContainer` support. This class only bundles single-device modules. """ def __init__(self, module, device_ids=None, broadcast_buffers=False, **kwargs): assert device_ids is None or len(device_ids) <= 1 if device_ids is None: if torch.cuda.is_available(): device_ids = [torch.cuda.current_device()] module = module.cuda() elif len(device_ids) == 1: module = module.to('cuda:{}'.format(device_ids[0])) super(NNDistributedDataParallel, self).__init__( module, device_ids=device_ids, broadcast_buffers=broadcast_buffers, **kwargs) def _run_ddp_forward(self, *inputs, **kwargs): if self._use_replicated_tensor_module: module = self._replicated_tensor_module else: module = self.module if self.device_ids: inputs, kwargs = _scatter_kwargs( inputs, kwargs, self.device_ids, dim=self.dim) with self._inside_ddp_forward(): return module(*inputs[0], **kwargs[0]) else: with self._inside_ddp_forward(): return module(*inputs, **kwargs) def scatter(self, inputs, kwargs, device_ids): return _scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim) def forward(self, *inputs, **kwargs): if self.device_ids: return super(NNDistributedDataParallel, self).forward(*inputs, **kwargs) else: inputs, kwargs = self.scatter(inputs, kwargs, [-1]) return self.module(*inputs[0], **kwargs[0])