Shortcuts

Source code for mmcv.cnn.bricks.upsample

# Copyright (c) OpenMMLab. All rights reserved.
import inspect
from typing import Dict

import torch
import torch.nn as nn
import torch.nn.functional as F
from mmengine.model import xavier_init
from mmengine.registry import MODELS

MODELS.register_module('nearest', module=nn.Upsample)
MODELS.register_module('bilinear', module=nn.Upsample)


@MODELS.register_module(name='pixel_shuffle')
class PixelShufflePack(nn.Module):
    """Pixel Shuffle upsample layer.

    This module packs `F.pixel_shuffle()` and a nn.Conv2d module together to
    achieve a simple upsampling with pixel shuffle.

    Args:
        in_channels (int): Number of input channels.
        out_channels (int): Number of output channels.
        scale_factor (int): Upsample ratio.
        upsample_kernel (int): Kernel size of the conv layer to expand the
            channels.
    """

    def __init__(self, in_channels: int, out_channels: int, scale_factor: int,
                 upsample_kernel: int):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.scale_factor = scale_factor
        self.upsample_kernel = upsample_kernel
        self.upsample_conv = nn.Conv2d(
            self.in_channels,
            self.out_channels * scale_factor * scale_factor,
            self.upsample_kernel,
            padding=(self.upsample_kernel - 1) // 2)
        self.init_weights()

    def init_weights(self):
        xavier_init(self.upsample_conv, distribution='uniform')

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.upsample_conv(x)
        x = F.pixel_shuffle(x, self.scale_factor)
        return x


[docs]def build_upsample_layer(cfg: Dict, *args, **kwargs) -> nn.Module: """Build upsample layer. Args: cfg (dict): The upsample layer config, which should contain: - type (str): Layer type. - scale_factor (int): Upsample ratio, which is not applicable to deconv. - layer args: Args needed to instantiate a upsample layer. args (argument list): Arguments passed to the ``__init__`` method of the corresponding conv layer. kwargs (keyword arguments): Keyword arguments passed to the ``__init__`` method of the corresponding conv layer. Returns: nn.Module: Created upsample layer. """ if not isinstance(cfg, dict): raise TypeError(f'cfg must be a dict, but got {type(cfg)}') if 'type' not in cfg: raise KeyError( f'the cfg dict must contain the key "type", but got {cfg}') cfg_ = cfg.copy() layer_type = cfg_.pop('type') if inspect.isclass(layer_type): upsample = layer_type # Switch registry to the target scope. If `upsample` cannot be found # in the registry, fallback to search `upsample` in the # mmengine.MODELS. else: with MODELS.switch_scope_and_registry(None) as registry: upsample = registry.get(layer_type) if upsample is None: raise KeyError(f'Cannot find {upsample} in registry under scope ' f'name {registry.scope}') if upsample is nn.Upsample: cfg_['mode'] = layer_type layer = upsample(*args, **kwargs, **cfg_) return layer
Read the Docs v: latest
Versions
master
latest
2.x
1.x
v1.7.0
v1.6.2
v1.6.1
v1.6.0
v1.5.3
v1.5.2_a
v1.5.1
v1.5.0
v1.4.8
v1.4.7
v1.4.6
v1.4.5
v1.4.4
v1.4.3
v1.4.2
v1.4.1
v1.4.0
v1.3.18
v1.3.17
v1.3.16
v1.3.15
v1.3.14
v1.3.13
v1.3.12
v1.3.11
v1.3.10
v1.3.9
v1.3.8
v1.3.7
v1.3.6
v1.3.5
v1.3.4
v1.3.3
v1.3.2
v1.3.1
v1.3.0
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.