CrissCrossAttention¶
- class mmcv.ops.CrissCrossAttention(in_channels: int)[source]¶
Criss-Cross Attention Module.
Note
Before v1.3.13, we use a CUDA op. Since v1.3.13, we switch to a pure PyTorch and equivalent implementation. For more details, please refer to https://github.com/open-mmlab/mmcv/pull/1201.
Speed comparison for one forward pass
Input size: [2,512,97,97]
Device: 1 NVIDIA GeForce RTX 2080 Ti
PyTorch version
CUDA version
Relative speed
with torch.no_grad()
0.00554402 s
0.0299619 s
5.4x
no with torch.no_grad()
0.00562803 s
0.0301349 s
5.4x
- Parameters
in_channels (int) – Channels of the input feature map.
- forward(x: torch.Tensor) → torch.Tensor[source]¶
forward function of Criss-Cross Attention.
- Parameters
x (torch.Tensor) – Input feature with the shape of (batch_size, in_channels, height, width).
- Returns
Output of the layer, with the shape of (batch_size, in_channels, height, width)
- Return type