Source code for nncore.utils.data

# Copyright (c) Ye Liu. Licensed under the MIT License.

import numpy as np


[docs] def swap_element(matrix, i, j, dim=0): """ Swap two elements of an array or a tensor. Args: matrix (:obj:`np.ndarray` | :obj:`torch.Tensor`): The array or tensor to be swapped. i (int | tuple): Index of the first element. j (int | tuple): Index of the second element. dim (int, optional): The dimension to swap. Default: ``0``. Returns: :obj:`np.ndarray` | :obj:`torch.Tensor`: The swapped array or tensor. """ inds = [slice(0, matrix.shape[d]) for d in range(dim)] i_inds = inds + [i] j_inds = inds + [j] meth = 'copy' if isinstance(matrix, np.ndarray) else 'clone' m_i = getattr(matrix[i_inds], meth)() m_j = getattr(matrix[j_inds], meth)() matrix[i_inds] = m_j matrix[j_inds] = m_i return matrix
[docs] def is_seq_of(seq, item_type, seq_type=(list, tuple)): """ Check whether it is a sequence of some type. Args: seq (Sequence): The sequence to be checked. item_type (tuple[type] | type): Expected item type. seq_type (tuple[type] | type, optional): Expected sequence type. Default: ``(list, tuple)``. Returns: bool: Whether the sequence is valid. """ if not isinstance(seq, seq_type): return False for item in seq: if not isinstance(item, item_type): return False return True
[docs] def is_list_of(seq, item_type): """ Check whether it is a list of some type. A partial method of :obj:`is_seq_of`. """ return is_seq_of(seq, item_type, seq_type=list)
[docs] def is_tuple_of(seq, item_type): """ Check whether it is a tuple of some type. A partial method of :obj:`is_seq_of`. """ return is_seq_of(seq, item_type, seq_type=tuple)
[docs] def slice(seq, length, type='list'): """ Slice a sequence into several sub sequences by length. Args: seq (list | tuple): The sequence to be sliced. length (list[int] | int): The expected length or list of lengths. type (str, optional): The type of returned object. Expected values include ``'list'`` and ``'tuple'``. Default: ``'list'``. Returns: list[list]: The sliced sequences. """ assert type in ('list', 'tuple') if isinstance(length, int): assert len(seq) % length == 0 length = [length] * int(len(seq) / length) elif not isinstance(length, list): raise TypeError("'length' must be an integer or a list of integers") elif sum(length) != len(seq): raise ValueError('the total length do not match the sequence length') out, idx = [], 0 for i in range(len(length)): out.append(seq[idx:idx + length[i]]) idx += length[i] if type == 'tuple': out = tuple(out) return out
[docs] def concat(seq): """ Concatenate a sequence of sequences. Args: seq (list | tuple): The sequence to be concatenated. Returns: list | tuple: The concatenated sequence. """ seq_type = type(seq) out = [] for item in seq: out += item return seq_type(out)
[docs] def interleave(seq): """ Interleave a sequence of sequences. Args: seq (list | tuple): The sequence to be interleaved. Returns: list | tuple: The interleaved sequence. """ seq_type = type(seq) return seq_type([v for s in zip(*seq) for v in s])
[docs] def flatten(seq): """ Flatten a sequence of sequences and items. Args: seq (list | tuple): The sequence to be flattened. Returns: list | tuple: The flattened sequence. """ seq_type = type(seq) out = [] for item in seq: if isinstance(item, (list, tuple)): out += flatten(item) else: out.append(item) return seq_type(out)
[docs] def to_dict_of_list(in_list): """ Convert a list of dicts to a dict of lists. Args: in_list (list): The list of dicts to be converted. Returns: dict: The converted dict of lists. """ for i in range(len(in_list) - 1): if in_list[i].keys() != in_list[i + 1].keys(): raise ValueError('dict keys are not consistent') out_dict = dict() for key in in_list[0]: out_dict[key] = [item[key] for item in in_list] return out_dict
[docs] def to_list_of_dict(in_dict): """ Convert a dict of lists to a list of dicts. Args: in_dict (dict): the dict of lists to be converted. Returns: list: The converted list of dicts. """ values = in_dict.values() for i in range(len(in_dict) - 1): if len(values[i]) != len(values[i + 1]): raise ValueError('lengths of lists are not consistent') out_list = [] for i in range(len(in_dict)): out_list.append({k: v[i] for k, v in in_dict.items()}) return out_list