Source code for nncore.dataset.wrapper

# Copyright (c) Ye Liu. Licensed under the MIT License.

from torch.utils.data import Dataset

import nncore
from .builder import DATASETS, build_dataset


[docs] @DATASETS.register() @nncore.bind_getter('times') @nncore.bind_method('_dataset', ['set_state', 'evaluate']) class RepeatDataset(Dataset): """ A dataset wrapper for repeated samples. The length of repeated dataset will be ``times`` larger than the original dataset. This is useful when the data loading time is long but the dataset is small. Using this class can reduce the data loading time among epochs. Args: dataset (:obj:`Dataset` | cfg | str): The dataset or config of dataset to be repeated. times (int): The number of repeat times. """ def __init__(self, dataset, times): if not isinstance(dataset, Dataset): dataset = build_dataset(dataset) if hasattr(dataset, 'CLASSES'): self.CLASSES = dataset.CLASSES self._dataset = dataset self._times = times def __getitem__(self, idx): return self.dataset[idx % len(self._dataset)] def __len__(self): state = getattr(self._dataset, 'state', None) times = 1 if state in ('val', 'test') else self.times return len(self._dataset) * times @property def dataset(self): return self._dataset