Skip to content

Commit

Permalink
Merge pull request #77 from KevinMusgrave/dev
Browse files Browse the repository at this point in the history
v0.0.76
  • Loading branch information
Kevin Musgrave authored Jul 4, 2022
2 parents 448c7d5 + f331b3b commit 9d54fb0
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 17 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
long_description = fh.read()


extras_require_ignite = ["pytorch-ignite == 0.5.0.dev20220221"]
extras_require_ignite = ["pytorch-ignite >= 0.4.9"]
extras_require_lightning = ["pytorch-lightning"]
extras_require_record_keeper = ["record-keeper >= 0.9.31"]
extras_require_timm = ["timm"]
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_adapt/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.0.75"
__version__ = "0.0.76"
22 changes: 7 additions & 15 deletions src/pytorch_adapt/frameworks/ignite/checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,22 +83,14 @@ def fn(engine):

return fn

def load_objects(self, to_load, checkpoint=None, global_step=None, **kwargs):
# This can be simplified once this issue is resolved https://github.com/pytorch/ignite/issues/2480
if global_step is not None:
dirname = self.objs.save_handler.dirname
filename_dict = {
"filename_prefix": self.objs.filename_prefix,
"name": "checkpoint",
"ext": self.objs.ext,
"score_name": self.objs.score_name,
"global_step": global_step,
}
filename = self.objs.filename_pattern.format(**filename_dict)
checkpoint = os.path.join(dirname, filename)

def load_objects(self, to_load, checkpoint=None, global_step=None):
to_load = {k: v for k, v in to_load.items() if v}
self.objs.load_objects(to_load, str(checkpoint), **kwargs)
if global_step is not None:
self.objs.reload_objects(
to_load, name="checkpoint", global_step=global_step
)
else:
self.objs.load_objects(to_load, str(checkpoint))

def load_best_checkpoint(self, to_load):
last_checkpoint = self.get_best_checkpoint()
Expand Down

0 comments on commit 9d54fb0

Please sign in to comment.