From 1f7bc8adf77ee676db04eb3074e512f810d98ebb Mon Sep 17 00:00:00 2001 From: vieting <45091115+vieting@users.noreply.github.com> Date: Wed, 26 Jul 2023 16:32:16 +0200 Subject: [PATCH] add GetBestPtCheckpointJob (#432) --- returnn/training.py | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/returnn/training.py b/returnn/training.py index b98feaf1..83649f6a 100644 --- a/returnn/training.py +++ b/returnn/training.py @@ -3,6 +3,7 @@ "Checkpoint", "GetBestEpochJob", "GetBestTFCheckpointJob", + "GetBestPtCheckpointJob", "PtCheckpoint", "ReturnnModel", "ReturnnTrainingFromFileJob", @@ -797,6 +798,40 @@ def run(self): ) +class GetBestPtCheckpointJob(GetBestEpochJob): + """ + Analog to GetBestTFCheckpointJob, just for torch checkpoints. + """ + + def __init__(self, model_dir: tk.Path, learning_rates: tk.Path, key: str, index: int = 0): + """ + + :param Path model_dir: model_dir output from a ReturnnTrainingJob + :param Path learning_rates: learning_rates output from a ReturnnTrainingJob + :param str key: a key from the learning rate file that is used to sort the models + e.g. "dev_score_output/output_prob" + :param int index: index of the sorted list to access, 0 for the lowest, -1 for the highest score + """ + super().__init__(model_dir, learning_rates, key, index) + self.out_checkpoint = PtCheckpoint(self.output_path("checkpoint.pt")) + + def run(self): + super().run() + + try: + os.link( + os.path.join(self.model_dir.get_path(), "epoch.%.3d.pt" % self.out_epoch.get()), + self.out_checkpoint.path, + ) + except OSError: + # the hardlink will fail when there was an imported job on a different filesystem, + # thus do a copy instead then + shutil.copy( + os.path.join(self.model_dir.get_path(), "epoch.%.3d.pt" % self.out_epoch.get()), + self.out_checkpoint.path, + ) + + class AverageTFCheckpointsJob(Job): """ Compute the average of multiple specified Tensorflow checkpoints using the tf_avg_checkpoints script from Returnn