Source code for nncore.parallel.collate

# Copyright (c) Ye Liu. Licensed under the MIT License.

import torch
import torch.nn.functional as F
from torch.utils.data.dataloader import default_collate

import nncore
from .container import DataContainer


[docs] def collate(batch, samples_per_gpu=-1): """ A collate function for :obj:`DataLoader` with :obj:`DataContainer` support. Args: batch (any): The batch of data to be collated. samples_per_gpu (int, optional): Number of samples per GPU. ``-1`` means moving all the data to a single GPU. Default: ``-1``. """ if isinstance(batch[0], DataContainer): stacked = [] if samples_per_gpu < 0: samples_per_gpu = len(batch) if batch[0].stack: for i in range(0, len(batch), samples_per_gpu): assert torch.is_tensor(batch[i].data) if batch[i].pad_dims is None: stacked.append( default_collate([ sample.data for sample in batch[i:i + samples_per_gpu] ])) else: ndim = batch[i].dim() max_shape = [0] * batch[i].pad_dims for dim in range(1, batch[i].pad_dims + 1): max_shape[dim - 1] = batch[i].size(-dim) for sample in batch[i:i + samples_per_gpu]: for dim in range(0, ndim - batch[i].pad_dims): assert batch[i].size(dim) == sample.size(dim) for dim in range(1, batch[i].pad_dims + 1): max_shape[dim - 1] = max(max_shape[dim - 1], sample.size(-dim)) padded = [] for sample in batch[i:i + samples_per_gpu]: pad = [0] * batch[i].pad_dims * 2 for dim in range(1, batch[i].pad_dims + 1): pad[2 * dim - 1] = max_shape[dim - 1] - sample.size(-dim) padded.append( F.pad(sample.data, pad, value=sample.pad_value)) stacked.append(default_collate(padded)) else: for i in range(0, len(batch), samples_per_gpu): stacked.append( [sample.data for sample in batch[i:i + samples_per_gpu]]) return DataContainer( stacked, stack=batch[0].stack, pad_value=batch[0].pad_value, pad_dims=batch[0].pad_dims, cpu_only=batch[0].cpu_only) elif isinstance(batch[0], list): return collate(nncore.concat(batch), samples_per_gpu) elif isinstance(batch[0], tuple): transposed = zip(*batch) return [collate(samples, samples_per_gpu) for samples in transposed] elif isinstance(batch[0], dict): return { k: collate([d[k] for d in batch], samples_per_gpu) for k in batch[0] } else: return default_collate(batch)