Shortcuts

Source code for mmcv.transforms.formatting

# Copyright (c) OpenMMLab. All rights reserved.
from typing import Sequence, Union

import mmengine
import numpy as np
import torch

from .base import BaseTransform
from .builder import TRANSFORMS


def to_tensor(
    data: Union[torch.Tensor, np.ndarray, Sequence, int,
                float]) -> torch.Tensor:
    """Convert objects of various python types to :obj:`torch.Tensor`.

    Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`,
    :class:`Sequence`, :class:`int` and :class:`float`.

    Args:
        data (torch.Tensor | numpy.ndarray | Sequence | int | float): Data to
            be converted.

    Returns:
        torch.Tensor: the converted data.
    """

    if isinstance(data, torch.Tensor):
        return data
    elif isinstance(data, np.ndarray):
        return torch.from_numpy(data)
    elif isinstance(data, Sequence) and not mmengine.is_str(data):
        return torch.tensor(data)
    elif isinstance(data, int):
        return torch.LongTensor([data])
    elif isinstance(data, float):
        return torch.FloatTensor([data])
    else:
        raise TypeError(f'type {type(data)} cannot be converted to tensor.')


[docs]@TRANSFORMS.register_module() class ToTensor(BaseTransform): """Convert some results to :obj:`torch.Tensor` by given keys. Required keys: - all these keys in `keys` Modified Keys: - all these keys in `keys` Args: keys (Sequence[str]): Keys that need to be converted to Tensor. """ def __init__(self, keys: Sequence[str]) -> None: self.keys = keys
[docs] def transform(self, results: dict) -> dict: """Transform function to convert data to `torch.Tensor`. Args: results (dict): Result dict from loading pipeline. Returns: dict: `keys` in results will be updated. """ for key in self.keys: key_list = key.split('.') cur_item = results for i in range(len(key_list)): if key_list[i] not in cur_item: raise KeyError(f'Can not find key {key}') if i == len(key_list) - 1: cur_item[key_list[i]] = to_tensor(cur_item[key_list[i]]) break cur_item = cur_item[key_list[i]] return results
def __repr__(self) -> str: return self.__class__.__name__ + f'(keys={self.keys})'
[docs]@TRANSFORMS.register_module() class ImageToTensor(BaseTransform): """Convert image to :obj:`torch.Tensor` by given keys. The dimension order of input image is (H, W, C). The pipeline will convert it to (C, H, W). If only 2 dimension (H, W) is given, the output would be (1, H, W). Required keys: - all these keys in `keys` Modified Keys: - all these keys in `keys` Args: keys (Sequence[str]): Key of images to be converted to Tensor. """ def __init__(self, keys: dict) -> None: self.keys = keys
[docs] def transform(self, results: dict) -> dict: """Transform function to convert image in results to :obj:`torch.Tensor` and transpose the channel order. Args: results (dict): Result dict contains the image data to convert. Returns: dict: The result dict contains the image converted to :obj:``torch.Tensor`` and transposed to (C, H, W) order. """ for key in self.keys: img = results[key] if len(img.shape) < 3: img = np.expand_dims(img, -1) results[key] = (to_tensor(img.transpose(2, 0, 1))).contiguous() return results
def __repr__(self) -> str: return self.__class__.__name__ + f'(keys={self.keys})'