From fe497f6b0dee27fe0f7f1c435e8142a79968da41 Mon Sep 17 00:00:00 2001 From: Peter Vieting Date: Tue, 25 Jul 2023 17:25:58 +0200 Subject: [PATCH 1/3] add GetBestPtCheckpointJob --- returnn/training.py | 49 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/returnn/training.py b/returnn/training.py index b98feaf1..99efd9ab 100644 --- a/returnn/training.py +++ b/returnn/training.py @@ -3,6 +3,7 @@ "Checkpoint", "GetBestEpochJob", "GetBestTFCheckpointJob", + "GetBestPtCheckpointJob", "PtCheckpoint", "ReturnnModel", "ReturnnTrainingFromFileJob", @@ -797,6 +798,54 @@ def run(self): ) +class GetBestPtCheckpointJob(GetBestEpochJob): + """ + Analog to GetBestTFCheckpointJob, just for torch checkpoints. + """ + + def __init__(self, model_dir: tk.Path, learning_rates: tk.Path, key: str, index: int = 0): + """ + + :param Path model_dir: model_dir output from a RETURNNTrainingJob + :param Path learning_rates: learning_rates output from a RETURNNTrainingJob + :param str key: a key from the learning rate file that is used to sort the models + e.g. "dev_score_output/output_prob" + :param int index: index of the sorted list to access, 0 for the lowest, -1 for the highest score + """ + super().__init__(model_dir, learning_rates, key, index) + self._out_model_dir = self.output_path("model", directory=True) + + # Note: checkpoint.pt (without epoch number) is only a symlink which is possibly resolved by RETURNN + self.out_checkpoint = PtCheckpoint(self.output_path("model/checkpoint.pt")) + + def run(self): + super().run() + + try: + os.link( + os.path.join(self.model_dir.get_path(), "epoch.%.3d.pt" % self.out_epoch.get()), + os.path.join( + self._out_model_dir.get_path(), + "epoch.%.3d.pt" % self.out_epoch.get(), + ), + ) + except OSError: + # the hardlink will fail when there was an imported job on a different filesystem, + # thus do a copy instead then + shutil.copy( + os.path.join(self.model_dir.get_path(), "epoch.%.3d.pt" % self.out_epoch.get()), + os.path.join( + self._out_model_dir.get_path(), + "epoch.%.3d.pt" % self.out_epoch.get(), + ), + ) + + os.symlink( + os.path.join(self._out_model_dir.get_path(), "epoch.%.3d.pt" % self.out_epoch.get()), + os.path.join(self._out_model_dir.get_path(), "checkpoint.pt"), + ) + + class AverageTFCheckpointsJob(Job): """ Compute the average of multiple specified Tensorflow checkpoints using the tf_avg_checkpoints script from Returnn From 5eee48f41f4f3f48d0cb3f4a71bab4a719191084 Mon Sep 17 00:00:00 2001 From: Peter Vieting Date: Tue, 25 Jul 2023 18:19:00 +0200 Subject: [PATCH 2/3] simplify torch checkpoint logic --- returnn/training.py | 24 +++++------------------- 1 file changed, 5 insertions(+), 19 deletions(-) diff --git a/returnn/training.py b/returnn/training.py index 99efd9ab..0c612eb9 100644 --- a/returnn/training.py +++ b/returnn/training.py @@ -806,17 +806,14 @@ class GetBestPtCheckpointJob(GetBestEpochJob): def __init__(self, model_dir: tk.Path, learning_rates: tk.Path, key: str, index: int = 0): """ - :param Path model_dir: model_dir output from a RETURNNTrainingJob - :param Path learning_rates: learning_rates output from a RETURNNTrainingJob + :param Path model_dir: model_dir output from a ReturnnTrainingJob + :param Path learning_rates: learning_rates output from a ReturnnTrainingJob :param str key: a key from the learning rate file that is used to sort the models e.g. "dev_score_output/output_prob" :param int index: index of the sorted list to access, 0 for the lowest, -1 for the highest score """ super().__init__(model_dir, learning_rates, key, index) - self._out_model_dir = self.output_path("model", directory=True) - - # Note: checkpoint.pt (without epoch number) is only a symlink which is possibly resolved by RETURNN - self.out_checkpoint = PtCheckpoint(self.output_path("model/checkpoint.pt")) + self.out_checkpoint = PtCheckpoint(self.output_path("checkpoint.pt")) def run(self): super().run() @@ -824,27 +821,16 @@ def run(self): try: os.link( os.path.join(self.model_dir.get_path(), "epoch.%.3d.pt" % self.out_epoch.get()), - os.path.join( - self._out_model_dir.get_path(), - "epoch.%.3d.pt" % self.out_epoch.get(), - ), + self.out_checkpoint.path ) except OSError: # the hardlink will fail when there was an imported job on a different filesystem, # thus do a copy instead then shutil.copy( os.path.join(self.model_dir.get_path(), "epoch.%.3d.pt" % self.out_epoch.get()), - os.path.join( - self._out_model_dir.get_path(), - "epoch.%.3d.pt" % self.out_epoch.get(), - ), + self.out_checkpoint.path ) - os.symlink( - os.path.join(self._out_model_dir.get_path(), "epoch.%.3d.pt" % self.out_epoch.get()), - os.path.join(self._out_model_dir.get_path(), "checkpoint.pt"), - ) - class AverageTFCheckpointsJob(Job): """ From d83364cb87dba8b41871c80f307ec003e1a4bc5c Mon Sep 17 00:00:00 2001 From: Peter Vieting Date: Tue, 25 Jul 2023 18:21:38 +0200 Subject: [PATCH 3/3] black --- returnn/training.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/returnn/training.py b/returnn/training.py index 0c612eb9..83649f6a 100644 --- a/returnn/training.py +++ b/returnn/training.py @@ -821,14 +821,14 @@ def run(self): try: os.link( os.path.join(self.model_dir.get_path(), "epoch.%.3d.pt" % self.out_epoch.get()), - self.out_checkpoint.path + self.out_checkpoint.path, ) except OSError: # the hardlink will fail when there was an imported job on a different filesystem, # thus do a copy instead then shutil.copy( os.path.join(self.model_dir.get_path(), "epoch.%.3d.pt" % self.out_epoch.get()), - self.out_checkpoint.path + self.out_checkpoint.path, )