Shortcuts

Source code for mmcv.ops.cc_attention

# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmengine.registry import MODELS

from mmcv.cnn import Scale


def NEG_INF_DIAG(n: int, device: torch.device) -> torch.Tensor:
    """Returns a diagonal matrix of size [n, n].

    The diagonal are all "-inf". This is for avoiding calculating the
    overlapped element in the Criss-Cross twice.
    """
    return torch.diag(torch.tensor(float('-inf')).to(device).repeat(n), 0)


[docs]@MODELS.register_module() class CrissCrossAttention(nn.Module): """Criss-Cross Attention Module. .. note:: Before v1.3.13, we use a CUDA op. Since v1.3.13, we switch to a pure PyTorch and equivalent implementation. For more details, please refer to https://github.com/open-mmlab/mmcv/pull/1201. Speed comparison for one forward pass - Input size: [2,512,97,97] - Device: 1 NVIDIA GeForce RTX 2080 Ti +-----------------------+---------------+------------+---------------+ | |PyTorch version|CUDA version|Relative speed | +=======================+===============+============+===============+ |with torch.no_grad() |0.00554402 s |0.0299619 s |5.4x | +-----------------------+---------------+------------+---------------+ |no with torch.no_grad()|0.00562803 s |0.0301349 s |5.4x | +-----------------------+---------------+------------+---------------+ Args: in_channels (int): Channels of the input feature map. """ def __init__(self, in_channels: int) -> None: super().__init__() self.query_conv = nn.Conv2d(in_channels, in_channels // 8, 1) self.key_conv = nn.Conv2d(in_channels, in_channels // 8, 1) self.value_conv = nn.Conv2d(in_channels, in_channels, 1) self.gamma = Scale(0.) self.in_channels = in_channels
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """forward function of Criss-Cross Attention. Args: x (torch.Tensor): Input feature with the shape of (batch_size, in_channels, height, width). Returns: torch.Tensor: Output of the layer, with the shape of (batch_size, in_channels, height, width) """ B, C, H, W = x.size() query = self.query_conv(x) key = self.key_conv(x) value = self.value_conv(x) energy_H = torch.einsum('bchw,bciw->bwhi', query, key) + NEG_INF_DIAG( H, query.device) energy_H = energy_H.transpose(1, 2) energy_W = torch.einsum('bchw,bchj->bhwj', query, key) attn = F.softmax( torch.cat([energy_H, energy_W], dim=-1), dim=-1) # [B,H,W,(H+W)] out = torch.einsum('bciw,bhwi->bchw', value, attn[..., :H]) out += torch.einsum('bchj,bhwj->bchw', value, attn[..., H:]) out = self.gamma(out) + x out = out.contiguous() return out
def __repr__(self) -> str: s = self.__class__.__name__ s += f'(in_channels={self.in_channels})' return s
Read the Docs v: latest
Versions
master
latest
2.x
1.x
v1.7.0
v1.6.2
v1.6.1
v1.6.0
v1.5.3
v1.5.2_a
v1.5.1
v1.5.0
v1.4.8
v1.4.7
v1.4.6
v1.4.5
v1.4.4
v1.4.3
v1.4.2
v1.4.1
v1.4.0
v1.3.18
v1.3.17
v1.3.16
v1.3.15
v1.3.14
v1.3.13
v1.3.12
v1.3.11
v1.3.10
v1.3.9
v1.3.8
v1.3.7
v1.3.6
v1.3.5
v1.3.4
v1.3.3
v1.3.2
v1.3.1
v1.3.0
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.