Shortcuts

mmcv.cnn.bricks.swish 源代码

# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmengine.registry import MODELS


[文档]@MODELS.register_module() class Swish(nn.Module): """Swish Module. This module applies the swish function: .. math:: Swish(x) = x * Sigmoid(x) Returns: Tensor: The output tensor. """ def __init__(self): super().__init__()
[文档] def forward(self, x: torch.Tensor) -> torch.Tensor: return x * torch.sigmoid(x)