CrissCrossAttention¶
- class mmcv.ops.CrissCrossAttention(in_channels: int)[源代码]¶
Criss-Cross Attention Module.
注解
Before v1.3.13, we use a CUDA op. Since v1.3.13, we switch to a pure PyTorch and equivalent implementation. For more details, please refer to https://github.com/open-mmlab/mmcv/pull/1201.
Speed comparison for one forward pass
Input size: [2,512,97,97]
Device: 1 NVIDIA GeForce RTX 2080 Ti
PyTorch version
CUDA version
Relative speed
with torch.no_grad()
0.00554402 s
0.0299619 s
5.4x
no with torch.no_grad()
0.00562803 s
0.0301349 s
5.4x
- 参数
in_channels (int) – Channels of the input feature map.
- forward(x: torch.Tensor) → torch.Tensor[源代码]¶
forward function of Criss-Cross Attention.
- 参数
x (torch.Tensor) – Input feature with the shape of (batch_size, in_channels, height, width).
- 返回
Output of the layer, with the shape of (batch_size, in_channels, height, width)
- 返回类型