Shortcuts

CrissCrossAttention

class mmcv.ops.CrissCrossAttention(in_channels: int)[源代码]

Criss-Cross Attention Module.

注解

Before v1.3.13, we use a CUDA op. Since v1.3.13, we switch to a pure PyTorch and equivalent implementation. For more details, please refer to https://github.com/open-mmlab/mmcv/pull/1201.

Speed comparison for one forward pass

  • Input size: [2,512,97,97]

  • Device: 1 NVIDIA GeForce RTX 2080 Ti

PyTorch version

CUDA version

Relative speed

with torch.no_grad()

0.00554402 s

0.0299619 s

5.4x

no with torch.no_grad()

0.00562803 s

0.0301349 s

5.4x

参数

in_channels (int) – Channels of the input feature map.

forward(x: torch.Tensor)torch.Tensor[源代码]

forward function of Criss-Cross Attention.

参数

x (torch.Tensor) – Input feature with the shape of (batch_size, in_channels, height, width).

返回

Output of the layer, with the shape of (batch_size, in_channels, height, width)

返回类型

torch.Tensor

Read the Docs v: stable
Versions
latest
stable
2.x
v2.0.1
v2.0.0
1.x
v1.7.1
v1.7.0
v1.6.2
v1.6.1
v1.6.0
v1.5.3
v1.5.2_a
v1.5.1
v1.5.0
v1.4.8
v1.4.7
v1.4.6
v1.4.5
v1.4.4
v1.4.3
v1.4.2
v1.4.1
v1.4.0
v1.3.18
v1.3.17
v1.3.16
v1.3.15
v1.3.14
v1.3.13
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.