Skip to content

Commit

Permalink
more consistent LR file reading (#551)
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz authored Oct 21, 2024
1 parent 6968fdb commit 1a2cee5
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions returnn/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import os
import shutil
import subprocess as sp
import numpy as np
from typing import Dict, Sequence, Iterable, List, Optional, Union

from sisyphus import *
Expand Down Expand Up @@ -287,8 +288,6 @@ def _get_run_cmd(self):
return run_cmd

def info(self):
import numpy as np

def try_load_lr_log(file_path: str) -> Optional[dict]:
# Used in parsing the learning rates
@dataclass
Expand All @@ -298,7 +297,10 @@ class EpochData:

try:
with open(file_path, "rt") as file:
return eval(file.read().strip(), {"EpochData": EpochData, "np": np})
return eval(
file.read().strip(),
{"EpochData": EpochData, "nan": float("nan"), "inf": float("inf"), "np": np},
)
except FileExistsError:
return None
except FileNotFoundError:
Expand Down Expand Up @@ -394,7 +396,7 @@ def EpochData(learningRate, error):
with open(self.out_learning_rates.get_path(), "rt") as f:
text = f.read()

data = eval(text)
data = eval(text, {"EpochData": EpochData, "nan": float("nan"), "inf": float("inf"), "np": np})

epochs = list(sorted(data.keys()))
train_score_keys = [k for k in data[epochs[0]]["error"] if k.startswith("train_score")]
Expand Down Expand Up @@ -704,7 +706,7 @@ def EpochData(learningRate, error):
with open(self.learning_rates.get_path(), "rt") as f:
text = f.read()

data = eval(text, {"nan": float("nan"), "inf": float("inf"), "EpochData": EpochData})
data = eval(text, {"EpochData": EpochData, "nan": float("nan"), "inf": float("inf"), "np": np})

epochs = list(sorted(data.keys()))

Expand Down

0 comments on commit 1a2cee5

Please sign in to comment.