Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
svekars committed Aug 27, 2024
1 parent 4c36093 commit dc0de9e
Showing 1 changed file with 11 additions and 1 deletion.
12 changes: 11 additions & 1 deletion beginner_source/basics/saveloadrun_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@
##########################
# To load model weights, you need to create an instance of the same model first, and then load the parameters
# using ``load_state_dict()`` method.
#
# In the code below, we set ``weights_only=True`` to limit the
# functions executed during unpickling to only those necessary for
# loading weights. Using ``weights_only=True`` is considered
# a best practice when loading weights.

model = models.vgg16() # we do not specify ``weights``, i.e. create untrained model
model.load_state_dict(torch.load('model_weights.pth', weights_only=True))
Expand All @@ -50,7 +55,12 @@
torch.save(model, 'model.pth')

########################
# We can then load the model like this:
# We can then load the model as demonstrated below.
#
# As described in `Saving and loading torch.nn.Modules <pytorch.org/docs/main/notes/serialization.html#saving-and-loading-torch-nn-modules>`__,
# saving ``state_dict``s is considered the best practice. However,
# below we use ``weights_only=False`` because this involves loading the
# model, which is a legacy use case for ``torch.save``.

model = torch.load('model.pth', weights_only=False),

Expand Down

0 comments on commit dc0de9e

Please sign in to comment.