Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add GetBestPtCheckpointJob #432

Merged
merged 3 commits into from
Jul 26, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions returnn/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"Checkpoint",
"GetBestEpochJob",
"GetBestTFCheckpointJob",
"GetBestPtCheckpointJob",
"PtCheckpoint",
"ReturnnModel",
"ReturnnTrainingFromFileJob",
Expand Down Expand Up @@ -797,6 +798,54 @@ def run(self):
)


class GetBestPtCheckpointJob(GetBestEpochJob):
"""
Analog to GetBestTFCheckpointJob, just for torch checkpoints.
"""

def __init__(self, model_dir: tk.Path, learning_rates: tk.Path, key: str, index: int = 0):
"""

:param Path model_dir: model_dir output from a RETURNNTrainingJob
:param Path learning_rates: learning_rates output from a RETURNNTrainingJob
:param str key: a key from the learning rate file that is used to sort the models
e.g. "dev_score_output/output_prob"
:param int index: index of the sorted list to access, 0 for the lowest, -1 for the highest score
"""
super().__init__(model_dir, learning_rates, key, index)
self._out_model_dir = self.output_path("model", directory=True)
vieting marked this conversation as resolved.
Show resolved Hide resolved

# Note: checkpoint.pt (without epoch number) is only a symlink which is possibly resolved by RETURNN
self.out_checkpoint = PtCheckpoint(self.output_path("model/checkpoint.pt"))

def run(self):
super().run()

try:
os.link(
os.path.join(self.model_dir.get_path(), "epoch.%.3d.pt" % self.out_epoch.get()),
os.path.join(
self._out_model_dir.get_path(),
"epoch.%.3d.pt" % self.out_epoch.get(),
),
)
except OSError:
# the hardlink will fail when there was an imported job on a different filesystem,
# thus do a copy instead then
shutil.copy(
os.path.join(self.model_dir.get_path(), "epoch.%.3d.pt" % self.out_epoch.get()),
os.path.join(
self._out_model_dir.get_path(),
"epoch.%.3d.pt" % self.out_epoch.get(),
),
)

os.symlink(
vieting marked this conversation as resolved.
Show resolved Hide resolved
os.path.join(self._out_model_dir.get_path(), "epoch.%.3d.pt" % self.out_epoch.get()),
os.path.join(self._out_model_dir.get_path(), "checkpoint.pt"),
)


class AverageTFCheckpointsJob(Job):
"""
Compute the average of multiple specified Tensorflow checkpoints using the tf_avg_checkpoints script from Returnn
Expand Down
Loading