Shortcuts

CrissCrossAttention

class mmcv.ops.CrissCrossAttention(in_channels: int)[源代码]

Criss-Cross Attention Module.

注解

Before v1.3.13, we use a CUDA op. Since v1.3.13, we switch to a pure PyTorch and equivalent implementation. For more details, please refer to https://github.com/open-mmlab/mmcv/pull/1201.

Speed comparison for one forward pass

  • Input size: [2,512,97,97]

  • Device: 1 NVIDIA GeForce RTX 2080 Ti

PyTorch version

CUDA version

Relative speed

with torch.no_grad()

0.00554402 s

0.0299619 s

5.4x

no with torch.no_grad()

0.00562803 s

0.0301349 s

5.4x

参数

in_channels (int) – Channels of the input feature map.

forward(x: torch.Tensor)torch.Tensor[源代码]

forward function of Criss-Cross Attention.

参数

x (torch.Tensor) – Input feature with the shape of (batch_size, in_channels, height, width).

返回

Output of the layer, with the shape of (batch_size, in_channels, height, width)

返回类型

torch.Tensor