Skip to content

Commit

Permalink
Merge branch 'chemle:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
kzinovjev authored Apr 8, 2024
2 parents 5519af2 + 2b29a27 commit 5e639bc
Show file tree
Hide file tree
Showing 6 changed files with 129 additions and 41 deletions.
19 changes: 19 additions & 0 deletions bin/emle-server
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,11 @@ save_settings = os.getenv("EMLE_SAVE_SETTINGS")
orca_template = os.getenv("EMLE_ORCA_TEMPLATE")
deepmd_model = os.getenv("EMLE_DEEPMD_MODEL")
deepmd_deviation = os.getenv("EMLE_DEEPMD_DEVIATION")
qm_xyz_file = os.getenv("EMLE_QM_XYZ_FILE")
try:
qm_xyz_frequency = int(os.getenv("EMLE_QM_XYZ_FREQUENCY"))
except:
qm_xyz_frequency = 0
rascal_model = os.getenv("EMLE_RASCAL_MODEL")
parm7 = os.getenv("EMLE_PARM7")
try:
Expand Down Expand Up @@ -127,6 +132,8 @@ env = {
"device": device,
"deepmd_model": deepmd_model,
"deepmd_deviation": deepmd_deviation,
"qm_xyz_file": qm_xyz_file,
"qm_xyz_frequency": qm_xyz_frequency,
"rascal_model": rascal_model,
"lambda_interpolate": lambda_interpolate,
"interpolate_steps": interpolate_steps,
Expand Down Expand Up @@ -227,6 +234,18 @@ parser.add_argument(
help="path to a file to write the max deviation between forces predicted with the DeePMD models",
required=False,
)
parser.add_argument(
"--qm-xyz-file",
type=str,
help="path to a file to write the QM region coordinates",
required=False,
)
parser.add_argument(
"--qm-xyz-frequency",
type=int,
help="the frequency of writing the QM region coordinates to file (0 to disable)",
required=False,
)
parser.add_argument(
"--rascal-model",
type=str,
Expand Down
46 changes: 37 additions & 9 deletions emle/calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,8 @@ def __init__(
with the DeePMD models.
qm_xyz_file: str
Path to write out xyz trajectory of the QM region.
Path to an output file for writing the xyz trajectory of the QM
region.
qm_xyz_frequency: int
How often to write the xyz trajectory of the QM region. Zero turns
Expand Down Expand Up @@ -792,8 +793,34 @@ def __init__(
_logger.error(msg)
raise ValueError(msg)

self.qm_xyz_frequency = qm_xyz_frequency
self.qm_xyz_file = qm_xyz_file
# Set the deviation file to None in case it was spuriously set.
self._deepmd_deviation = None

# Validate the QM XYZ file options.

if qm_xyz_file is None:
qm_xyz_file = "qm.xyz"
else:
if not isinstance(qm_xyz_file, str):
msg = "'qm_xyz_file' must be of type 'str'"
_logger.error(msg)
raise TypeError(msg)
self._qm_xyz_file = qm_xyz_file

if qm_xyz_frequency is None:
qm_xyz_frequency = 0
else:
try:
qm_xyz_frequency = int(qm_xyz_frequency)
except:
msg = "'qm_xyz_frequency' must be of type 'int'"
_logger.error(msg)
raise TypeError(msg)
if qm_xyz_frequency < 0:
msg = "'qm_xyz_frequency' must be greater than or equal to 0"
_logger.error(msg)
raise ValueError(msg)
self._qm_xyz_frequency = qm_xyz_frequency

# Validate the QM method for SQM.
if backend == "sqm":
Expand Down Expand Up @@ -1501,12 +1528,12 @@ def run(self, path=None):
else:
f.write(f"{self._step:>10}{E_vac:22.12f}{E_tot:22.12f}\n")

if self.qm_xyz_frequency > 0 and self._step % self.qm_xyz_frequency == 0:
# Write out the QM region to xyz trajectory file
# Write out the QM region to the xyz trajectory file.
if self._qm_xyz_frequency > 0 and self._step % self._qm_xyz_frequency == 0:
atoms = _ase.Atoms(positions=xyz_qm, numbers=atomic_numbers)
if hasattr(self, 'max_f_std'):
atoms.info = {'max_f_std': self.max_f_std}
_ase_io.write(self.qm_xyz_file, atoms, append=True)
if hasattr(self, "_max_f_std"):
atoms.info = {"max_f_std": self._max_f_std}
_ase_io.write(self._qm_xyz_file, atoms, append=True)

# Increment the step counter.
if self._is_first_step:
Expand Down Expand Up @@ -2739,7 +2766,8 @@ def _run_deepmd(self, xyz, elements):
max_f_std = calc_model_devi_f(_np.array(f_list))[0][0]
with open(self._deepmd_deviation, "a") as f:
f.write(f"{max_f_std:12.5f}\n")
self.max_f_std = max_f_std # To be written to qm_xyz_file
# To be written to qm_xyz_file.
self._max_f_std = max_f_std

# Take averages and return. (Gradient equals minus the force.)
return (
Expand Down
27 changes: 27 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import os
import pytest
import shlex
import subprocess


@pytest.fixture(autouse=True)
def wrapper():
"""
A wrapper function that clears the environment variables before each test
and stops the EMLE server after each test.
"""

# Clear the environment.

for env in os.environ:
if env.startswith("EMLE_"):
del os.environ[env]

yield

# Stop the EMLE server.
process = subprocess.run(
shlex.split("emle-stop"),
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
16 changes: 0 additions & 16 deletions tests/test_external.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,6 @@
import tempfile


@pytest.fixture(autouse=True)
def teardown():
"""
Clean up the environment.
"""

yield

# Stop the EMLE server.
process = subprocess.run(
shlex.split("emle-stop"),
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)


def test_external_local_directory():
"""
Make sure that the server can run using an external callback for the in
Expand Down
16 changes: 0 additions & 16 deletions tests/test_interpolate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,6 @@
import tempfile


@pytest.fixture(autouse=True)
def teardown():
"""
Clean up the environment.
"""

yield

# Stop the EMLE server.
process = subprocess.run(
shlex.split("emle-stop"),
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)


def parse_mdinfo(mdinfo_file):
"""
Helper function to extract the total energy from AMBER mdinfo files.
Expand Down
46 changes: 46 additions & 0 deletions tests/test_qm_xyz.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import os
import pytest
import shlex
import shutil
import subprocess
import tempfile


def test_qm_xyz():
"""
Make sure that an xyz file for the QM region is written when requested.
"""

with tempfile.TemporaryDirectory() as tmpdir:
# Copy files to temporary directory.
shutil.copyfile("tests/input/adp.parm7", tmpdir + "/adp.parm7")
shutil.copyfile("tests/input/adp.rst7", tmpdir + "/adp.rst7")
shutil.copyfile("tests/input/emle_prod.in", tmpdir + "/emle_prod.in")

# Set environment variables.
os.environ["EMLE_PORT"] = "12345"
os.environ["EMLE_QM_XYZ_FREQUENCY"] = "2"

# Create the sander command.
command = "sander -O -i emle_prod.in -p adp.parm7 -c adp.rst7 -o emle.out"

process = subprocess.run(
shlex.split(command),
cwd=tmpdir,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)

# Make sure that the process exited successfully.
assert process.returncode == 0

# Make sure that an xyz file was written.
assert os.path.isfile(tmpdir + "/qm.xyz")

# Make sure that the file contains the expected number of frames.
with open(tmpdir + "/qm.xyz", "r") as f:
num_frames = 0
for line in f:
if line.startswith("22"):
num_frames += 1
assert num_frames == 11

0 comments on commit 5e639bc

Please sign in to comment.