From a5f9f1727c3aafa9ffbbc618f155bd7ce8594b93 Mon Sep 17 00:00:00 2001 From: Eren G??lge Date: Tue, 12 Dec 2023 14:39:10 +0100 Subject: [PATCH 1/2] Fixup --- trainer/io.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/trainer/io.py b/trainer/io.py index eb34082..9348981 100644 --- a/trainer/io.py +++ b/trainer/io.py @@ -176,14 +176,23 @@ def save_best_model( epoch, out_path, keep_all_best=False, - keep_after=10000, + keep_after=0, save_func=None, **kwargs, ): - use_eval_loss = current_loss["eval_loss"] is not None and best_loss["eval_loss"] is not None - if (use_eval_loss and current_loss["eval_loss"] < best_loss["eval_loss"]) or ( - not use_eval_loss and current_loss["train_loss"] < best_loss["train_loss"] - ): + if isinstance(current_loss, dict): + use_eval_loss = current_loss["eval_loss"] is not None and best_loss["eval_loss"] is not None + is_save_model = (use_eval_loss and current_loss["eval_loss"] < best_loss["eval_loss"]) or ( + not use_eval_loss and current_loss["train_loss"] < best_loss["train_loss"] + ) + else: + is_save_model = current_loss < best_loss + + if isinstance(keep_after, int) or isinstance(keep_after, float): + keep_after = int(keep_after) + is_save_model = is_save_model and current_step > keep_after + + if is_save_model: best_model_name = f"best_model_{current_step}.pth" checkpoint_path = os.path.join(out_path, best_model_name) logger.info(" > BEST MODEL : %s", checkpoint_path) From 0741bbdc849fa154d4850d151dffa45da11f0428 Mon Sep 17 00:00:00 2001 From: Eren G??lge Date: Tue, 12 Dec 2023 14:43:25 +0100 Subject: [PATCH 2/2] Fixup lint --- trainer/io.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trainer/io.py b/trainer/io.py index 9348981..62381f1 100644 --- a/trainer/io.py +++ b/trainer/io.py @@ -188,7 +188,7 @@ def save_best_model( else: is_save_model = current_loss < best_loss - if isinstance(keep_after, int) or isinstance(keep_after, float): + if isinstance(keep_after, (int, float)): keep_after = int(keep_after) is_save_model = is_save_model and current_step > keep_after