mmcv.ops.tin_shift 源代码

# Copyright (c) OpenMMLab. All rights reserved.
# Code reference from "Temporal Interlacing Network"
# Hao Shao, Shengju Qian, Yu Liu

import torch
import torch.nn as nn
from torch.autograd import Function

from ..utils import ext_loader

ext_module = ext_loader.load_ext('_ext',
                                 ['tin_shift_forward', 'tin_shift_backward'])

class TINShiftFunction(Function):

    def forward(ctx, input, shift):


        out = torch.zeros_like(input)
        ext_module.tin_shift_forward(input, shift, out)

        return out

    def backward(ctx, grad_output):

        shift = ctx.saved_tensors[0]
        data_grad_input =*grad_output.size()).zero_()
        shift_grad_input =*shift.size()).zero_()
        ext_module.tin_shift_backward(grad_output, shift, data_grad_input)

        return data_grad_input, shift_grad_input

tin_shift = TINShiftFunction.apply

[文档]class TINShift(nn.Module): """Temporal Interlace Shift. Temporal Interlace shift is a differentiable temporal-wise frame shifting which is proposed in "Temporal Interlacing Network" Please refer to for more details. Code is modified from """
[文档] def forward(self, input, shift): """Perform temporal interlace shift. Args: input (Tensor): Feature map with shape [N, num_segments, C, H * W]. shift (Tensor): Shift tensor with shape [N, num_segments]. Returns: Feature map after temporal interlace shift. """ return tin_shift(input, shift)