Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the multi-output, dict-input, parameter counting and calculation overflow problem. #165

Open
wants to merge 5 commits into
base: master
Choose a base branch
from

Conversation

cainmagi
Copy link

@cainmagi cainmagi commented Feb 27, 2021

Update report

  1. Fix the bug of parameter number calculation when there are more than one output variables, including both sequence case and dict case (mentioned in Cannot get the summary #162).
  2. Make multiple output variables split into multiple lines.
  3. Remove the last line break of summary_string().
  4. Enable argument device to accept both str and torch.device.
  5. Fix a bug when the model requires batch_size to be a specific number.
  6. Fix a bug caused by multiple input cases when dtypes=None.
  7. Add text auto wrap when the layer name is too long.
  8. Support counting all parameters instead of weight and bias (a different solution of Fix parameter count #142, the package does not count "torch.nn.parameter" #148).
  9. Drop the np.sum/prod to fix the overflow problem during calculating the total size (mentioned in RuntimeWarning: overflow encountered in long_scalars #158).
  10. Fix the bug caused by layers with dict input values (mentioned in Cannot get the summary #162).
  11. Add docstring.

Example for verifying this update

The following code is not compatible with the base repository:

import torch
import torch.nn as nn
from torchsummary import summary

class VeryLongNameSimpleMultiConv(nn.Module):
    def __init__(self):
        super(VeryLongNameSimpleMultiConv, self).__init__()
        self.features_1 = nn.Sequential(
            nn.Conv2d(1, 1, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
        )
        self.features_2 = nn.Sequential(
            nn.Conv2d(1, 2, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
        )

    def forward(self, x):
        x1 = self.features_1(x)
        x2 = self.features_2(x)
        return x1, x2
    
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = VeryLongNameSimpleMultiConv().to(device)

summary(model, (1, 16, 16))

Now the output is:

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1            [-1, 1, 16, 16]              10
              ReLU-2            [-1, 1, 16, 16]               0
            Conv2d-3            [-1, 2, 16, 16]              20
              ReLU-4            [-1, 2, 16, 16]               0
VeryLong...ltiConv-5            [-1, 1, 16, 16]               0
                                [-1, 2, 16, 16]
================================================================
Total params: 30
Trainable params: 30
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.02
Params size (MB): 0.00
Estimated Total Size (MB): 0.02
----------------------------------------------------------------

cainmagi and others added 2 commits February 27, 2021 01:21
1. Fix the bug of parameter number calculation when there are more than one output variables, including both sequence case and dict case.
2. Make multuple output variables split into multiple lines.
3. Remove the last line break of summary_string()
4. Enable argument "device" to accept both str and torch.device.
5. Fix a bug when the model requires "batch_size" to be a specific number.
6. Fix a bug caused by multiple input case when "dtypes=None".
7. Add text auto wrap when the layer name is too long.
8. Add docstring.
Support counting all parameters instead of `weight` and `bias`.
@cainmagi cainmagi changed the title Fix the multi-output problem. Fix the multi-output and parameter counting problem. Feb 28, 2021
Using numpy sum/prod to calculate the total size may cause overflow problem. This modification would drop the numpy and use the python built-in method to calculate the size.
@cainmagi cainmagi changed the title Fix the multi-output and parameter counting problem. Fix the multi-output, parameter counting and calculation overflow problem. Feb 28, 2021
Fix the bug caused by layers with dict input values.
@cainmagi cainmagi changed the title Fix the multi-output, parameter counting and calculation overflow problem. Fix the multi-output, dict-input, parameter counting and calculation overflow problem. Feb 28, 2021
Fix the data type of the output params_info from torch.tensor to int.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant