Shortcuts

Source code for mmcv.ops.tin_shift

# Copyright (c) OpenMMLab. All rights reserved.
# Code reference from "Temporal Interlacing Network"
# https://github.com/deepcs233/TIN/blob/master/cuda_shift/rtc_wrap.py
# Hao Shao, Shengju Qian, Yu Liu
# shaoh19@mails.tsinghua.edu.cn, sjqian@cse.cuhk.edu.hk, yuliu@ee.cuhk.edu.hk

import torch
import torch.nn as nn
from torch.autograd import Function

from ..utils import ext_loader

ext_module = ext_loader.load_ext('_ext',
                                 ['tin_shift_forward', 'tin_shift_backward'])


class TINShiftFunction(Function):

    @staticmethod
    def forward(ctx, input, shift):
        if input.size(0) != shift.size(0):
            raise ValueError(
                'The first dim (batch) of `input` and `shift` should be '
                f'same, but got {input.size(0)} and {shift.size(0)}.')
        C = input.size(2)
        num_segments = shift.size(1)
        if C // num_segments <= 0 or C % num_segments != 0:
            raise ValueError('C should be a multiple of num_segments, '
                             f'but got C={C} and num_segments={num_segments}.')

        ctx.save_for_backward(shift)

        out = torch.zeros_like(input)
        ext_module.tin_shift_forward(input, shift, out)

        return out

    @staticmethod
    def backward(ctx, grad_output):

        shift = ctx.saved_tensors[0]
        data_grad_input = grad_output.new(*grad_output.size()).zero_()
        shift_grad_input = shift.new(*shift.size()).zero_()
        ext_module.tin_shift_backward(grad_output, shift, data_grad_input)

        return data_grad_input, shift_grad_input


tin_shift = TINShiftFunction.apply


[docs]class TINShift(nn.Module): """Temporal Interlace Shift. Temporal Interlace shift is a differentiable temporal-wise frame shifting which is proposed in "Temporal Interlacing Network" Please refer to `Temporal Interlacing Network <https://arxiv.org/abs/2001.06499>`_ for more details. Code is modified from https://github.com/mit-han-lab/temporal-shift-module """
[docs] def forward(self, input, shift): """Perform temporal interlace shift. Args: input (torch.Tensor): Feature map with shape [N, num_segments, C, H * W]. shift (torch.Tensor): Shift tensor with shape [N, num_segments]. Returns: Feature map after temporal interlace shift. """ return tin_shift(input, shift)
Read the Docs v: 2.x
Versions
master
latest
2.x
1.x
v1.7.0
v1.6.2
v1.6.1
v1.6.0
v1.5.3
v1.5.2_a
v1.5.1
v1.5.0
v1.4.8
v1.4.7
v1.4.6
v1.4.5
v1.4.4
v1.4.3
v1.4.2
v1.4.1
v1.4.0
v1.3.18
v1.3.17
v1.3.16
v1.3.15
v1.3.14
v1.3.13
v1.3.12
v1.3.11
v1.3.10
v1.3.9
v1.3.8
v1.3.7
v1.3.6
v1.3.5
v1.3.4
v1.3.3
v1.3.2
v1.3.1
v1.3.0
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.