diff --git a/fairseq/training.py b/fairseq/training.py index 5142e7a6..54bc6496 100755 --- a/fairseq/training.py +++ b/fairseq/training.py @@ -38,6 +38,13 @@ def __init__( self.package_name = package_name self.check_consistency() + @property + def data(self): + """ + Get the underlying data of the FairseqHydraConfig + """ + return self.config_dict + def write(self, path: str): config_dict = self.config_dict.copy() config_dict = util.update_nested_dict(config_dict, self.post_config_dict) @@ -170,6 +177,9 @@ def __init__( kwargs = locals() del kwargs["self"] + # save start checkpoint and rename to checkpoint_latest + self.start_checkpoint = fairseq_hydra_config.data.get("checkpoint", {}).pop("restore_file") + self.command_line_args = command_line_args or [] stored_epochs = list(range(save_interval, max_epoch, save_interval)) + [max_epoch] @@ -231,9 +241,28 @@ def create_fairseq_hydra_config(cls, fairseq_hydra_config, max_epoch, max_update } res.update(FairseqHydraConfig(config_dict, post_config_dict)) return res + + def _fairseq_prepare_checkpoint(self, start_checkpoint): + # rename the start checkpoint to checkpoint_last.pt if it is not None and checkpoint_last.pt does not exist + if start_checkpoint is None: + print("No start checkpoint provided") + return + if not os.path.exists(start_checkpoint): + raise FileNotFoundError(f"Start checkpoint {start_checkpoint} does not exist") + if not os.path.exists(os.path.join(self.out_checkpoint_dir.get_path(), "checkpoint_last.pt")): + print(f"Linking {start_checkpoint} to {self.out_checkpoint_dir.get_path()}") + os.symlink( + start_checkpoint, + os.path.join(self.out_checkpoint_dir.get_path(), "checkpoint_last.pt") + ) + os.symlink( + start_checkpoint, + os.path.join(self.out_checkpoint_dir.get_path(), os.path.basename(start_checkpoint)) + ) def create_files(self): self.fairseq_hydra_config.write(self.out_fairseq_hydra_yaml.get_path()) + self._fairseq_prepare_checkpoint(self.start_checkpoint) util.create_executable("fairseq.sh", self._get_run_cmd()) def run(self):