Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend elastic properties to MLPs, and include elastic property analysis #2693

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions src/quacc/atoms/deformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def make_deformations_from_bulk(
norm_strains: Sequence[float] = (-0.01, -0.005, 0.005, 0.01),
shear_strains: Sequence[float] = (-0.06, -0.03, 0.03, 0.06),
symmetry: bool = False,
) -> list[Atoms]:
) -> DeformedStructureSet:
"""
Function to generate deformed structures from a bulk atoms object.

Expand All @@ -38,13 +38,11 @@ def make_deformations_from_bulk(
list[Atoms]
All generated deformed structures
"""
struct = AseAtomsAdaptor.get_structure(atoms)
struct = AseAtomsAdaptor.get_structure(atoms) # type: ignore

deformed_set = DeformedStructureSet(
return DeformedStructureSet(
struct,
norm_strains=norm_strains,
shear_strains=shear_strains,
symmetry=symmetry,
)

return [structure.to_ase_atoms() for structure in deformed_set]
81 changes: 69 additions & 12 deletions src/quacc/recipes/common/elastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,66 @@

from typing import TYPE_CHECKING

from quacc import subflow
from ase import units
from ase.stress import voigt_6_to_full_3x3_stress
from emmet.core.elasticity import ElasticityDoc
from emmet.core.mpid import MPID
from pymatgen.analysis.elasticity.stress import Stress
from pymatgen.io.ase import AseAtomsAdaptor

from quacc import job, subflow
from quacc.atoms.deformation import make_deformations_from_bulk

if TYPE_CHECKING:
from typing import Any

from ase.atoms import Atoms
from pymatgen.analysis.elasticity.strain import DeformedStructureSet

from quacc import Job
from quacc.types import ElasticSchema, OptSchema, RunSchema


@job
def deformations_to_elastic_tensor(
undeformed_result: OptSchema | RunSchema,
deformed_structure_set: DeformedStructureSet,
results: list[dict],
) -> ElasticityDoc:
structure = AseAtomsAdaptor.get_structure(undeformed_result["atoms"]) # type: ignore
return ElasticityDoc.from_deformations_and_stresses(
structure,
material_id=MPID("quacc-00"),
deformations=deformed_structure_set.deformations,
equilibrium_stress=Stress(
(
voigt_6_to_full_3x3_stress(undeformed_result["results"]["stress"])
if len(undeformed_result["results"]["stress"]) == 6
else undeformed_result["results"]["stress"]
)
/ units.GPa
),
stresses=[
Stress(
(
voigt_6_to_full_3x3_stress(relax_result["results"]["stress"])
if len(relax_result["results"]["stress"]) == 6
else relax_result["results"]["stress"]
)
/ units.GPa
)
for relax_result in results
],
)


@subflow
def bulk_to_deformations_subflow(
atoms: Atoms,
undeformed_result: OptSchema | RunSchema,
relax_job: Job,
static_job: Job | None = None,
static_job: Job,
run_static: bool = False,
deform_kwargs: dict[str, Any] | None = None,
) -> list[dict]:
) -> ElasticSchema:
"""
Workflow consisting of:

Expand All @@ -33,10 +75,12 @@ def bulk_to_deformations_subflow(

Parameters
----------
atoms
Atoms object
undeformed_result
Result of a static or optimization calculation
relax_job
The relaxation function.
static_job
The static function
static_job
The static function.
deform_kwargs
Expand All @@ -50,15 +94,28 @@ def bulk_to_deformations_subflow(
"""
deform_kwargs = deform_kwargs or {}

deformations = make_deformations_from_bulk(atoms, **deform_kwargs)
deformed_structure_set = make_deformations_from_bulk(
undeformed_result["atoms"], **deform_kwargs
)

results = []
for deformed in deformations:
result = relax_job(deformed)
for deformed in deformed_structure_set:
result = relax_job(deformed.to_ase_atoms())

if static_job is not None:
if run_static:
result = static_job(result["atoms"])

results.append(result)

return results
elasticity_doc = deformations_to_elastic_tensor(
undeformed_result=undeformed_result,
deformed_structure_set=deformed_structure_set,
results=results,
)

return {
"deformed_structure_set": deformed_structure_set,
"deformed_results": results,
"undeformed_result": undeformed_result,
"elasticity_doc": elasticity_doc,
}
21 changes: 14 additions & 7 deletions src/quacc/recipes/emt/elastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,18 @@

from ase.atoms import Atoms

from quacc.types import OptSchema, RunSchema
from quacc.types import ElasticSchema


@flow
def bulk_to_deformations_flow(
atoms: Atoms,
run_static: bool = True,
run_static: bool = False,
pre_relax: bool = True,
deform_kwargs: dict[str, Any] | None = None,
job_params: dict[str, dict[str, Any]] | None = None,
job_decorators: dict[str, Callable | None] | None = None,
) -> list[RunSchema | OptSchema]:
) -> ElasticSchema:
"""
Workflow consisting of:

Expand Down Expand Up @@ -66,11 +67,17 @@ def bulk_to_deformations_flow(
[relax_job, static_job],
param_swaps=job_params,
decorators=job_decorators,
)
) # type: ignore

if pre_relax:
undeformed_result = relax_job_(atoms, relax_cell=True)
else:
undeformed_result = static_job_(atoms)

return bulk_to_deformations_subflow(
atoms,
relax_job_,
static_job=static_job_ if run_static else None,
undeformed_result=undeformed_result,
relax_job=relax_job_,
static_job=static_job_,
run_static=run_static,
deform_kwargs=deform_kwargs,
)
85 changes: 85 additions & 0 deletions src/quacc/recipes/mlp/elastic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""Elastic constants recipes for MLPs."""

from __future__ import annotations

from typing import TYPE_CHECKING

from quacc import flow
from quacc.recipes.common.elastic import bulk_to_deformations_subflow
from quacc.recipes.mlp.core import relax_job, static_job
from quacc.wflow_tools.customizers import customize_funcs

if TYPE_CHECKING:
from collections.abc import Callable
from typing import Any

from ase.atoms import Atoms

from quacc.types import ElasticSchema


@flow
def bulk_to_deformations_flow(
atoms: Atoms,
run_static: bool = False,
pre_relax: bool = True,
deform_kwargs: dict[str, Any] | None = None,
job_params: dict[str, dict[str, Any]] | None = None,
job_decorators: dict[str, Callable | None] | None = None,
) -> ElasticSchema:
"""
Workflow consisting of:

1. Deformed structures generation

2. Deformed structures relaxations
- name: "relax_job"
- job: [quacc.recipes.mlp.core.relax_job][]

3. Deformed structures statics (optional)
- name: "static_job"
- job: [quacc.recipes.mlp.core.static_job][]

Parameters
----------
atoms
Atoms object
run_static
Whether to run static calculations after the relaxations
pre_relax
Whether to pre-relax the input atoms as is common
deform_kwargs
Additional keyword arguments to pass to [quacc.atoms.deformation.make_deformations_from_bulk][]
job_params
Custom parameters to pass to each Job in the Flow. This is a dictionary where
the keys are the names of the jobs and the values are dictionaries of parameters.
job_decorators
Custom decorators to apply to each Job in the Flow. This is a dictionary where
the keys are the names of the jobs and the values are decorators.

Returns
-------
list[RunSchema | OptSchema]
[RunSchema][quacc.schemas.ase.Summarize.run] or
[OptSchema][quacc.schemas.ase.Summarize.opt] for each deformation.
See the return type-hint for the data structure.
"""
relax_job_, static_job_ = customize_funcs(
["relax_job", "static_job"],
[relax_job, static_job],
param_swaps=job_params,
decorators=job_decorators,
) # type: ignore

if pre_relax:
undeformed_result = relax_job_(atoms, relax_cell=True)
else:
undeformed_result = static_job_(atoms)

return bulk_to_deformations_subflow(
undeformed_result=undeformed_result,
relax_job=relax_job_,
static_job=static_job_,
run_static=run_static,
deform_kwargs=deform_kwargs,
)
8 changes: 8 additions & 0 deletions src/quacc/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class DefaultSetting(BaseSettings):
from ase.atoms import Atoms
from ase.md.md import MolecularDynamics
from ase.optimize.optimize import Dynamics
from emmet.core.elasticity import ElasticityDoc
from emmet.core.math import ListMatrix3D, Matrix3D, Vector3D
from emmet.core.symmetry import CrystalSystem
from emmet.core.vasp.calc_types import CalcType
Expand All @@ -32,6 +33,7 @@ class DefaultSetting(BaseSettings):
from emmet.core.vasp.task_valid import TaskState
from numpy.random import Generator
from numpy.typing import ArrayLike, NDArray
from pymatgen.analysis.elasticity.strain import DeformedStructureSet
from pymatgen.core.composition import Composition
from pymatgen.core.lattice import Lattice
from pymatgen.core.periodic_table import Element
Expand Down Expand Up @@ -528,6 +530,12 @@ class ThermoSchema(AtomsSchema):
parameters_thermo: ParametersThermo
results: ThermoResults

class ElasticSchema(TypedDict):
deformed_structure_set: DeformedStructureSet
deformed_results: list[RunSchema | OptSchema]
undeformed_result: RunSchema | OptSchema
elasticity_doc: ElasticityDoc

class VibThermoSchema(VibSchema, ThermoSchema):
"""Combined Vibrations and Thermo schema"""

Expand Down
13 changes: 9 additions & 4 deletions tests/core/atoms/test_elastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,13 @@ def test_make_deformations_from_bulk():
atoms.info["test"] = "hi"
deformations = make_deformations_from_bulk(atoms)
assert len(deformations) == 24
assert deformations[0].get_volume() != pytest.approx(atoms.get_volume())
assert deformations[0].to_ase_atoms().get_volume() != pytest.approx(
atoms.get_volume()
)
for deformation in deformations:
assert_equal(deformation.get_atomic_numbers(), [30, 30, 30, 30, 52, 52, 52, 52])
assert_equal(deformation.get_chemical_formula(), "Te4Zn4")
assert deformation.info["test"] == "hi"
assert_equal(
deformation.to_ase_atoms().get_atomic_numbers(),
[30, 30, 30, 30, 52, 52, 52, 52],
)
assert_equal(deformation.to_ase_atoms().get_chemical_formula(), "Te4Zn4")
assert deformation.to_ase_atoms().info["test"] == "hi"
20 changes: 14 additions & 6 deletions tests/core/recipes/emt_recipes/test_emt_elastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,29 @@ def test_elastic_jobs(tmp_path, monkeypatch):
atoms = bulk("Cu")

outputs = bulk_to_deformations_flow(atoms, run_static=False)
assert outputs[0]["atoms"].get_volume() != pytest.approx(atoms.get_volume())
for output in outputs:
assert outputs["deformed_results"][0]["atoms"].get_volume() != pytest.approx(
atoms.get_volume()
)
assert outputs["elasticity_doc"].bulk_modulus.voigt == pytest.approx(134.579)
for output in outputs["deformed_results"]:
assert output["parameters"]["asap_cutoff"] is False
assert output["name"] == "EMT Relax"
assert output["nelements"] == 1
assert output["nsites"] == 1
assert len(outputs) == 24
assert len(outputs["deformed_results"]) == 24

outputs = bulk_to_deformations_flow(
atoms, run_static=True, job_params={"static_job": {"asap_cutoff": True}}
)
assert outputs[0]["atoms"].get_volume() != pytest.approx(atoms.get_volume())
for output in outputs:
assert outputs["deformed_results"][0]["atoms"].get_volume() != pytest.approx(
atoms.get_volume()
)
assert outputs["deformed_results"][0]["atoms"].get_volume() != pytest.approx(
atoms.get_volume()
)
for output in outputs["deformed_results"]:
assert output["parameters"]["asap_cutoff"] is True
assert output["name"] == "EMT Static"
assert output["nelements"] == 1
assert output["nsites"] == 1
assert len(outputs) == 24
assert len(outputs["deformed_results"]) == 24
12 changes: 6 additions & 6 deletions tests/core/recipes/mlp_recipes/test_core_recipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,19 +31,19 @@
methods.append("fairchem")


def _set_dtype(size, type_="float"):
globals()[f"{type_}_th"] = getattr(torch, f"{type_}{size}")
globals()[f"{type_}_np"] = getattr(np, f"{type_}{size}")
torch.set_default_dtype(getattr(torch, f"float{size}"))


@pytest.mark.skipif(has_chgnet is None, reason="chgnet not installed")
def test_bad_method():
atoms = bulk("Cu")
with pytest.raises(ValueError, match="Unrecognized method='bad_method'"):
static_job(atoms, method="bad_method")


def _set_dtype(size, type_="float"):
globals()[f"{type_}_th"] = getattr(torch, f"{type_}{size}")
globals()[f"{type_}_np"] = getattr(np, f"{type_}{size}")
torch.set_default_dtype(getattr(torch, f"float{size}"))


@pytest.mark.parametrize("method", methods)
def test_static_job(tmp_path, monkeypatch, method):
monkeypatch.chdir(tmp_path)
Expand Down
Loading
Loading