Shortcuts

Source code for mmcv.ops.psa_mask

# Modified from https://github.com/hszhao/semseg/blob/master/lib/psa
from typing import Optional, Tuple

import torch
from torch import nn
from torch.autograd import Function
from torch.nn.modules.utils import _pair

from ..utils import ext_loader

ext_module = ext_loader.load_ext('_ext',
                                 ['psamask_forward', 'psamask_backward'])


class PSAMaskFunction(Function):

    @staticmethod
    def symbolic(g, input, psa_type, mask_size):
        return g.op(
            'mmcv::MMCVPSAMask',
            input,
            psa_type_i=psa_type,
            mask_size_i=mask_size)

    @staticmethod
    def forward(ctx, input: torch.Tensor, psa_type: str,
                mask_size: int) -> torch.Tensor:
        ctx.psa_type = psa_type
        ctx.mask_size = _pair(mask_size)
        ctx.save_for_backward(input)

        h_mask, w_mask = ctx.mask_size
        batch_size, channels, h_feature, w_feature = input.size()
        assert channels == h_mask * w_mask
        output = input.new_zeros(
            (batch_size, h_feature * w_feature, h_feature, w_feature))

        ext_module.psamask_forward(
            input,
            output,
            psa_type=psa_type,
            num_=batch_size,
            h_feature=h_feature,
            w_feature=w_feature,
            h_mask=h_mask,
            w_mask=w_mask,
            half_h_mask=(h_mask - 1) // 2,
            half_w_mask=(w_mask - 1) // 2)
        return output

    @staticmethod
    def backward(
            ctx, grad_output: torch.Tensor
    ) -> Tuple[torch.Tensor, None, None, None]:
        input = ctx.saved_tensors[0]
        psa_type = ctx.psa_type
        h_mask, w_mask = ctx.mask_size
        batch_size, channels, h_feature, w_feature = input.size()
        grad_input = grad_output.new_zeros(
            (batch_size, channels, h_feature, w_feature))
        ext_module.psamask_backward(
            grad_output,
            grad_input,
            psa_type=psa_type,
            num_=batch_size,
            h_feature=h_feature,
            w_feature=w_feature,
            h_mask=h_mask,
            w_mask=w_mask,
            half_h_mask=(h_mask - 1) // 2,
            half_w_mask=(w_mask - 1) // 2)
        return grad_input, None, None, None


psa_mask = PSAMaskFunction.apply


[docs]class PSAMask(nn.Module): def __init__(self, psa_type: str, mask_size: Optional[tuple] = None): super().__init__() assert psa_type in ['collect', 'distribute'] if psa_type == 'collect': psa_type_enum = 0 else: psa_type_enum = 1 self.psa_type_enum = psa_type_enum self.mask_size = mask_size self.psa_type = psa_type
[docs] def forward(self, input: torch.Tensor) -> torch.Tensor: return psa_mask(input, self.psa_type_enum, self.mask_size)
def __repr__(self): s = self.__class__.__name__ s += f'(psa_type={self.psa_type}, ' s += f'mask_size={self.mask_size})' return s
Read the Docs v: latest
Versions
master
latest
2.x
1.x
v1.7.0
v1.6.2
v1.6.1
v1.6.0
v1.5.3
v1.5.2_a
v1.5.1
v1.5.0
v1.4.8
v1.4.7
v1.4.6
v1.4.5
v1.4.4
v1.4.3
v1.4.2
v1.4.1
v1.4.0
v1.3.18
v1.3.17
v1.3.16
v1.3.15
v1.3.14
v1.3.13
v1.3.12
v1.3.11
v1.3.10
v1.3.9
v1.3.8
v1.3.7
v1.3.6
v1.3.5
v1.3.4
v1.3.3
v1.3.2
v1.3.1
v1.3.0
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.