From 1a2cee58ad55baddb94d65bd5b14d74c839d59a9 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Mon, 21 Oct 2024 12:13:03 +0200 Subject: [PATCH] more consistent LR file reading (#551) --- returnn/training.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/returnn/training.py b/returnn/training.py index 57020ca2..03e17127 100644 --- a/returnn/training.py +++ b/returnn/training.py @@ -17,6 +17,7 @@ import os import shutil import subprocess as sp +import numpy as np from typing import Dict, Sequence, Iterable, List, Optional, Union from sisyphus import * @@ -287,8 +288,6 @@ def _get_run_cmd(self): return run_cmd def info(self): - import numpy as np - def try_load_lr_log(file_path: str) -> Optional[dict]: # Used in parsing the learning rates @dataclass @@ -298,7 +297,10 @@ class EpochData: try: with open(file_path, "rt") as file: - return eval(file.read().strip(), {"EpochData": EpochData, "np": np}) + return eval( + file.read().strip(), + {"EpochData": EpochData, "nan": float("nan"), "inf": float("inf"), "np": np}, + ) except FileExistsError: return None except FileNotFoundError: @@ -394,7 +396,7 @@ def EpochData(learningRate, error): with open(self.out_learning_rates.get_path(), "rt") as f: text = f.read() - data = eval(text) + data = eval(text, {"EpochData": EpochData, "nan": float("nan"), "inf": float("inf"), "np": np}) epochs = list(sorted(data.keys())) train_score_keys = [k for k in data[epochs[0]]["error"] if k.startswith("train_score")] @@ -704,7 +706,7 @@ def EpochData(learningRate, error): with open(self.learning_rates.get_path(), "rt") as f: text = f.read() - data = eval(text, {"nan": float("nan"), "inf": float("inf"), "EpochData": EpochData}) + data = eval(text, {"EpochData": EpochData, "nan": float("nan"), "inf": float("inf"), "np": np}) epochs = list(sorted(data.keys()))