Shortcuts

mmcv.cnn.utils.flops_counter 源代码

# Modified from flops-counter.pytorch by Vladislav Sovrasov
# original repo: https://github.com/sovrasov/flops-counter.pytorch

# MIT License

# Copyright (c) 2018 Vladislav Sovrasov

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import sys
from functools import partial

import numpy as np
import torch
import torch.nn as nn

import mmcv


[文档]def get_model_complexity_info(model, input_shape, print_per_layer_stat=True, as_strings=True, input_constructor=None, flush=False, ost=sys.stdout): """Get complexity information of a model. This method can calculate FLOPs and parameter counts of a model with corresponding input shape. It can also print complexity information for each layer in a model. Supported layers are listed as below: - Convolutions: ``nn.Conv1d``, ``nn.Conv2d``, ``nn.Conv3d``. - Activations: ``nn.ReLU``, ``nn.PReLU``, ``nn.ELU``, ``nn.LeakyReLU``, ``nn.ReLU6``. - Poolings: ``nn.MaxPool1d``, ``nn.MaxPool2d``, ``nn.MaxPool3d``, ``nn.AvgPool1d``, ``nn.AvgPool2d``, ``nn.AvgPool3d``, ``nn.AdaptiveMaxPool1d``, ``nn.AdaptiveMaxPool2d``, ``nn.AdaptiveMaxPool3d``, ``nn.AdaptiveAvgPool1d``, ``nn.AdaptiveAvgPool2d``, ``nn.AdaptiveAvgPool3d``. - BatchNorms: ``nn.BatchNorm1d``, ``nn.BatchNorm2d``, ``nn.BatchNorm3d``, ``nn.GroupNorm``, ``nn.InstanceNorm1d``, ``InstanceNorm2d``, ``InstanceNorm3d``, ``nn.LayerNorm``. - Linear: ``nn.Linear``. - Deconvolution: ``nn.ConvTranspose2d``. - Upsample: ``nn.Upsample``. Args: model (nn.Module): The model for complexity calculation. input_shape (tuple): Input shape used for calculation. print_per_layer_stat (bool): Whether to print complexity information for each layer in a model. Default: True. as_strings (bool): Output FLOPs and params counts in a string form. Default: True. input_constructor (None | callable): If specified, it takes a callable method that generates input. otherwise, it will generate a random tensor with input shape to calculate FLOPs. Default: None. flush (bool): same as that in :func:`print`. Default: False. ost (stream): same as ``file`` param in :func:`print`. Default: sys.stdout. Returns: tuple[float | str]: If ``as_strings`` is set to True, it will return FLOPs and parameter counts in a string format. otherwise, it will return those in a float number format. """ assert type(input_shape) is tuple assert len(input_shape) >= 1 assert isinstance(model, nn.Module) flops_model = add_flops_counting_methods(model) flops_model.eval() flops_model.start_flops_count() if input_constructor: input = input_constructor(input_shape) _ = flops_model(**input) else: try: batch = torch.ones(()).new_empty( (1, *input_shape), dtype=next(flops_model.parameters()).dtype, device=next(flops_model.parameters()).device) except StopIteration: # Avoid StopIteration for models which have no parameters, # like `nn.Relu()`, `nn.AvgPool2d`, etc. batch = torch.ones(()).new_empty((1, *input_shape)) _ = flops_model(batch) flops_count, params_count = flops_model.compute_average_flops_cost() if print_per_layer_stat: print_model_with_flops( flops_model, flops_count, params_count, ost=ost, flush=flush) flops_model.stop_flops_count() if as_strings: return flops_to_string(flops_count), params_to_string(params_count) return flops_count, params_count
def flops_to_string(flops, units='GFLOPs', precision=2): """Convert FLOPs number into a string. Note that Here we take a multiply-add counts as one FLOP. Args: flops (float): FLOPs number to be converted. units (str | None): Converted FLOPs units. Options are None, 'GFLOPs', 'MFLOPs', 'KFLOPs', 'FLOPs'. If set to None, it will automatically choose the most suitable unit for FLOPs. Default: 'GFLOPs'. precision (int): Digit number after the decimal point. Default: 2. Returns: str: The converted FLOPs number with units. Examples: >>> flops_to_string(1e9) '1.0 GFLOPs' >>> flops_to_string(2e5, 'MFLOPs') '0.2 MFLOPs' >>> flops_to_string(3e-9, None) '3e-09 FLOPs' """ if units is None: if flops // 10**9 > 0: return str(round(flops / 10.**9, precision)) + ' GFLOPs' elif flops // 10**6 > 0: return str(round(flops / 10.**6, precision)) + ' MFLOPs' elif flops // 10**3 > 0: return str(round(flops / 10.**3, precision)) + ' KFLOPs' else: return str(flops) + ' FLOPs' else: if units == 'GFLOPs': return str(round(flops / 10.**9, precision)) + ' ' + units elif units == 'MFLOPs': return str(round(flops / 10.**6, precision)) + ' ' + units elif units == 'KFLOPs': return str(round(flops / 10.**3, precision)) + ' ' + units else: return str(flops) + ' FLOPs' def params_to_string(num_params, units=None, precision=2): """Convert parameter number into a string. Args: num_params (float): Parameter number to be converted. units (str | None): Converted FLOPs units. Options are None, 'M', 'K' and ''. If set to None, it will automatically choose the most suitable unit for Parameter number. Default: None. precision (int): Digit number after the decimal point. Default: 2. Returns: str: The converted parameter number with units. Examples: >>> params_to_string(1e9) '1000.0 M' >>> params_to_string(2e5) '200.0 k' >>> params_to_string(3e-9) '3e-09' """ if units is None: if num_params // 10**6 > 0: return str(round(num_params / 10**6, precision)) + ' M' elif num_params // 10**3: return str(round(num_params / 10**3, precision)) + ' k' else: return str(num_params) else: if units == 'M': return str(round(num_params / 10.**6, precision)) + ' ' + units elif units == 'K': return str(round(num_params / 10.**3, precision)) + ' ' + units else: return str(num_params) def print_model_with_flops(model, total_flops, total_params, units='GFLOPs', precision=3, ost=sys.stdout, flush=False): """Print a model with FLOPs for each layer. Args: model (nn.Module): The model to be printed. total_flops (float): Total FLOPs of the model. total_params (float): Total parameter counts of the model. units (str | None): Converted FLOPs units. Default: 'GFLOPs'. precision (int): Digit number after the decimal point. Default: 3. ost (stream): same as `file` param in :func:`print`. Default: sys.stdout. flush (bool): same as that in :func:`print`. Default: False. Example: >>> class ExampleModel(nn.Module): >>> def __init__(self): >>> super().__init__() >>> self.conv1 = nn.Conv2d(3, 8, 3) >>> self.conv2 = nn.Conv2d(8, 256, 3) >>> self.conv3 = nn.Conv2d(256, 8, 3) >>> self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) >>> self.flatten = nn.Flatten() >>> self.fc = nn.Linear(8, 1) >>> def forward(self, x): >>> x = self.conv1(x) >>> x = self.conv2(x) >>> x = self.conv3(x) >>> x = self.avg_pool(x) >>> x = self.flatten(x) >>> x = self.fc(x) >>> return x >>> model = ExampleModel() >>> x = (3, 16, 16) to print the complexity information state for each layer, you can use >>> get_model_complexity_info(model, x) or directly use >>> print_model_with_flops(model, 4579784.0, 37361) ExampleModel( 0.037 M, 100.000% Params, 0.005 GFLOPs, 100.000% FLOPs, (conv1): Conv2d(0.0 M, 0.600% Params, 0.0 GFLOPs, 0.959% FLOPs, 3, 8, kernel_size=(3, 3), stride=(1, 1)) # noqa: E501 (conv2): Conv2d(0.019 M, 50.020% Params, 0.003 GFLOPs, 58.760% FLOPs, 8, 256, kernel_size=(3, 3), stride=(1, 1)) (conv3): Conv2d(0.018 M, 49.356% Params, 0.002 GFLOPs, 40.264% FLOPs, 256, 8, kernel_size=(3, 3), stride=(1, 1)) (avg_pool): AdaptiveAvgPool2d(0.0 M, 0.000% Params, 0.0 GFLOPs, 0.017% FLOPs, output_size=(1, 1)) (flatten): Flatten(0.0 M, 0.000% Params, 0.0 GFLOPs, 0.000% FLOPs, ) (fc): Linear(0.0 M, 0.024% Params, 0.0 GFLOPs, 0.000% FLOPs, in_features=8, out_features=1, bias=True) ) """ def accumulate_params(self): if is_supported_instance(self): return self.__params__ else: sum = 0 for m in self.children(): sum += m.accumulate_params() return sum def accumulate_flops(self): if is_supported_instance(self): return self.__flops__ / model.__batch_counter__ else: sum = 0 for m in self.children(): sum += m.accumulate_flops() return sum def flops_repr(self): accumulated_num_params = self.accumulate_params() accumulated_flops_cost = self.accumulate_flops() return ', '.join([ params_to_string( accumulated_num_params, units='M', precision=precision), '{:.3%} Params'.format(accumulated_num_params / total_params), flops_to_string( accumulated_flops_cost, units=units, precision=precision), '{:.3%} FLOPs'.format(accumulated_flops_cost / total_flops), self.original_extra_repr() ]) def add_extra_repr(m): m.accumulate_flops = accumulate_flops.__get__(m) m.accumulate_params = accumulate_params.__get__(m) flops_extra_repr = flops_repr.__get__(m) if m.extra_repr != flops_extra_repr: m.original_extra_repr = m.extra_repr m.extra_repr = flops_extra_repr assert m.extra_repr != m.original_extra_repr def del_extra_repr(m): if hasattr(m, 'original_extra_repr'): m.extra_repr = m.original_extra_repr del m.original_extra_repr if hasattr(m, 'accumulate_flops'): del m.accumulate_flops model.apply(add_extra_repr) print(model, file=ost, flush=flush) model.apply(del_extra_repr) def get_model_parameters_number(model): """Calculate parameter number of a model. Args: model (nn.module): The model for parameter number calculation. Returns: float: Parameter number of the model. """ num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) return num_params def add_flops_counting_methods(net_main_module): # adding additional methods to the existing module object, # this is done this way so that each function has access to self object net_main_module.start_flops_count = start_flops_count.__get__( net_main_module) net_main_module.stop_flops_count = stop_flops_count.__get__( net_main_module) net_main_module.reset_flops_count = reset_flops_count.__get__( net_main_module) net_main_module.compute_average_flops_cost = compute_average_flops_cost.__get__( # noqa: E501 net_main_module) net_main_module.reset_flops_count() return net_main_module def compute_average_flops_cost(self): """Compute average FLOPs cost. A method to compute average FLOPs cost, which will be available after `add_flops_counting_methods()` is called on a desired net object. Returns: float: Current mean flops consumption per image. """ batches_count = self.__batch_counter__ flops_sum = 0 for module in self.modules(): if is_supported_instance(module): flops_sum += module.__flops__ params_sum = get_model_parameters_number(self) return flops_sum / batches_count, params_sum def start_flops_count(self): """Activate the computation of mean flops consumption per image. A method to activate the computation of mean flops consumption per image. which will be available after ``add_flops_counting_methods()`` is called on a desired net object. It should be called before running the network. """ add_batch_counter_hook_function(self) def add_flops_counter_hook_function(module): if is_supported_instance(module): if hasattr(module, '__flops_handle__'): return else: handle = module.register_forward_hook( get_modules_mapping()[type(module)]) module.__flops_handle__ = handle self.apply(partial(add_flops_counter_hook_function)) def stop_flops_count(self): """Stop computing the mean flops consumption per image. A method to stop computing the mean flops consumption per image, which will be available after ``add_flops_counting_methods()`` is called on a desired net object. It can be called to pause the computation whenever. """ remove_batch_counter_hook_function(self) self.apply(remove_flops_counter_hook_function) def reset_flops_count(self): """Reset statistics computed so far. A method to Reset computed statistics, which will be available after `add_flops_counting_methods()` is called on a desired net object. """ add_batch_counter_variables_or_reset(self) self.apply(add_flops_counter_variable_or_reset) # ---- Internal functions def empty_flops_counter_hook(module, input, output): module.__flops__ += 0 def upsample_flops_counter_hook(module, input, output): output_size = output[0] batch_size = output_size.shape[0] output_elements_count = batch_size for val in output_size.shape[1:]: output_elements_count *= val module.__flops__ += int(output_elements_count) def relu_flops_counter_hook(module, input, output): active_elements_count = output.numel() module.__flops__ += int(active_elements_count) def linear_flops_counter_hook(module, input, output): input = input[0] output_last_dim = output.shape[ -1] # pytorch checks dimensions, so here we don't care much module.__flops__ += int(np.prod(input.shape) * output_last_dim) def pool_flops_counter_hook(module, input, output): input = input[0] module.__flops__ += int(np.prod(input.shape)) def norm_flops_counter_hook(module, input, output): input = input[0] batch_flops = np.prod(input.shape) if (getattr(module, 'affine', False) or getattr(module, 'elementwise_affine', False)): batch_flops *= 2 module.__flops__ += int(batch_flops) def deconv_flops_counter_hook(conv_module, input, output): # Can have multiple inputs, getting the first one input = input[0] batch_size = input.shape[0] input_height, input_width = input.shape[2:] kernel_height, kernel_width = conv_module.kernel_size in_channels = conv_module.in_channels out_channels = conv_module.out_channels groups = conv_module.groups filters_per_channel = out_channels // groups conv_per_position_flops = ( kernel_height * kernel_width * in_channels * filters_per_channel) active_elements_count = batch_size * input_height * input_width overall_conv_flops = conv_per_position_flops * active_elements_count bias_flops = 0 if conv_module.bias is not None: output_height, output_width = output.shape[2:] bias_flops = out_channels * batch_size * output_height * output_height overall_flops = overall_conv_flops + bias_flops conv_module.__flops__ += int(overall_flops) def conv_flops_counter_hook(conv_module, input, output): # Can have multiple inputs, getting the first one input = input[0] batch_size = input.shape[0] output_dims = list(output.shape[2:]) kernel_dims = list(conv_module.kernel_size) in_channels = conv_module.in_channels out_channels = conv_module.out_channels groups = conv_module.groups filters_per_channel = out_channels // groups conv_per_position_flops = int( np.prod(kernel_dims)) * in_channels * filters_per_channel active_elements_count = batch_size * int(np.prod(output_dims)) overall_conv_flops = conv_per_position_flops * active_elements_count bias_flops = 0 if conv_module.bias is not None: bias_flops = out_channels * active_elements_count overall_flops = overall_conv_flops + bias_flops conv_module.__flops__ += int(overall_flops) def batch_counter_hook(module, input, output): batch_size = 1 if len(input) > 0: # Can have multiple inputs, getting the first one input = input[0] batch_size = len(input) else: pass print('Warning! No positional inputs found for a module, ' 'assuming batch size is 1.') module.__batch_counter__ += batch_size def add_batch_counter_variables_or_reset(module): module.__batch_counter__ = 0 def add_batch_counter_hook_function(module): if hasattr(module, '__batch_counter_handle__'): return handle = module.register_forward_hook(batch_counter_hook) module.__batch_counter_handle__ = handle def remove_batch_counter_hook_function(module): if hasattr(module, '__batch_counter_handle__'): module.__batch_counter_handle__.remove() del module.__batch_counter_handle__ def add_flops_counter_variable_or_reset(module): if is_supported_instance(module): if hasattr(module, '__flops__') or hasattr(module, '__params__'): print('Warning: variables __flops__ or __params__ are already ' 'defined for the module' + type(module).__name__ + ' ptflops can affect your code!') module.__flops__ = 0 module.__params__ = get_model_parameters_number(module) def is_supported_instance(module): if type(module) in get_modules_mapping(): return True return False def remove_flops_counter_hook_function(module): if is_supported_instance(module): if hasattr(module, '__flops_handle__'): module.__flops_handle__.remove() del module.__flops_handle__ def get_modules_mapping(): return { # convolutions nn.Conv1d: conv_flops_counter_hook, nn.Conv2d: conv_flops_counter_hook, mmcv.cnn.bricks.Conv2d: conv_flops_counter_hook, nn.Conv3d: conv_flops_counter_hook, mmcv.cnn.bricks.Conv3d: conv_flops_counter_hook, # activations nn.ReLU: relu_flops_counter_hook, nn.PReLU: relu_flops_counter_hook, nn.ELU: relu_flops_counter_hook, nn.LeakyReLU: relu_flops_counter_hook, nn.ReLU6: relu_flops_counter_hook, # poolings nn.MaxPool1d: pool_flops_counter_hook, nn.AvgPool1d: pool_flops_counter_hook, nn.AvgPool2d: pool_flops_counter_hook, nn.MaxPool2d: pool_flops_counter_hook, mmcv.cnn.bricks.MaxPool2d: pool_flops_counter_hook, nn.MaxPool3d: pool_flops_counter_hook, mmcv.cnn.bricks.MaxPool3d: pool_flops_counter_hook, nn.AvgPool3d: pool_flops_counter_hook, nn.AdaptiveMaxPool1d: pool_flops_counter_hook, nn.AdaptiveAvgPool1d: pool_flops_counter_hook, nn.AdaptiveMaxPool2d: pool_flops_counter_hook, nn.AdaptiveAvgPool2d: pool_flops_counter_hook, nn.AdaptiveMaxPool3d: pool_flops_counter_hook, nn.AdaptiveAvgPool3d: pool_flops_counter_hook, # normalizations nn.BatchNorm1d: norm_flops_counter_hook, nn.BatchNorm2d: norm_flops_counter_hook, nn.BatchNorm3d: norm_flops_counter_hook, nn.GroupNorm: norm_flops_counter_hook, nn.InstanceNorm1d: norm_flops_counter_hook, nn.InstanceNorm2d: norm_flops_counter_hook, nn.InstanceNorm3d: norm_flops_counter_hook, nn.LayerNorm: norm_flops_counter_hook, # FC nn.Linear: linear_flops_counter_hook, mmcv.cnn.bricks.Linear: linear_flops_counter_hook, # Upscale nn.Upsample: upsample_flops_counter_hook, # Deconvolution nn.ConvTranspose2d: deconv_flops_counter_hook, mmcv.cnn.bricks.ConvTranspose2d: deconv_flops_counter_hook, }
Read the Docs v: v1.4.0
Versions
latest
stable
v1.4.0
v1.3.18
v1.3.17
v1.3.16
v1.3.15
v1.3.14
v1.3.13
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.