Source code for mmcv.transforms.formatting
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Sequence, Union
import mmengine
import numpy as np
import torch
from .base import BaseTransform
from .builder import TRANSFORMS
def to_tensor(
data: Union[torch.Tensor, np.ndarray, Sequence, int,
float]) -> torch.Tensor:
"""Convert objects of various python types to :obj:`torch.Tensor`.
Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`,
:class:`Sequence`, :class:`int` and :class:`float`.
Args:
data (torch.Tensor | numpy.ndarray | Sequence | int | float): Data to
be converted.
Returns:
torch.Tensor: the converted data.
"""
if isinstance(data, torch.Tensor):
return data
elif isinstance(data, np.ndarray):
return torch.from_numpy(data)
elif isinstance(data, Sequence) and not mmengine.is_str(data):
return torch.tensor(data)
elif isinstance(data, int):
return torch.LongTensor([data])
elif isinstance(data, float):
return torch.FloatTensor([data])
else:
raise TypeError(f'type {type(data)} cannot be converted to tensor.')
[docs]@TRANSFORMS.register_module()
class ToTensor(BaseTransform):
"""Convert some results to :obj:`torch.Tensor` by given keys.
Required keys:
- all these keys in `keys`
Modified Keys:
- all these keys in `keys`
Args:
keys (Sequence[str]): Keys that need to be converted to Tensor.
"""
def __init__(self, keys: Sequence[str]) -> None:
self.keys = keys
[docs] def transform(self, results: dict) -> dict:
"""Transform function to convert data to `torch.Tensor`.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: `keys` in results will be updated.
"""
for key in self.keys:
key_list = key.split('.')
cur_item = results
for i in range(len(key_list)):
if key_list[i] not in cur_item:
raise KeyError(f'Can not find key {key}')
if i == len(key_list) - 1:
cur_item[key_list[i]] = to_tensor(cur_item[key_list[i]])
break
cur_item = cur_item[key_list[i]]
return results
def __repr__(self) -> str:
return self.__class__.__name__ + f'(keys={self.keys})'
[docs]@TRANSFORMS.register_module()
class ImageToTensor(BaseTransform):
"""Convert image to :obj:`torch.Tensor` by given keys.
The dimension order of input image is (H, W, C). The pipeline will convert
it to (C, H, W). If only 2 dimension (H, W) is given, the output would be
(1, H, W).
Required keys:
- all these keys in `keys`
Modified Keys:
- all these keys in `keys`
Args:
keys (Sequence[str]): Key of images to be converted to Tensor.
"""
def __init__(self, keys: dict) -> None:
self.keys = keys
[docs] def transform(self, results: dict) -> dict:
"""Transform function to convert image in results to
:obj:`torch.Tensor` and transpose the channel order.
Args:
results (dict): Result dict contains the image data to convert.
Returns:
dict: The result dict contains the image converted
to :obj:``torch.Tensor`` and transposed to (C, H, W) order.
"""
for key in self.keys:
img = results[key]
if len(img.shape) < 3:
img = np.expand_dims(img, -1)
results[key] = (to_tensor(img.transpose(2, 0, 1))).contiguous()
return results
def __repr__(self) -> str:
return self.__class__.__name__ + f'(keys={self.keys})'