diff --git a/trainer/io.py b/trainer/io.py index eb34082..62381f1 100644 --- a/trainer/io.py +++ b/trainer/io.py @@ -176,14 +176,23 @@ def save_best_model( epoch, out_path, keep_all_best=False, - keep_after=10000, + keep_after=0, save_func=None, **kwargs, ): - use_eval_loss = current_loss["eval_loss"] is not None and best_loss["eval_loss"] is not None - if (use_eval_loss and current_loss["eval_loss"] < best_loss["eval_loss"]) or ( - not use_eval_loss and current_loss["train_loss"] < best_loss["train_loss"] - ): + if isinstance(current_loss, dict): + use_eval_loss = current_loss["eval_loss"] is not None and best_loss["eval_loss"] is not None + is_save_model = (use_eval_loss and current_loss["eval_loss"] < best_loss["eval_loss"]) or ( + not use_eval_loss and current_loss["train_loss"] < best_loss["train_loss"] + ) + else: + is_save_model = current_loss < best_loss + + if isinstance(keep_after, (int, float)): + keep_after = int(keep_after) + is_save_model = is_save_model and current_step > keep_after + + if is_save_model: best_model_name = f"best_model_{current_step}.pth" checkpoint_path = os.path.join(out_path, best_model_name) logger.info(" > BEST MODEL : %s", checkpoint_path)