From f331b3b26976b9e21d4bb67ff3248c84390e5dbb Mon Sep 17 00:00:00 2001 From: KevinMusgrave Date: Mon, 4 Jul 2022 15:06:35 -0400 Subject: [PATCH] Fixes issue #74 (a newer version of pytorch-ignite simplifies the loading function --- setup.py | 2 +- src/pytorch_adapt/__init__.py | 2 +- .../frameworks/ignite/checkpoint_utils.py | 22 ++++++------------- 3 files changed, 9 insertions(+), 17 deletions(-) diff --git a/setup.py b/setup.py index 1e33b148..a8d8dd22 100644 --- a/setup.py +++ b/setup.py @@ -9,7 +9,7 @@ long_description = fh.read() -extras_require_ignite = ["pytorch-ignite == 0.5.0.dev20220221"] +extras_require_ignite = ["pytorch-ignite >= 0.4.9"] extras_require_lightning = ["pytorch-lightning"] extras_require_record_keeper = ["record-keeper >= 0.9.31"] extras_require_timm = ["timm"] diff --git a/src/pytorch_adapt/__init__.py b/src/pytorch_adapt/__init__.py index 7715123d..a981d111 100644 --- a/src/pytorch_adapt/__init__.py +++ b/src/pytorch_adapt/__init__.py @@ -1 +1 @@ -__version__ = "0.0.75" +__version__ = "0.0.76" diff --git a/src/pytorch_adapt/frameworks/ignite/checkpoint_utils.py b/src/pytorch_adapt/frameworks/ignite/checkpoint_utils.py index 72459018..cd5bab11 100644 --- a/src/pytorch_adapt/frameworks/ignite/checkpoint_utils.py +++ b/src/pytorch_adapt/frameworks/ignite/checkpoint_utils.py @@ -83,22 +83,14 @@ def fn(engine): return fn - def load_objects(self, to_load, checkpoint=None, global_step=None, **kwargs): - # This can be simplified once this issue is resolved https://github.com/pytorch/ignite/issues/2480 - if global_step is not None: - dirname = self.objs.save_handler.dirname - filename_dict = { - "filename_prefix": self.objs.filename_prefix, - "name": "checkpoint", - "ext": self.objs.ext, - "score_name": self.objs.score_name, - "global_step": global_step, - } - filename = self.objs.filename_pattern.format(**filename_dict) - checkpoint = os.path.join(dirname, filename) - + def load_objects(self, to_load, checkpoint=None, global_step=None): to_load = {k: v for k, v in to_load.items() if v} - self.objs.load_objects(to_load, str(checkpoint), **kwargs) + if global_step is not None: + self.objs.reload_objects( + to_load, name="checkpoint", global_step=global_step + ) + else: + self.objs.load_objects(to_load, str(checkpoint)) def load_best_checkpoint(self, to_load): last_checkpoint = self.get_best_checkpoint()