From 65181794ae47dc87e54c7869025524b71419dd32 Mon Sep 17 00:00:00 2001 From: leriomaggio Date: Mon, 9 Mar 2020 09:51:28 +0000 Subject: [PATCH] Removed unused import, PEP8 reformatting, and more robust checking of None for dtype --- torchsummary/.idea/.gitignore | 2 ++ torchsummary/torchsummary.py | 11 ++++------- 2 files changed, 6 insertions(+), 7 deletions(-) create mode 100644 torchsummary/.idea/.gitignore diff --git a/torchsummary/.idea/.gitignore b/torchsummary/.idea/.gitignore new file mode 100644 index 0000000..e7e9d11 --- /dev/null +++ b/torchsummary/.idea/.gitignore @@ -0,0 +1,2 @@ +# Default ignored files +/workspace.xml diff --git a/torchsummary/torchsummary.py b/torchsummary/torchsummary.py index 1ed065f..08ddc19 100644 --- a/torchsummary/torchsummary.py +++ b/torchsummary/torchsummary.py @@ -1,7 +1,5 @@ import torch import torch.nn as nn -from torch.autograd import Variable - from collections import OrderedDict import numpy as np @@ -10,13 +8,12 @@ def summary(model, input_size, batch_size=-1, device=torch.device('cuda:0'), dty result, params_info = summary_string( model, input_size, batch_size, device, dtypes) print(result) - return params_info def summary_string(model, input_size, batch_size=-1, device=torch.device('cuda:0'), dtypes=None): - if dtypes == None: - dtypes = [torch.FloatTensor]*len(input_size) + if dtypes is None: + dtypes = [torch.FloatTensor] * len(input_size) summary_str = '' @@ -46,8 +43,8 @@ def hook(module, input, output): summary[m_key]["nb_params"] = params if ( - not isinstance(module, nn.Sequential) - and not isinstance(module, nn.ModuleList) + not isinstance(module, nn.Sequential) + and not isinstance(module, nn.ModuleList) ): hooks.append(module.register_forward_hook(hook))