Skip to content

Commit

Permalink
Merge pull request #343 from libAtoms/md_logger_stdout
Browse files Browse the repository at this point in the history
Improve generator.md logging
  • Loading branch information
bernstei authored Oct 11, 2024
2 parents 3a5c4e5 + 849f0e1 commit 80a5a23
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 9 deletions.
41 changes: 38 additions & 3 deletions tests/test_md.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,14 +221,15 @@ def test_md_abort_function(cu_slab):
assert len(list(atoms_traj)) < 501


def test_md_attach_logger(cu_slab):
def test_md_attach_logger(cu_slab, tmp_path, monkeypatch):
monkeypatch.chdir(tmp_path)

calc = EMT()
autopara_info = autoparainfo.AutoparaInfo(num_python_subprocesses=2, num_inputs_per_python_subprocess=1, skip_failed=False)

inputs = ConfigSet([cu_slab, cu_slab])
outputs = OutputSpec()

logger_kwargs = {
"logger" : MDLogger,
"logfile" : "test_log",
Expand All @@ -245,4 +246,38 @@ def test_md_attach_logger(cu_slab):

assert len(atoms_traj) == 602
assert all([Path(workdir / "test_log.item_0").is_file(), Path(workdir / "test_log.item_1").is_file()])



def test_md_attach_logger_stdout(cu_slab, tmp_path, monkeypatch, capsys):
monkeypatch.chdir(tmp_path)

calc = EMT()
autopara_info = autoparainfo.AutoparaInfo(num_python_subprocesses=2, num_inputs_per_python_subprocess=1, skip_failed=False)

inputs = ConfigSet([cu_slab, cu_slab])
outputs = OutputSpec()

logger_kwargs = {
"logger" : MDLogger,
"logfile" : "-",
}

atoms_traj = md.md(inputs, outputs, calculator=calc, integrator="Langevin", steps=300, dt=1.0,
temperature=500.0, temperature_tau=100/fs, logger_kwargs=logger_kwargs, logger_interval=1,
rng=np.random.default_rng(1), autopara_info=autopara_info,)

atoms_traj = list(atoms_traj)
atoms_final = atoms_traj[-1]

workdir = Path(os.getcwd())

assert len(atoms_traj) == 602

# make sure normal log files were not written
assert len(list(Path(workdir).glob("*"))) == 0

captured = capsys.readouterr()
n_0 = sum(['item 0 ' in captured.out.splitlines()])
n_1 = sum(['item 1 ' in captured.out.splitlines()])
if n_0 != 301 or n_1 != 301:
pytest.xfail("capsys fails to capture stdout to check for logger output")
21 changes: 15 additions & 6 deletions wfl/generate/md/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,10 @@ def _sample_autopara_wrappable(atoms, calculator, steps, dt, integrator="NVTBere
random number generator to use (needed for pressure sampling, initial temperature, or Langevin dynamics)
logger_kwargs: dict, default None
kwargs to MDLogger to attach to each MD run, including "logfile" as string to which
config number will be appended. User defined ase.md.MDLogger derived class can be provided with "logger" as key.
config number will be appended. If logfile is "-", stdout will be used, and config number
will be prepended to each outout line. User defined ase.md.MDLogger derived class can be provided with "logger" as key.
logger_interval: int, default None
interval for logger
Enable logging at this interval
_autopara_per_item_info: dict
INTERNALLY used by autoparallelization framework to make runs reproducible (see
wfl.autoparallelize.autoparallelize() docs)
Expand All @@ -106,9 +107,11 @@ def _sample_autopara_wrappable(atoms, calculator, steps, dt, integrator="NVTBere
else:
logfile = None

if logger_kwargs is not None:
if logger_interval is not None and logger_interval > 0:
if logger_kwargs is None:
logger_kwargs = {}
logger_constructor = logger_kwargs.pop("logger", MDLogger)
logger_logfile = logger_kwargs["logfile"]
logger_logfile = logger_kwargs.get("logfile", "-")

if temperature_tau is None and (temperature is not None and not isinstance(temperature, (float, int))):
raise RuntimeError(f'NVE (temperature_tau is None) can only accept temperature=float for initial T, got {type(temperature)}')
Expand Down Expand Up @@ -248,11 +251,17 @@ def process_step(interval):
md = md_constructor(at, **stage_kwargs)

md.attach(process_step, 1, traj_step_interval)
if logger_kwargs is not None:
logger_kwargs["logfile"] = f"{logger_logfile}.item_{item_i}"
if logger_interval is not None and logger_interval > 0:
if logger_logfile == "-":
logger_kwargs["logfile"] = "-"
else:
logger_kwargs["logfile"] = f"{logger_logfile}.item_{item_i}"
logger_kwargs["dyn"] = md
logger_kwargs["atoms"] = at
logger = logger_constructor(**logger_kwargs)
if logger_logfile == "-":
# add prefix to each line
logger.fmt = f"item {item_i} " + logger.fmt
md.attach(logger, logger_interval)

if stage_i > 0:
Expand Down

0 comments on commit 80a5a23

Please sign in to comment.