Shortcuts

CrissCrossAttention

class mmcv.ops.CrissCrossAttention(in_channels: int)[source]

Criss-Cross Attention Module.

Note

Before v1.3.13, we use a CUDA op. Since v1.3.13, we switch to a pure PyTorch and equivalent implementation. For more details, please refer to https://github.com/open-mmlab/mmcv/pull/1201.

Speed comparison for one forward pass

  • Input size: [2,512,97,97]

  • Device: 1 NVIDIA GeForce RTX 2080 Ti

PyTorch version

CUDA version

Relative speed

with torch.no_grad()

0.00554402 s

0.0299619 s

5.4x

no with torch.no_grad()

0.00562803 s

0.0301349 s

5.4x

Parameters

in_channels (int) – Channels of the input feature map.

forward(x: torch.Tensor)torch.Tensor[source]

Forward function of Criss-Cross Attention.

Parameters

x (torch.Tensor) – Input feature with the shape of (batch_size, in_channels, height, width).

Returns

Output of the layer, with the shape of (batch_size, in_channels, height, width)

Return type

torch.Tensor