Skip to content

Commit

Permalink
hubconf: Passing exportable parameter to timm model
Browse files Browse the repository at this point in the history
  • Loading branch information
paulgavrikov committed Feb 2, 2021
1 parent 9674c1d commit 1ea4eaf
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions hubconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
'mealv2_efficientnet_b0':'tf_efficientnet_b0'
}

def meal_v2(model_name, pretrained=True, progress=True):
def meal_v2(model_name, pretrained=True, progress=True, exportable=False):
""" MEAL V2 models from
`"MEAL V2: Boosting Vanilla ResNet-50 to 80%+ Top-1 Accuracy on ImageNet without Tricks" <https://arxiv.org/pdf/2009.08453.pdf>`_
Expand All @@ -39,7 +39,7 @@ def meal_v2(model_name, pretrained=True, progress=True):
progress (bool): If True, displays a progress bar of the download to stderr
"""

model = timm.create_model(mapping[model_name.lower()], pretrained=False)
model = timm.create_model(mapping[model_name.lower()], pretrained=False, exportable=exportable)
if pretrained:
state_dict = torch.hub.load_state_dict_from_url(model_urls[model_name.lower()], progress=progress)
model = torch.nn.DataParallel(model).cuda()
Expand Down

0 comments on commit 1ea4eaf

Please sign in to comment.