From 1d76f461f0807016aeb765291128f4e562a6f657 Mon Sep 17 00:00:00 2001 From: AndreasPlt Date: Wed, 30 Oct 2024 11:49:48 +0100 Subject: [PATCH 1/7] fix when using restore_file --- fairseq/training.py | 42 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/fairseq/training.py b/fairseq/training.py index 5142e7a6..8a53bece 100755 --- a/fairseq/training.py +++ b/fairseq/training.py @@ -38,6 +38,13 @@ def __init__( self.package_name = package_name self.check_consistency() + @property + def data(self): + """ + Get the underlying data of the FairseqHydraConfig + """ + return self.config_dict + def write(self, path: str): config_dict = self.config_dict.copy() config_dict = util.update_nested_dict(config_dict, self.post_config_dict) @@ -170,6 +177,24 @@ def __init__( kwargs = locals() del kwargs["self"] + # check for start checkpoint + #if ( + # fairseq_hydra_config.data.get("checkpoint", {}).get("restore_file", None) is not None + # and fairseq_hydra_config.data["checkpoint"]["restore_file"] != "checkpoint_last.pt" + # and os.path.exists(os.path.join(self.output_path("checkpoints", directory=True), "checkpoint_last.pt")) + #): + # # start_checkpoint provided but checkpoint_last.pt exists: start_checkpoint will be ignored + # print( + # "Warning: start_checkpoint will be ignored as checkpoint_last.pt exists in output directory" + # ) + # fairseq_hydra_config.data["checkpoint"]["restore_file"] = "checkpoint_last.pt" + + + self.start_checkpoint = None + if fairseq_hydra_config.data.get("checkpoint", {}).get("restore_file") is not None: + self.start_checkpoint = fairseq_hydra_config.data["checkpoint"]["restore_file"] + fairseq_hydra_config.data["checkpoint"]["restore_file"] = "checkpoint_last.pt" + self.command_line_args = command_line_args or [] stored_epochs = list(range(save_interval, max_epoch, save_interval)) + [max_epoch] @@ -231,9 +256,26 @@ def create_fairseq_hydra_config(cls, fairseq_hydra_config, max_epoch, max_update } res.update(FairseqHydraConfig(config_dict, post_config_dict)) return res + + def _fairseq_prepare_checkpoint(self, start_checkpoint): + # rename the start checkpoint to checkpoint_last.pt if it is not None and checkpoint_last.pt does not exist + if start_checkpoint is None: + return + if not os.path.exists(start_checkpoint): + raise FileNotFoundError(f"Start checkpoint {start_checkpoint} does not exist") + if not os.path.exists(os.path.join(self.out_checkpoint_dir.get_path(), "checkpoint_last.pt")): + os.link( + start_checkpoint, + os.path.join(self.out_checkpoint_dir.get_path(), "checkpoint_last.pt") + ) + os.link( + start_checkpoint, + os.path.join(self.out_checkpoint_dir.get_path(), os.path.basename(start_checkpoint)) + ) def create_files(self): self.fairseq_hydra_config.write(self.out_fairseq_hydra_yaml.get_path()) + self._fairseq_prepare_checkpoint(self.start_checkpoint) util.create_executable("fairseq.sh", self._get_run_cmd()) def run(self): From a7817e9fecfb9c6d6d7acf41782b6ee5289b6181 Mon Sep 17 00:00:00 2001 From: AndreasPlt Date: Mon, 11 Nov 2024 16:24:10 +0100 Subject: [PATCH 2/7] add some prints --- fairseq/training.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/fairseq/training.py b/fairseq/training.py index 8a53bece..a6a8cd5c 100755 --- a/fairseq/training.py +++ b/fairseq/training.py @@ -260,10 +260,12 @@ def create_fairseq_hydra_config(cls, fairseq_hydra_config, max_epoch, max_update def _fairseq_prepare_checkpoint(self, start_checkpoint): # rename the start checkpoint to checkpoint_last.pt if it is not None and checkpoint_last.pt does not exist if start_checkpoint is None: + print("No start checkpoint provided") return if not os.path.exists(start_checkpoint): raise FileNotFoundError(f"Start checkpoint {start_checkpoint} does not exist") if not os.path.exists(os.path.join(self.out_checkpoint_dir.get_path(), "checkpoint_last.pt")): + print(f"Linking {start_checkpoint} to {self.out_checkpoint_dir.get_path()}") os.link( start_checkpoint, os.path.join(self.out_checkpoint_dir.get_path(), "checkpoint_last.pt") From 6812f6fc91df934a60abb232a78c33140b9b3f85 Mon Sep 17 00:00:00 2001 From: Andreas Date: Mon, 11 Nov 2024 16:29:47 +0100 Subject: [PATCH 3/7] comment clean up --- fairseq/training.py | 14 +------------- 1 file changed, 1 insertion(+), 13 deletions(-) diff --git a/fairseq/training.py b/fairseq/training.py index a6a8cd5c..188d150b 100755 --- a/fairseq/training.py +++ b/fairseq/training.py @@ -177,19 +177,7 @@ def __init__( kwargs = locals() del kwargs["self"] - # check for start checkpoint - #if ( - # fairseq_hydra_config.data.get("checkpoint", {}).get("restore_file", None) is not None - # and fairseq_hydra_config.data["checkpoint"]["restore_file"] != "checkpoint_last.pt" - # and os.path.exists(os.path.join(self.output_path("checkpoints", directory=True), "checkpoint_last.pt")) - #): - # # start_checkpoint provided but checkpoint_last.pt exists: start_checkpoint will be ignored - # print( - # "Warning: start_checkpoint will be ignored as checkpoint_last.pt exists in output directory" - # ) - # fairseq_hydra_config.data["checkpoint"]["restore_file"] = "checkpoint_last.pt" - - + # save start checkpoint and rename to checkpoint_latest self.start_checkpoint = None if fairseq_hydra_config.data.get("checkpoint", {}).get("restore_file") is not None: self.start_checkpoint = fairseq_hydra_config.data["checkpoint"]["restore_file"] From 51882e1d86fca53e3df712c2b67ab1f680146585 Mon Sep 17 00:00:00 2001 From: Andreas Date: Thu, 14 Nov 2024 13:36:07 +0100 Subject: [PATCH 4/7] change link to symlink --- fairseq/training.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fairseq/training.py b/fairseq/training.py index 188d150b..d86ee708 100755 --- a/fairseq/training.py +++ b/fairseq/training.py @@ -254,11 +254,11 @@ def _fairseq_prepare_checkpoint(self, start_checkpoint): raise FileNotFoundError(f"Start checkpoint {start_checkpoint} does not exist") if not os.path.exists(os.path.join(self.out_checkpoint_dir.get_path(), "checkpoint_last.pt")): print(f"Linking {start_checkpoint} to {self.out_checkpoint_dir.get_path()}") - os.link( + os.symlink( start_checkpoint, os.path.join(self.out_checkpoint_dir.get_path(), "checkpoint_last.pt") ) - os.link( + os.symlink( start_checkpoint, os.path.join(self.out_checkpoint_dir.get_path(), os.path.basename(start_checkpoint)) ) From e1191ee1e16ae133c305016852ce873a330453b9 Mon Sep 17 00:00:00 2001 From: AndreasPlt Date: Wed, 30 Oct 2024 11:49:48 +0100 Subject: [PATCH 5/7] fix when using restore_file --- fairseq/training.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/fairseq/training.py b/fairseq/training.py index d86ee708..8873fbd0 100755 --- a/fairseq/training.py +++ b/fairseq/training.py @@ -177,7 +177,23 @@ def __init__( kwargs = locals() del kwargs["self"] +<<<<<<< HEAD # save start checkpoint and rename to checkpoint_latest +======= + # check for start checkpoint + #if ( + # fairseq_hydra_config.data.get("checkpoint", {}).get("restore_file", None) is not None + # and fairseq_hydra_config.data["checkpoint"]["restore_file"] != "checkpoint_last.pt" + # and os.path.exists(os.path.join(self.output_path("checkpoints", directory=True), "checkpoint_last.pt")) + #): + # # start_checkpoint provided but checkpoint_last.pt exists: start_checkpoint will be ignored + # print( + # "Warning: start_checkpoint will be ignored as checkpoint_last.pt exists in output directory" + # ) + # fairseq_hydra_config.data["checkpoint"]["restore_file"] = "checkpoint_last.pt" + + +>>>>>>> c5d0a9f (fix when using restore_file) self.start_checkpoint = None if fairseq_hydra_config.data.get("checkpoint", {}).get("restore_file") is not None: self.start_checkpoint = fairseq_hydra_config.data["checkpoint"]["restore_file"] From a37e00654cc571e805a21e5a89591e7489b21563 Mon Sep 17 00:00:00 2001 From: Andreas Date: Mon, 11 Nov 2024 16:29:47 +0100 Subject: [PATCH 6/7] comment clean up --- fairseq/training.py | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/fairseq/training.py b/fairseq/training.py index 8873fbd0..d86ee708 100755 --- a/fairseq/training.py +++ b/fairseq/training.py @@ -177,23 +177,7 @@ def __init__( kwargs = locals() del kwargs["self"] -<<<<<<< HEAD # save start checkpoint and rename to checkpoint_latest -======= - # check for start checkpoint - #if ( - # fairseq_hydra_config.data.get("checkpoint", {}).get("restore_file", None) is not None - # and fairseq_hydra_config.data["checkpoint"]["restore_file"] != "checkpoint_last.pt" - # and os.path.exists(os.path.join(self.output_path("checkpoints", directory=True), "checkpoint_last.pt")) - #): - # # start_checkpoint provided but checkpoint_last.pt exists: start_checkpoint will be ignored - # print( - # "Warning: start_checkpoint will be ignored as checkpoint_last.pt exists in output directory" - # ) - # fairseq_hydra_config.data["checkpoint"]["restore_file"] = "checkpoint_last.pt" - - ->>>>>>> c5d0a9f (fix when using restore_file) self.start_checkpoint = None if fairseq_hydra_config.data.get("checkpoint", {}).get("restore_file") is not None: self.start_checkpoint = fairseq_hydra_config.data["checkpoint"]["restore_file"] From 61711824ac274bb5cf36a0370a69949db276da66 Mon Sep 17 00:00:00 2001 From: AndreasPlt <107055616+AndreasPlt@users.noreply.github.com> Date: Thu, 14 Nov 2024 15:58:56 +0100 Subject: [PATCH 7/7] Update fairseq/training.py Co-authored-by: michelwi --- fairseq/training.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/fairseq/training.py b/fairseq/training.py index d86ee708..54bc6496 100755 --- a/fairseq/training.py +++ b/fairseq/training.py @@ -178,10 +178,7 @@ def __init__( del kwargs["self"] # save start checkpoint and rename to checkpoint_latest - self.start_checkpoint = None - if fairseq_hydra_config.data.get("checkpoint", {}).get("restore_file") is not None: - self.start_checkpoint = fairseq_hydra_config.data["checkpoint"]["restore_file"] - fairseq_hydra_config.data["checkpoint"]["restore_file"] = "checkpoint_last.pt" + self.start_checkpoint = fairseq_hydra_config.data.get("checkpoint", {}).pop("restore_file") self.command_line_args = command_line_args or [] stored_epochs = list(range(save_interval, max_epoch, save_interval)) + [max_epoch]