mmcv.cnn.bricks.scale 源代码

# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn


[文档]class Scale(nn.Module): """A learnable scale parameter. This layer scales the input by a learnable factor. It multiplies a learnable scale parameter of shape (1,) with input of any shape. Args: scale (float): Initial value of scale factor. Default: 1.0 """ def __init__(self, scale=1.0): super(Scale, self).__init__() self.scale = nn.Parameter(torch.tensor(scale, dtype=torch.float))
[文档] def forward(self, x): return x * self.scale