this version is modified from https://github.com/Lyken17/pytorch-OpCounter.git
. I replace the original list input into a dict() input. In this way, you can deal with more complicate situation esier. like
# the original model forward() function.
def forward(self, num_modalities, img_meta, return_loss=True, **kwargs):
# count the flops with dict() format input.
flops, params = profile(model, inputs=dict(return_loss=False, **data))
ps: you can count the flops and params more easier in "mmaction" or "mmdet" style code with this method.
'git clone https://github.com/ziming-liu/pytorch-OpCounter-dict-input'
and then
cd pytorch-OpCounter-dict-input
chmod 777 build.sh
./build.sh
# pip install the missing package if there is error
python setup.py develop
pip intall .
-
Basic usage
from torchvision.models import resnet50 from thop import profile model = resnet50() input = torch.randn(1, 3, 224, 224) macs, params = profile(model, inputs=dict(input, ))
-
Define the rule for 3rd party module.
class YourModule(nn.Module): # your definition def count_your_model(model, x, y): # your rule here input = torch.randn(1, 3, 224, 224) macs, params = profile(model, inputs=dict(input, ), custom_ops={YourModule: count_your_model})
-
Improve the output readability
Call
thop.clever_format
to give a better format of the output.from thop import clever_format macs, params = clever_format([flops, params], "%.3f")
The implementation are adapted from torchvision
. Following results can be obtained using benchmark/evaluate_famours_models.py.
|
|