Skip to content
Snippets Groups Projects
Commit ceaca2c9 authored by Yifan Zhao's avatar Yifan Zhao
Browse files

Improved FLOPs summary implementation

parent ad1fc835
No related branches found
No related tags found
No related merge requests found
from collections import OrderedDict
from typing import Tuple
from typing import Iterable, Tuple
import pandas
import torch
import torch.nn as nn
from .indexing import ModuleIndexer
_summary_used = False
def get_flops(module: nn.Module, input_shape, output_shape):
if output_shape is None:
return None
n_elem = torch.prod(torch.tensor(output_shape)).item()
if isinstance(module, nn.Linear):
if input_shape is None:
return None
_, n = input_shape
# Partially following impl here:
# https://github.com/juliagusak/flopco-pytorch/blob/c9679785d802f4984c9c5e5d47958e3b82044ce9/flopco/compute_layer_flops.py
from torchvision.models.detection.transform import GeneralizedRCNNTransform
def linear_flops():
m, n = input_shape
k, n_ = module.weight.shape
assert n == n_
return n * n * k
if isinstance(module, nn.Conv2d):
return m * n * k
def conv2d_flops():
_, _, h, w = output_shape
return module.weight.numel() * h * w
if isinstance(module, nn.BatchNorm2d):
return 6 * n_elem
return None
def pool2d_flops():
ksize = module.kernel_size
if isinstance(ksize, int):
ksize = ksize, ksize
k_area = ksize[0] * ksize[1]
return k_area * _get_numel(output_shape)
def ntimes_input_numel(n: int):
return lambda: n * _get_numel(input_shape)
def ntimes_output_numel(n: int):
return lambda: n * _get_numel(output_shape)
type_dispatch = {
nn.Linear: linear_flops,
nn.Conv2d: conv2d_flops,
nn.BatchNorm2d: ntimes_output_numel(6),
nn.ReLU: ntimes_output_numel(1),
nn.AvgPool2d: pool2d_flops,
nn.MaxPool2d: pool2d_flops,
# Resize is likely more than 1x input size, but let's go with that.
GeneralizedRCNNTransform: ntimes_input_numel(2),
}
handler = type_dispatch.get(type(module))
if not handler:
if not list(module.children()):
_print_once(f"Leaf module {module} cannot be handled")
return None
try:
return handler()
except RuntimeError as e:
_print_once(f'Error "{e}" when handling {module}')
return None
def get_summary(model: nn.Module, model_args: Tuple) -> pandas.DataFrame:
from torchvision.ops.feature_pyramid_network import LastLevelMaxPool
include = lambda m: (
not isinstance(m, nn.Sequential) and not isinstance(m, nn.ModuleList) and not (m == model)
not isinstance(m, nn.Sequential)
and not isinstance(m, nn.ModuleList)
and not (m == model)
)
indexed = ModuleIndexer(model, include, lambda m: True)
find_by_module = lambda m: indexed.find_by_module(m)[0]
summary = OrderedDict()
hooks = []
special_ops = {LastLevelMaxPool: last_level_max_pool_io}
def hook(module: nn.Module, inputs, outputs):
module_name = find_by_module(module)
try:
input_shape = list(inputs[0].size())
except AttributeError:
input_shape = None
try:
if isinstance(outputs, (list, tuple)):
output_shape = [[-1] + list(o.size())[1:] for o in outputs]
else:
output_shape = list(outputs.size())
except AttributeError:
output_shape = None
special_handler = special_ops.get(type(module))
if special_handler:
input_shape, output_shape, flops = special_handler(module, inputs, outputs)
else:
input_shape, output_shape, flops = default_io(module, inputs, outputs)
n_params = sum(param.numel() for param in module.parameters())
trainable = any(param.requires_grad for param in module.parameters())
is_leaf = not list(module.children())
summary[module_name] = OrderedDict(
type=module.__class__.__name__,
input_shape=input_shape,
output_shape=output_shape,
params=n_params,
flops=get_flops(module, input_shape, output_shape),
trainable=trainable
flops=flops,
trainable=trainable,
is_leaf=is_leaf
)
def register_hook(module: nn.Module):
......@@ -74,4 +106,44 @@ def get_summary(model: nn.Module, model_args: Tuple) -> pandas.DataFrame:
# remove these hooks
for h in hooks:
h.remove()
return pandas.DataFrame(summary)
global _summary_used
_summary_used = True # Prevent further error printing
return pandas.DataFrame(summary).T
def last_level_max_pool_io(_, inputs, outputs):
input_shapes = [list(i.size()) for i in inputs[0]]
output_shapes = [list(o.size()) for o in outputs[0]]
total_numel = sum([_get_numel(s) for s in input_shapes])
return input_shapes, output_shapes, total_numel
def default_handle_sizes(value):
try:
if isinstance(value, torch.Tensor):
return list(value.size())
if isinstance(value, dict):
return {k: list(v.size()) for k, v in value.items()}
if isinstance(value, Iterable):
return [list(i.size()) for i in value]
except AttributeError as e:
_print_once(f"Cannot handle {type(value)}: error {e}")
return None
_print_once(f"Cannot handle {type(value)}")
return None
def default_io(module: nn.Module, inputs, outputs):
input_shape = default_handle_sizes(inputs[0])
output_shape = default_handle_sizes(outputs)
return input_shape, output_shape, get_flops(module, input_shape, output_shape)
def _get_numel(shape):
return torch.prod(torch.tensor(shape)).item()
def _print_once(*args, **kwargs):
if _summary_used:
return
print(*args, **kwargs)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment