diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..b7526e3 --- /dev/null +++ b/.gitignore @@ -0,0 +1,163 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +# Results directory +results/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..c7a2d66 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,17 @@ +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.6.0 + hooks: + - id: check-added-large-files + - id: check-yaml + - id: debug-statements + - id: end-of-file-fixer + - id: trailing-whitespace + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.4.2 + hooks: + - id: ruff + types_or: [ python, pyi, jupyter ] + args: [ --fix ] + - id: ruff-format + types_or: [ python, pyi, jupyter ] \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..5c5d1d7 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2024 Juno Nam and Rafael Gomez-Bombarelli + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/README.md b/README.md index 490cb2f..15b5a0b 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,90 @@ # Alchemical MLIP +[![arXiv](https://img.shields.io/badge/arXiv-2404.10746-84cc16)](https://arxiv.org/abs/2404.10746) +[![Zenodo](https://img.shields.io/badge/Zenodo-10.5281/zenodo.11081395-14b8a6.svg)](https://zenodo.org/doi/10.5281/zenodo.11081395) +[![MIT](https://img.shields.io/badge/License-MIT-3b82f6.svg)](https://opensource.org/license/mit) + This repository contains the code to modify machine learning interatomic potentials (MLIPs) to enable continuous and differentiable alchemical transformations. +Currently, we provide the alchemical modification for the [MACE](https://github.com/ACEsuit/mace) model. The details of the method are described in the paper: [Interpolation and differentiation of alchemical degrees of freedom in machine learning interatomic potentials](https://arxiv.org/abs/2404.10746). -> **Note:** The code is currently under final preparation and will be uploaded here shortly. -> We anticipate having the code ready by the end of this week. -> Please stay tuned for updates! +## Installation +We tested the code with Python 3.10 and the packages in `requirements.txt`. +For example, you can create a conda environment and install the required packages as follows (assuming CUDA 11.8): +```bash +conda create -n alchemical-mlip python=3.10 +conda activate alchemical-mlip +pip install torch==2.0.1 --index-url https://download.pytorch.org/whl/cu118 +pip install -r requirements.txt +pip install -e . +``` + +## Static calculations +We provide the jupyter notebooks for the lattice parameter calculations (Fig. 2 in the paper) and the compositional optimization (Fig. 3) in the `notebook` directory. +``` +notebook/ +├── 1_solid_solution.ipynb +└── 2_compositional_optimization.ipynb +``` + +## Free energy calculations +We provide the scripts for the free energy calculations for the vacancy (Fig. 4) and perovskites (Fig. 5) in the `scripts` directory. +``` +scripts/ +├── vacancy_frenkel_ladd.py +├── perovskite_frenkel_ladd.py +└── perovskite_alchemy.py +``` + +The arguments for the scripts are as follows: +```bash +# Vacancy Frenkel-Ladd calculation +python vacancy_frenkel_ladd.py \ + --structure-file data/structures/Fe.cif \ + --supercell 5 5 5 \ + --temperature 100 \ + --output-dir data/results/vacancy/Fe_5x5x5_100K/0 + +# Perovskite Frenkel-Ladd calculation (alpha phase) +python perovskite_frenkel_ladd.py \ + --structure-file data/structures/CsPbI3_alpha.cif \ + --supercell 6 6 6 \ + --temperature 400 \ + --output-dir data/results/perovskite/frenkel_ladd/CsPbI3_alpha_6x6x6_400K/0 + +# Perovskite Frenkel-Ladd calculation (delta phase) +python perovskite_frenkel_ladd.py \ + --structure-file data/structures/CsPbI3_delta.cif \ + --supercell 6 3 3 \ + --temperature 400 \ + --output-dir data/results/perovskite/frenkel_ladd/CsPbI3_delta_6x3x3_400K/0 + +# Perovskite alchemy calculation (alpha phase) +python -u perovskite_alchemy.py \ + --structure-file data/structures/CsPbI3_alpha.cif \ + --supercell 6 6 6 \ + --switch-pair Pb Sn \ + --temperature 400 \ + --output-dir data/results/perovskite/alchemy/CsPbI3_CsSnI3_alpha_400K/0 + +# Perovskite alchemy calculation (delta phase) +python -u perovskite_alchemy.py \ + --structure-file data/structures/CsPbI3_delta.cif \ + --supercell 6 3 3 \ + --switch-pair Pb Sn \ + --temperature 400 \ + --output-dir data/results/perovskite/alchemy/CsPbI3_CsSnI3_delta_400K/0 +``` + +The result files are large and not included in the repository. +If you want to reproduce the results without running the calculations, the result files are uploaded in the [Zenodo repository](https://zenodo.org/doi/10.5281/zenodo.11081395). +Please download the files and place them in the `data/results` directory. + +The post-processing scripts for the free energy calculations are provided in the `notebook` directory. +``` +notebook/ +├── 3_vacancy_analysis.ipynb +└── 4_perovskite_analysis.ipynb +``` ## Citation ``` diff --git a/THIRD-PARTY-LICENSES b/THIRD-PARTY-LICENSES new file mode 100644 index 0000000..859c10b --- /dev/null +++ b/THIRD-PARTY-LICENSES @@ -0,0 +1,60 @@ +Code in alchemical_mace/{calculator,model}.py is adapted from +https://github.com/ACEsuit/mace + +MIT License + +Copyright (c) 2022 ACEsuit/mace + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +------------------------------------------------------------------------------- + +Code in alchemical_mace/utils.py is adapted from +https://github.com/CederGroupHub/chgnet + +Crystal Hamiltonian Graph neural Network (CHGNet) Copyright (c) 2023, The Regents +of the University of California, through Lawrence Berkeley National +Laboratory (subject to receipt of any required approvals from the U.S. +Dept. of Energy) and the University of California, Berkeley. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +(1) Redistributions of source code must retain the above copyright notice, +this list of conditions and the following disclaimer. + +(2) Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in the +documentation and/or other materials provided with the distribution. + +(3) Neither the name of the University of California, Lawrence Berkeley +National Laboratory, U.S. Dept. of Energy, University of California, +Berkeley nor the names of its contributors may be used to endorse or +promote products derived from this software without specific prior written +permission. + + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +POSSIBILITY OF SUCH DAMAGE. + +You are under no obligation whatsoever to provide any bug fixes, patches, +or upgrades to the features, functionality or performance of the source +code ("Enhancements") to anyone; however, if you choose to make your +Enhancements available either publicly, or directly to Lawrence Berkeley +National Laboratory, without imposing a separate written license agreement +for such Enhancements, then you hereby grant the following license: a +non-exclusive, royalty-free perpetual license to install, use, modify, +prepare derivative works, incorporate into other computer software, +distribute, and sublicense such enhancements or derivative works thereof, +in binary and source code form. diff --git a/alchemical_mace/__init__.py b/alchemical_mace/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/alchemical_mace/calculator.py b/alchemical_mace/calculator.py new file mode 100644 index 0000000..ffc06c1 --- /dev/null +++ b/alchemical_mace/calculator.py @@ -0,0 +1,432 @@ +from typing import Sequence, Tuple + +import ase +import numpy as np +import torch +import torch.nn.functional as F +from ase.calculators.calculator import Calculator, all_changes +from ase.constraints import ExpCellFilter +from ase.optimize import FIRE +from ase.stress import full_3x3_to_voigt_6_stress +from mace import data +from mace.calculators import mace_mp +from mace.tools import torch_geometric + +from alchemical_mace.model import ( + AlchemicalPair, + AlchemyManager, + alchemical_mace_mp, + get_z_table_and_r_max, +) + +################################################################################ +# Alchemical MACE calculator +################################################################################ + + +class AlchemicalMACECalculator(Calculator): + """ + Alchemical MACE calculator for ASE. + """ + + def __init__( + self, + atoms: ase.Atoms, + alchemical_pairs: Sequence[Sequence[Tuple[int, int]]], + alchemical_weights: Sequence[float], + device: str = "cpu", + model: str = "medium", + ): + """ + Initialize the Alchemical MACE calculator. + + Args: + atoms (ase.Atoms): Atoms object. + alchemical_pairs (Sequence[Sequence[Tuple[int, int]]]): List of + alchemical pairs. Each pair is a tuple of the atom index and + atomic number of an alchemical atom. + alchemical_weights (Sequence[float]): List of alchemical weights. + device (str): Device to run the calculations on. + model (str): Model to use for the MACE calculator. + """ + Calculator.__init__(self) + self.results = {} + self.implemented_properties = ["energy", "free_energy", "forces", "stress"] + + # Build the alchemical MACE model + self.device = device + self.model = alchemical_mace_mp( + model=model, device=device, default_dtype="float32" + ) + for param in self.model.parameters(): + param.requires_grad = False + + # Set AlchemyManager + z_table, r_max = get_z_table_and_r_max(self.model) + alchemical_weights = torch.tensor(alchemical_weights, dtype=torch.float32) + self.alchemy_manager = AlchemyManager( + atoms=atoms, + alchemical_pairs=alchemical_pairs, + alchemical_weights=alchemical_weights, + z_table=z_table, + r_max=r_max, + ).to(self.device) + + # Disable alchemical weights gradients by default + self.alchemy_manager.alchemical_weights.requires_grad = False + self.calculate_alchemical_grad = False + + self.num_atoms = len(atoms) + + def set_alchemical_weights(self, alchemical_weights: Sequence[float]): + alchemical_weights = torch.tensor( + alchemical_weights, + dtype=torch.float32, + device=self.device, + ) + self.alchemy_manager.alchemical_weights.data = alchemical_weights + + def get_alchemical_atomic_masses(self) -> np.ndarray: + # Get atomic masses for alchemical atoms + node_masses = ase.data.atomic_masses[self.alchemy_manager.atomic_numbers] + weights = self.alchemy_manager.alchemical_weights.data + weights = F.pad(weights, (1, 0), "constant", 1.0).cpu().numpy() + node_weights = weights[self.alchemy_manager.weight_indices] + + # Scatter sum to get the atomic masses + atom_masses = np.zeros(self.num_atoms, dtype=np.float32) + np.add.at( + atom_masses, self.alchemy_manager.atom_indices, node_masses * node_weights + ) + return atom_masses + + # pylint: disable=dangerous-default-value + def calculate(self, atoms=None, properties=None, system_changes=all_changes): + # call to base-class to set atoms attribute + Calculator.calculate(self, atoms) + + # prepare data + tensor_kwargs = {"dtype": torch.float32, "device": self.device} + positions = torch.tensor(atoms.get_positions(), **tensor_kwargs) + cell = torch.tensor(atoms.get_cell().array, **tensor_kwargs) + if self.calculate_alchemical_grad: + self.alchemy_manager.alchemical_weights.requires_grad = True + batch = self.alchemy_manager(positions, cell).to(self.device) + + # get outputs + if self.calculate_alchemical_grad: + out = self.model(batch, compute_stress=True, compute_alchemical_grad=True) + (grad,) = torch.autograd.grad( + outputs=[batch["node_weights"], batch["edge_weights"]], + inputs=[self.alchemy_manager.alchemical_weights], + grad_outputs=[out["node_grad"], out["edge_grad"]], + retain_graph=False, + create_graph=False, + ) + grad = grad.cpu().numpy() + self.alchemy_manager.alchemical_weights.requires_grad = False + else: + out = self.model(batch, retain_graph=False, compute_stress=True) + grad = np.zeros( + self.alchemy_manager.alchemical_weights.shape[0], dtype=np.float32 + ) + + # store results + self.results = {} + self.results["energy"] = out["energy"][0].item() + self.results["free_energy"] = self.results["energy"] + self.results["forces"] = out["forces"].detach().cpu().numpy() + self.results["stress"] = full_3x3_to_voigt_6_stress( + out["stress"][0].detach().cpu().numpy() + ) + self.results["alchemical_grad"] = grad + + +class NVTMACECalculator(Calculator): + def __init__(self, model: str = "medium", device: str = "cuda"): + Calculator.__init__(self) + self.results = {} + self.implemented_properties = ["energy", "free_energy", "forces", "stress"] + self.device = device + self.model = mace_mp( + model=model, device=device, default_dtype="float32" + ).models[0] + self.z_table, self.r_max = get_z_table_and_r_max(self.model) + for param in self.model.parameters(): + param.requires_grad = False + + # pylint: disable=dangerous-default-value + def calculate(self, atoms=None, properties=None, system_changes=all_changes): + # call to base-class to set atoms attribute + Calculator.calculate(self, atoms) + + # prepare data + config = data.config_from_atoms(atoms) + atomic_data = data.AtomicData.from_config( + config, z_table=self.z_table, cutoff=self.r_max + ) + data_loader = torch_geometric.dataloader.DataLoader( + dataset=[atomic_data], + batch_size=1, + shuffle=False, + drop_last=False, + ) + batch = next(iter(data_loader)).to(self.device) + + out = self.model(batch, compute_stress=False) + self.results = {} + self.results["energy"] = out["energy"][0].item() + self.results["free_energy"] = self.results["energy"] + self.results["forces"] = out["forces"].detach().cpu().numpy() + + +class FrenkelLaddCalculator(Calculator): + """ + Frenkel-Ladd calculator for ASE. + """ + + def __init__( + self, + spring_constants: np.ndarray, + initial_positions: np.ndarray, + device: str, + model: str = "medium", + ): + """ + Initialize the Frenkel-Ladd calculator. + + Args: + spring_constants (np.ndarray): Spring constants for each atom. + initial_positions (np.ndarray): Initial positions of the atoms. + device (str): Device to run the calculations on. + model (str): Model to use for the MACE calculator. + """ + Calculator.__init__(self) + self.results = {} + self.implemented_properties = ["energy", "free_energy", "forces"] + self.device = device + self.model = mace_mp( + model=model, device=device, default_dtype="float32" + ).models[0] + self.z_table, self.r_max = get_z_table_and_r_max(self.model) + for param in self.model.parameters(): + param.requires_grad = False + + # Spring constants + self.spring_constants = spring_constants + self.initial_positions = initial_positions + + # Reversible scaling factor + self.weights = [1.0, 0.0] + self.compute_mace = True + + def set_weights(self, lambda_value: float): + self.weights = [1.0 - lambda_value, lambda_value] + + # pylint: disable=dangerous-default-value + def calculate(self, atoms=None, properties=None, system_changes=all_changes): + # call to base-class to set atoms attribute + Calculator.calculate(self, atoms) + + # Get MACE results if needed + if self.compute_mace: + config = data.config_from_atoms(atoms) + atomic_data = data.AtomicData.from_config( + config, z_table=self.z_table, cutoff=self.r_max + ) + data_loader = torch_geometric.dataloader.DataLoader( + dataset=[atomic_data], + batch_size=1, + shuffle=False, + drop_last=False, + ) + batch = next(iter(data_loader)).to(self.device) + out = self.model(batch, compute_stress=False) # Frenkel-Ladd is NVT + mace_energy = out["energy"][0].item() + mace_forces = out["forces"].detach().cpu().numpy() + else: + mace_energy = 0.0 + mace_forces = np.zeros((len(atoms), 3), dtype=np.float32) + + # Get spring energy and forces + displacement = atoms.get_positions() - self.initial_positions + spring_energy = 0.5 * np.sum( + self.spring_constants * np.sum(displacement**2, axis=1) + ) + spring_forces = -self.spring_constants[:, None] * displacement + + # Combine energies and forces + total_energy = self.weights[0] * spring_energy + self.weights[1] * mace_energy + total_forces = self.weights[0] * spring_forces + self.weights[1] * mace_forces + if self.compute_mace: + energy_diff = mace_energy - spring_energy + else: + energy_diff = 0.0 + + self.results = {} + self.results["energy"] = total_energy + self.results["free_energy"] = total_energy + self.results["forces"] = total_forces + self.results["energy_diff"] = energy_diff + + +class DefectFrenkelLaddCalculator(Calculator): + """ + Frenkel-Ladd calculator for ASE, for a crystal with a defect. + """ + + def __init__( + self, + atoms: ase.Atoms, + spring_constant: float, + defect_index: int, + device: str = "cpu", + model: str = "medium", + ): + """ + Initialize the Frenkel-Ladd calculator. + + Args: + atoms (ase.Atoms): Atoms object. + spring_constant (float): Spring constant for the defect atom. + defect_index (int): Index of the defect atom. + device (str): Device to run the calculations on. + model (str): Model to use for the MACE calculator. + """ + Calculator.__init__(self) + self.results = {} + self.implemented_properties = ["energy", "free_energy", "forces", "stress"] + + # Build the alchemical MACE model + self.device = device + self.model = alchemical_mace_mp( + model=model, device=device, default_dtype="float32" + ) + for param in self.model.parameters(): + param.requires_grad = False + + # Set AlchemyManager + z_table, r_max = get_z_table_and_r_max(self.model) + alchemical_weights = torch.tensor([1.0], dtype=torch.float32) + atomic_number = atoms.get_atomic_numbers()[defect_index] + alchemical_pairs = [[AlchemicalPair(defect_index, atomic_number)]] + self.alchemy_manager = AlchemyManager( + atoms=atoms, + alchemical_pairs=alchemical_pairs, + alchemical_weights=alchemical_weights, + z_table=z_table, + r_max=r_max, + ).to(self.device) + + # Disable alchemical weights gradients by default + self.alchemy_manager.alchemical_weights.requires_grad = False + self.calculate_alchemical_grad = False + + self.num_atoms = len(atoms) + + # Switching + self.defect_index = defect_index + self.spring_constant = spring_constant + + def set_alchemical_weight(self, alchemical_weight: float): + # Set alchemical weights + alchemical_weights = torch.tensor( + [1.0 - alchemical_weight], # initial = original atoms = 1 - 0 + dtype=torch.float32, + device=self.device, + ) + self.alchemy_manager.alchemical_weights.data = alchemical_weights + + # pylint: disable=dangerous-default-value + def calculate(self, atoms=None, properties=None, system_changes=all_changes): + # call to base-class to set atoms attribute + Calculator.calculate(self, atoms) + + # prepare data + tensor_kwargs = {"dtype": torch.float32, "device": self.device} + positions = torch.tensor(atoms.get_positions(), **tensor_kwargs) + cell = torch.tensor(atoms.get_cell().array, **tensor_kwargs) + if self.calculate_alchemical_grad: + self.alchemy_manager.alchemical_weights.requires_grad = True + batch = self.alchemy_manager(positions, cell).to(self.device) + + # get outputs + if self.calculate_alchemical_grad: + out = self.model(batch, retain_graph=True, compute_stress=True) + out["energy"].backward() + grad = self.alchemy_manager.alchemical_weights.grad.item() + self.alchemy_manager.alchemical_weights.grad.zero_() + self.alchemy_manager.alchemical_weights.requires_grad = False + else: + out = self.model(batch, retain_graph=False, compute_stress=True) + grad = 0.0 + mace_energy = out["energy"][0].item() + mace_forces = out["forces"].detach().cpu().numpy() + mace_stress = out["stress"][0].detach().cpu().numpy() + + # Get spring energy and forces + cell_center = np.array([0.5, 0.5, 0.5]) @ atoms.get_cell().array + displacement = atoms.get_positions()[self.defect_index] - cell_center + spring_energy = 0.5 * self.spring_constant * np.sum(displacement**2) + spring_forces = -self.spring_constant * displacement + + # Combine energies and forces + # Note: weight here is 1 - lambda, and we're not weighting the mace + # energy because it's already weighted by the alchemical weight + weight = self.alchemy_manager.alchemical_weights.item() + total_energy = mace_energy + (1 - weight) * spring_energy + total_forces = mace_forces + total_forces[self.defect_index] += (1 - weight) * spring_forces + if self.calculate_alchemical_grad: + # H(lambda) = E(1 - lambda) + lambda * spring_energy + # dH/d(lambda) = -dE/d(1 - lambda) + spring_energy + grad = -grad + spring_energy + + # store results + self.results = {} + self.results["energy"] = total_energy + self.results["free_energy"] = total_energy + self.results["forces"] = total_forces + self.results["stress"] = full_3x3_to_voigt_6_stress(mace_stress) + self.results["alchemical_grad"] = grad + + +def get_alchemical_optimized_cellpar( + atoms: ase.Atoms, + alchemical_pairs: Sequence[Sequence[Tuple[int, int]]], + alchemical_weights: Sequence[float], + model: str = "medium", + device: str = "cpu", + **kwargs, +): + """ + Optimize the cell parameters of a crystal with alchemical atoms using the + Alchemical MACE calculator. + + Args: + atoms (ase.Atoms): Atoms object. + alchemical_pairs (Sequence[Sequence[Tuple[int, int]]]): List of + alchemical pairs. Each pair is a tuple of the atom index and + atomic number of an alchemical atom. + alchemical_weights (Sequence[float]): List of alchemical weights. + model (str): Model to use for the MACE calculator. + device (str): Device to run the calculations on. + + Returns: + np.ndarray: Optimized cell parameters. + """ + # Make a copy of the atoms object + atoms = atoms.copy() + + # Load Alchemical MACE calculator and relax the structure + calc = AlchemicalMACECalculator( + atoms, alchemical_pairs, alchemical_weights, device=device, model=model + ) + atoms.set_calculator(calc) + atoms.set_masses(calc.get_alchemical_atomic_masses()) + atoms = ExpCellFilter(atoms) + optimizer = FIRE(atoms) + optimizer.run(fmax=kwargs.get("fmax", 0.01), steps=kwargs.get("steps", 500)) + + # Return the optimized cell parameters + return atoms.atoms.get_cell().cellpar() diff --git a/alchemical_mace/model.py b/alchemical_mace/model.py new file mode 100644 index 0000000..e9e8e62 --- /dev/null +++ b/alchemical_mace/model.py @@ -0,0 +1,537 @@ +import ast +from collections import namedtuple +from typing import Dict, List, Optional, Sequence, Tuple + +import ase +import numpy as np +import torch +import torch.nn.functional as F +from e3nn import o3 +from e3nn.util.jit import compile_mode +from mace import modules, tools +from mace.calculators import mace_mp +from mace.data.neighborhood import get_neighborhood +from mace.modules import RealAgnosticResidualInteractionBlock, ScaleShiftMACE +from mace.modules.utils import get_edge_vectors_and_lengths, get_symmetric_displacement +from mace.tools import ( + AtomicNumberTable, + atomic_numbers_to_indices, + to_one_hot, + torch_geometric, + utils, +) +from mace.tools.scatter import scatter_sum + +################################################################################ +# Alchemy manager class for handling alchemical weights +################################################################################ + +AlchemicalPair = namedtuple("AlchemicalPair", ["atom_index", "atomic_number"]) + + +class AlchemyManager(torch.nn.Module): + """ + Class for managing alchemical weights and building alchemical graphs for MACE. + """ + + def __init__( + self, + atoms: ase.Atoms, + alchemical_pairs: Sequence[Sequence[Tuple[int, int]]], + alchemical_weights: torch.Tensor, + z_table: AtomicNumberTable, + r_max: float, + ): + """ + Initialize the alchemy manager. + + Args: + atoms: ASE atoms object + alchemical_pairs: List of lists of tuples, where each tuple contains + the atom index and atomic number of an alchemical atom + alchemical_weights: Tensor of alchemical weights + z_table: Atomic number table + r_max: Maximum cutoff radius for the alchemical graph + """ + super().__init__() + self.alchemical_weights = torch.nn.Parameter(alchemical_weights) + self.r_max = r_max + + # Process alchemical pairs into atom indices and atomic numbers + # Alchemical weights are 1-indexed, 0 is reserved for non-alchemical atoms + alchemical_atom_indices = [] + alchemical_atomic_numbers = [] + alchemical_weight_indices = [] + + for weight_idx, pairs in enumerate(alchemical_pairs): + for pair in pairs: + alchemical_atom_indices.append(pair.atom_index) + alchemical_atomic_numbers.append(pair.atomic_number) + alchemical_weight_indices.append(weight_idx + 1) + + non_alchemical_atom_indices = [ + i for i in range(len(atoms)) if i not in alchemical_atom_indices + ] + non_alchemical_atomic_numbers = atoms.get_atomic_numbers()[ + non_alchemical_atom_indices + ].tolist() + non_alchemical_weight_indices = [0] * len(non_alchemical_atom_indices) + + self.atom_indices = alchemical_atom_indices + non_alchemical_atom_indices + self.atomic_numbers = alchemical_atomic_numbers + non_alchemical_atomic_numbers + self.weight_indices = alchemical_weight_indices + non_alchemical_weight_indices + + self.atom_indices = np.array(self.atom_indices) + self.atomic_numbers = np.array(self.atomic_numbers) + self.weight_indices = np.array(self.weight_indices) + + sort_idx = np.argsort(self.atom_indices) + self.atom_indices = self.atom_indices[sort_idx] + self.atomic_numbers = self.atomic_numbers[sort_idx] + self.weight_indices = self.weight_indices[sort_idx] + + # Array to map original atom indices to alchemical indices + # -1 means the atom does not have a corresponding alchemical atom + # [n_atoms, n_weights + 1] + self.original_to_alchemical_index = np.full( + (len(atoms), len(alchemical_pairs) + 1), -1 + ) + for i, (atom_idx, weight_idx) in enumerate( + zip(self.atom_indices, self.weight_indices) + ): + self.original_to_alchemical_index[atom_idx, weight_idx] = i + + self.is_original_atom_alchemical = np.any( + self.original_to_alchemical_index[:, 1:] != -1, axis=1 + ) + + # Extract common node features + z_indices = atomic_numbers_to_indices(self.atomic_numbers, z_table=z_table) + node_attrs = to_one_hot( + torch.tensor(z_indices, dtype=torch.long).unsqueeze(-1), + num_classes=len(z_table), + ) + self.register_buffer("node_attrs", node_attrs) + self.pbc = atoms.get_pbc() + + def forward( + self, + positions: torch.Tensor, + cell: torch.Tensor, + ) -> Dict[str, torch.Tensor]: + """ + Build an alchemical graph for the given positions and cell. + + Args: + positions: Tensor of atomic positions + cell: Tensor of cell vectors + + Returns: + Dictionary containing the alchemical graph data + """ + + # Build original atom graph + orig_edge_index, shifts, unit_shifts = get_neighborhood( + positions=positions.detach().cpu().numpy(), + cutoff=self.r_max, + pbc=self.pbc, + cell=cell.detach().cpu().numpy(), + ) + + # Extend edges to alchemical pairs + edge_index = [] + orig_edge_loc = [] + edge_weight_indices = [] + + is_alchemical = self.is_original_atom_alchemical[orig_edge_index] + src_non_dst_non = ~is_alchemical[0] & ~is_alchemical[1] + src_non_dst_alch = ~is_alchemical[0] & is_alchemical[1] + src_alch_dst_non = is_alchemical[0] & ~is_alchemical[1] + src_alch_dst_alch = is_alchemical[0] & is_alchemical[1] + + # Both non-alchemical: keep as is + _orig_edge_index = orig_edge_index[:, src_non_dst_non] + edge_index.append(self.original_to_alchemical_index[_orig_edge_index, 0]) + orig_edge_loc.append(np.where(src_non_dst_non)[0]) + edge_weight_indices.append(np.zeros_like(_orig_edge_index[0])) + + # Source non-alchemical, destination alchemical: pair all, weights are 1 + _src, _dst = orig_edge_index[:, src_non_dst_alch] + _orig_edge_loc = np.where(src_non_dst_alch)[0] + _src = self.original_to_alchemical_index[_src, 0] + _dst = self.original_to_alchemical_index[_dst, :] + _dst_mask = _dst != -1 + _dst = _dst[_dst_mask] + _repeat = _dst_mask.sum(axis=1) + _src = np.repeat(_src, _repeat) + edge_index.append(np.stack((_src, _dst), axis=0)) + orig_edge_loc.append(np.repeat(_orig_edge_loc, _repeat)) + edge_weight_indices.append(np.zeros_like(_src)) + + # Source alchemical, destination non-alchemical: pair all, follow src weights + _src, _dst = orig_edge_index[:, src_alch_dst_non] + _orig_edge_loc = np.where(src_alch_dst_non)[0] + _src = self.original_to_alchemical_index[_src, :] + _dst = self.original_to_alchemical_index[_dst, 0] + _src_mask = _src != -1 + _src = _src[_src_mask] + _repeat = _src_mask.sum(axis=1) + _dst = np.repeat(_dst, _repeat) + edge_index.append(np.stack((_src, _dst), axis=0)) + orig_edge_loc.append(np.repeat(_orig_edge_loc, _repeat)) + edge_weight_indices.append(np.where(_src_mask)[1]) + + # Both alchemical: pair according to alchemical indices, weights are 1 + _orig_edge_index = orig_edge_index[:, src_alch_dst_alch] + _orig_edge_loc = np.where(src_alch_dst_alch)[0] + _alch_edge_index = self.original_to_alchemical_index[_orig_edge_index, :] + _idx = np.where((_alch_edge_index != -1).all(axis=0)) + edge_index.append(_alch_edge_index[:, _idx[0], _idx[1]]) + orig_edge_loc.append(_orig_edge_loc[_idx[0]]) + edge_weight_indices.append(np.zeros_like(_idx[0])) + + # Collect all edges + edge_index = np.concatenate(edge_index, axis=1) + orig_edge_loc = np.concatenate(orig_edge_loc) + edge_weight_indices = np.concatenate(edge_weight_indices) + + # Convert to torch tensors + edge_index = torch.tensor(edge_index, dtype=torch.long) + shifts = torch.tensor(shifts[orig_edge_loc], dtype=torch.float32) + unit_shifts = torch.tensor(unit_shifts[orig_edge_loc], dtype=torch.float32) + + # Alchemical weights for nodes and edges + weights = F.pad(self.alchemical_weights, (1, 0), "constant", 1.0) + node_weights = weights[self.weight_indices] + edge_weights = weights[edge_weight_indices] + + # Build data batch + atomic_data = torch_geometric.data.Data( + num_nodes=len(self.atom_indices), + edge_index=edge_index, + node_attrs=self.node_attrs, + positions=positions[self.atom_indices], + shifts=shifts, + unit_shifts=unit_shifts, + cell=cell, + node_weights=node_weights, + edge_weights=edge_weights, + node_atom_indices=torch.tensor(self.atom_indices, dtype=torch.long), + ) + data_loader = torch_geometric.dataloader.DataLoader( + dataset=[atomic_data], + batch_size=1, + shuffle=False, + drop_last=False, + ) + batch = next(iter(data_loader)) + + return batch + + +################################################################################ +# Alchemical MACE model +################################################################################ + +# get_outputs function from mace.modules.utils is modified to calculate also +# the alchemical gradients + + +def get_outputs( + energy: torch.Tensor, + positions: torch.Tensor, + displacement: torch.Tensor, + cell: torch.Tensor, + node_weights: torch.Tensor, + edge_weights: torch.Tensor, + retain_graph: bool = False, + create_graph: bool = False, + compute_force: bool = True, + compute_stress: bool = False, + compute_alchemical_grad: bool = False, +) -> Tuple[ + Optional[torch.Tensor], + Optional[torch.Tensor], + Optional[torch.Tensor], + Optional[torch.Tensor], + Optional[torch.Tensor], +]: + grad_outputs: List[Optional[torch.Tensor]] = [torch.ones_like(energy)] + if not compute_force: + return None, None, None, None, None + inputs = [positions] + if compute_stress: + inputs.append(displacement) + if compute_alchemical_grad: + inputs.extend([node_weights, edge_weights]) + gradients = torch.autograd.grad( + outputs=[energy], + inputs=inputs, + grad_outputs=grad_outputs, + retain_graph=retain_graph, + create_graph=create_graph, + allow_unused=True, + ) + + forces = gradients[0] + stress = torch.zeros_like(displacement) + virials = gradients[1] if compute_stress else None + if compute_alchemical_grad: + node_grad, edge_grad = gradients[-2], gradients[-1] + else: + node_grad, edge_grad = None, None + if compute_stress and virials is not None: + cell = cell.view(-1, 3, 3) + volume = torch.einsum( + "zi,zi->z", + cell[:, 0, :], + torch.cross(cell[:, 1, :], cell[:, 2, :], dim=1), + ).unsqueeze(-1) + stress = virials / volume.view(-1, 1, 1) + + if forces is not None: + forces = -1 * forces + if virials is not None: + virials = -1 * virials + return forces, virials, stress, node_grad, edge_grad + + +@compile_mode("script") +class AlchemicalResidualInteractionBlock(RealAgnosticResidualInteractionBlock): + def forward( + self, + node_attrs: torch.Tensor, + node_feats: torch.Tensor, + edge_attrs: torch.Tensor, + edge_feats: torch.Tensor, + edge_index: torch.Tensor, + edge_weights: torch.Tensor, # alchemy + ) -> Tuple[torch.Tensor, torch.Tensor]: + sender = edge_index[0] + receiver = edge_index[1] + num_nodes = node_feats.shape[0] + sc = self.skip_tp(node_feats, node_attrs) + node_feats = self.linear_up(node_feats) + tp_weights = self.conv_tp_weights(edge_feats) + tp_weights = tp_weights * edge_weights[:, None] # alchemy + mji = self.conv_tp( + node_feats[sender], edge_attrs, tp_weights + ) # [n_edges, irreps] + message = scatter_sum( + src=mji, index=receiver, dim=0, dim_size=num_nodes + ) # [n_nodes, irreps] + message = self.linear(message) / self.avg_num_neighbors + return ( + self.reshape(message), + sc, + ) # [n_nodes, channels, (lmax + 1)**2] + + +@compile_mode("script") +class AlchemicalMACE(ScaleShiftMACE): + def forward( + self, + data: Dict[str, torch.Tensor], + retain_graph: bool = False, # alchemy + create_graph: bool = False, # alchemy + compute_force: bool = True, + compute_stress: bool = False, + compute_displacement: bool = False, + compute_alchemical_grad: bool = False, # alchemy + map_to_original_atoms: bool = True, # alchemy + ) -> Dict[str, Optional[torch.Tensor]]: + # Setup + data["positions"].requires_grad_(True) + data["node_attrs"].requires_grad_(True) + num_graphs = data["ptr"].numel() - 1 + displacement = torch.zeros( + (num_graphs, 3, 3), + dtype=data["positions"].dtype, + device=data["positions"].device, + ) + if compute_stress or compute_displacement: + ( + data["positions"], + data["shifts"], + displacement, + ) = get_symmetric_displacement( + positions=data["positions"], + unit_shifts=data["unit_shifts"], + cell=data["cell"], + edge_index=data["edge_index"], + num_graphs=num_graphs, + batch=data["batch"], + ) + + # Atomic energies + node_e0 = self.atomic_energies_fn(data["node_attrs"]) + node_e0 = node_e0 * data["node_weights"] # alchemy + e0 = scatter_sum( + src=node_e0, index=data["batch"], dim=-1, dim_size=num_graphs + ) # [n_graphs,] + + # Embeddings + node_feats = self.node_embedding(data["node_attrs"]) + vectors, lengths = get_edge_vectors_and_lengths( + positions=data["positions"], + edge_index=data["edge_index"], + shifts=data["shifts"], + ) + edge_attrs = self.spherical_harmonics(vectors) + edge_feats = self.radial_embedding(lengths) + + # Interactions + node_es_list = [] + node_feats_list = [] + for interaction, product, readout in zip( + self.interactions, self.products, self.readouts + ): + node_feats, sc = interaction( + node_attrs=data["node_attrs"], + node_feats=node_feats, + edge_attrs=edge_attrs, + edge_feats=edge_feats, + edge_index=data["edge_index"], + edge_weights=data["edge_weights"], # alchemy + ) + node_feats = product( + node_feats=node_feats, sc=sc, node_attrs=data["node_attrs"] + ) + node_feats_list.append(node_feats) + node_es_list.append(readout(node_feats).squeeze(-1)) # {[n_nodes, ], } + + # Concatenate node features + # node_feats_out = torch.cat(node_feats_list, dim=-1) + + # Sum over interactions + node_inter_es = torch.sum( + torch.stack(node_es_list, dim=0), dim=0 + ) # [n_nodes, ] + node_inter_es = self.scale_shift(node_inter_es) + node_inter_es = node_inter_es * data["node_weights"] # alchemy + + # Sum over nodes in graph + inter_e = scatter_sum( + src=node_inter_es, index=data["batch"], dim=-1, dim_size=num_graphs + ) # [n_graphs,] + + # Add E_0 and (scaled) interaction energy + total_energy = e0 + inter_e + node_energy = node_e0 + node_inter_es + + forces, virials, stress, node_grad, edge_grad = get_outputs( + energy=total_energy, # alchemy + positions=data["positions"], + displacement=displacement, + cell=data["cell"], + node_weights=data["node_weights"], # alchemy + edge_weights=data["edge_weights"], # alchemy + retain_graph=retain_graph, # alchemy + create_graph=create_graph, # alchemy + compute_force=compute_force, + # compute_virials=compute_virials, # alchemy + compute_stress=compute_stress, + compute_alchemical_grad=compute_alchemical_grad, # alchemy + ) + + # Map to original atoms (node energies and forces): alchemy + if map_to_original_atoms: + # Note: we're not giving the dim_size, as we assume that all + # original atoms are present in the batch + node_index = data["node_atom_indices"] + node_energy = scatter_sum(src=node_energy, dim=0, index=node_index) + if compute_force: + forces = scatter_sum(src=forces, dim=0, index=node_index) + + output = { + "energy": total_energy, + "node_energy": node_energy, + "interaction_energy": inter_e, + "forces": forces, + "virials": virials, + "stress": stress, + "displacement": displacement, + "node_grad": node_grad, + "edge_grad": edge_grad, + } + + return output + + +################################################################################ +# Alchemical MACE universal model +################################################################################ + + +def alchemical_mace_mp( + model: str, + device: str, + default_dtype: str = "float32", +): + """ + Load a pre-trained alchemical MACE model. + + Args: + model: Model size (small, medium) + device: Device to load the model onto + default_dtype: Default data type for the model + + Returns: + Alchemical MACE model + """ + + # Load foundation MACE model and extract initial parameters + assert model in ("small", "medium") # TODO: support large model + mace = mace_mp(model=model, device=device, default_dtype=default_dtype).models[0] + atomic_energies = mace.atomic_energies_fn.atomic_energies.detach().clone() + z_table = utils.AtomicNumberTable([int(z) for z in mace.atomic_numbers]) + atomic_inter_scale = mace.scale_shift.scale.detach().clone() + atomic_inter_shift = mace.scale_shift.shift.detach().clone() + + # Prepare arguments for building the model + placeholder_args = ["--name", "None", "--train_file", "None"] + args = tools.build_default_arg_parser().parse_args(placeholder_args) + args.max_L = {"small": 0, "medium": 1, "large": 2}[model] + args.num_channels = 128 + args.hidden_irreps = o3.Irreps( + (args.num_channels * o3.Irreps.spherical_harmonics(args.max_L)) + .sort() + .irreps.simplify() + ) + + # Build the alchemical MACE model + model = AlchemicalMACE( + r_max=6.0, + num_bessel=10, + num_polynomial_cutoff=5, + max_ell=3, + interaction_cls=AlchemicalResidualInteractionBlock, + interaction_cls_first=AlchemicalResidualInteractionBlock, + num_interactions=2, + num_elements=len(z_table), + hidden_irreps=o3.Irreps(args.hidden_irreps), + MLP_irreps=o3.Irreps(args.MLP_irreps), + atomic_energies=atomic_energies, + avg_num_neighbors=args.avg_num_neighbors, + atomic_numbers=z_table.zs, + correlation=args.correlation, + gate=modules.gate_dict[args.gate], + radial_MLP=ast.literal_eval(args.radial_MLP), + radial_type=args.radial_type, + atomic_inter_scale=atomic_inter_scale, + atomic_inter_shift=atomic_inter_shift, + ) + + # Load foundation model parameters + model.load_state_dict(mace.state_dict(), strict=True) + for i in range(int(model.num_interactions)): + model.interactions[i].avg_num_neighbors = mace.interactions[i].avg_num_neighbors + model = model.to(device) + return model + + +def get_z_table_and_r_max(model: ScaleShiftMACE) -> Tuple[AtomicNumberTable, float]: + """Extract the atomic number table and maximum cutoff radius from a MACE model.""" + z_table = AtomicNumberTable([int(z) for z in model.atomic_numbers]) + r_max = model.r_max.item() + return z_table, r_max diff --git a/alchemical_mace/optimize.py b/alchemical_mace/optimize.py new file mode 100644 index 0000000..87171be --- /dev/null +++ b/alchemical_mace/optimize.py @@ -0,0 +1,29 @@ +import torch +from typing import Union, Iterable, Dict, Any + + +class ExponentiatedGradientDescent(torch.optim.Optimizer): + """ + Implements Exponentiated Gradient Descent. + + Args: + params (iterable of torch.Tensor or dict): iterable of parameters to optimize or + dicts defining parameter groups. + lr (float, optional): learning rate. Defaults to 1e-3. + eps (float, optional): small constant for numerical stability. Defaults to 1e-8. + """ + def __init__( + self, + params: Union[Iterable[torch.Tensor], Iterable[Dict[str, Any]]], + lr: float = 1e-3, + eps: float = 1e-8, + ): + super().__init__(params, defaults=dict(lr=lr, eps=eps)) + + def step(self): + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + p.data.mul_(torch.exp(-group["lr"] * p.grad)) + p.data.div_(p.data.sum() + group["eps"]) diff --git a/alchemical_mace/utils.py b/alchemical_mace/utils.py new file mode 100644 index 0000000..d37cf9c --- /dev/null +++ b/alchemical_mace/utils.py @@ -0,0 +1,43 @@ +import os +from contextlib import ExitStack, contextmanager, redirect_stderr, redirect_stdout + +from ase import Atoms + + +@contextmanager +def suppress_print(out: bool = True, err: bool = False): + """Suppress stdout and/or stderr.""" + + with ExitStack() as stack: + devnull = stack.enter_context(open(os.devnull, "w")) + if out: + stack.enter_context(redirect_stdout(devnull)) + if err: + stack.enter_context(redirect_stderr(devnull)) + yield + + +# From CHGNet +def upper_triangular_cell(atoms: Atoms): + """Transform to upper-triangular cell.""" + import numpy as np + from ase.md.npt import NPT + + if NPT._isuppertriangular(atoms.get_cell()): + return + + a, b, c, alpha, beta, gamma = atoms.cell.cellpar() + angles = np.radians((alpha, beta, gamma)) + sin_a, sin_b, _sin_g = np.sin(angles) + cos_a, cos_b, cos_g = np.cos(angles) + cos_p = (cos_g - cos_a * cos_b) / (sin_a * sin_b) + cos_p = np.clip(cos_p, -1, 1) + sin_p = (1 - cos_p**2) ** 0.5 + + new_basis = [ + (a * sin_b * sin_p, a * sin_b * cos_p, a * cos_b), + (0, b * sin_a, b * cos_a), + (0, 0, c), + ] + + atoms.set_cell(new_basis, scale_atoms=True) diff --git a/data/structures/AlN_hex.cif b/data/structures/AlN_hex.cif new file mode 100644 index 0000000..ef9bf10 --- /dev/null +++ b/data/structures/AlN_hex.cif @@ -0,0 +1,30 @@ +# generated using pymatgen +data_AlN +_symmetry_space_group_name_H-M 'P 1' +_cell_length_a 3.12664153 +_cell_length_b 3.12664143 +_cell_length_c 5.00715332 +_cell_angle_alpha 90.00000043 +_cell_angle_beta 89.99999986 +_cell_angle_gamma 119.99999965 +_symmetry_Int_Tables_number 1 +_chemical_formula_structural AlN +_chemical_formula_sum 'Al2 N2' +_cell_volume 42.39139351 +_cell_formula_units_Z 2 +loop_ + _symmetry_equiv_pos_site_id + _symmetry_equiv_pos_as_xyz + 1 'x, y, z' +loop_ + _atom_site_type_symbol + _atom_site_label + _atom_site_symmetry_multiplicity + _atom_site_fract_x + _atom_site_fract_y + _atom_site_fract_z + _atom_site_occupancy + Al Al0 1 0.66666667 0.33333333 0.49932157 1 + Al Al1 1 0.33333333 0.66666667 0.99932156 1 + N N2 1 0.66666667 0.33333333 0.88067843 1 + N N3 1 0.33333333 0.66666667 0.38067844 1 diff --git a/data/structures/BiSBr.cif b/data/structures/BiSBr.cif new file mode 100644 index 0000000..e9bbad3 --- /dev/null +++ b/data/structures/BiSBr.cif @@ -0,0 +1,38 @@ +# generated using pymatgen +data_BiSBr +_symmetry_space_group_name_H-M 'P 1' +_cell_length_a 4.10679210 +_cell_length_b 8.22484633 +_cell_length_c 11.05701800 +_cell_angle_alpha 89.99999573 +_cell_angle_beta 90.00000347 +_cell_angle_gamma 90.00001971 +_symmetry_Int_Tables_number 1 +_chemical_formula_structural BiSBr +_chemical_formula_sum 'Bi4 S4 Br4' +_cell_volume 373.48101173 +_cell_formula_units_Z 4 +loop_ + _symmetry_equiv_pos_site_id + _symmetry_equiv_pos_as_xyz + 1 'x, y, z' +loop_ + _atom_site_type_symbol + _atom_site_label + _atom_site_symmetry_multiplicity + _atom_site_fract_x + _atom_site_fract_y + _atom_site_fract_z + _atom_site_occupancy + Bi Bi0 1 0.24999992 0.88556832 0.87177467 1 + Bi Bi1 1 0.75000010 0.11443170 0.12822534 1 + Bi Bi2 1 0.24999988 0.38556830 0.62822532 1 + Bi Bi3 1 0.75000009 0.61443169 0.37177466 1 + S S4 1 0.74999990 0.82391374 0.03167750 1 + S S5 1 0.25000008 0.67608625 0.53167751 1 + S S6 1 0.74999991 0.32391375 0.46832249 1 + S S7 1 0.25000010 0.17608625 0.96832251 1 + Br Br8 1 0.24999990 0.96479764 0.30006197 1 + Br Br9 1 0.75000009 0.53520235 0.80006197 1 + Br Br10 1 0.24999991 0.46479765 0.19993803 1 + Br Br11 1 0.75000013 0.03520235 0.69993804 1 diff --git a/data/structures/CeO2.cif b/data/structures/CeO2.cif new file mode 100644 index 0000000..ad4b557 --- /dev/null +++ b/data/structures/CeO2.cif @@ -0,0 +1,38 @@ +# generated using pymatgen +data_CeO2 +_symmetry_space_group_name_H-M 'P 1' +_cell_length_a 5.46789061 +_cell_length_b 5.46789009 +_cell_length_c 5.46788984 +_cell_angle_alpha 90.00000180 +_cell_angle_beta 90.00000282 +_cell_angle_gamma 90.00000072 +_symmetry_Int_Tables_number 1 +_chemical_formula_structural CeO2 +_chemical_formula_sum 'Ce4 O8' +_cell_volume 163.47801292 +_cell_formula_units_Z 4 +loop_ + _symmetry_equiv_pos_site_id + _symmetry_equiv_pos_as_xyz + 1 'x, y, z' +loop_ + _atom_site_type_symbol + _atom_site_label + _atom_site_symmetry_multiplicity + _atom_site_fract_x + _atom_site_fract_y + _atom_site_fract_z + _atom_site_occupancy + Ce Ce0 1 0.00000000 -0.00000000 -0.00000000 1 + Ce Ce1 1 0.00000000 0.50000000 0.50000000 1 + Ce Ce2 1 0.50000000 0.00000000 0.50000000 1 + Ce Ce3 1 0.50000000 0.50000000 0.00000000 1 + O O4 1 0.25000000 0.75000000 0.74999999 1 + O O5 1 0.74999999 0.25000000 0.25000000 1 + O O6 1 0.25000000 0.25000000 0.25000000 1 + O O7 1 0.75000000 0.75000000 0.75000000 1 + O O8 1 0.74999999 0.75000000 0.25000000 1 + O O9 1 0.25000000 0.25000000 0.74999999 1 + O O10 1 0.75000000 0.25000000 0.75000000 1 + O O11 1 0.25000000 0.75000000 0.25000000 1 diff --git a/data/structures/CsPbI3_alpha.cif b/data/structures/CsPbI3_alpha.cif new file mode 100644 index 0000000..8f0984d --- /dev/null +++ b/data/structures/CsPbI3_alpha.cif @@ -0,0 +1,31 @@ +# generated using pymatgen +data_CsPbI3 +_symmetry_space_group_name_H-M 'P 1' +_cell_length_a 6.37904246 +_cell_length_b 6.37904251 +_cell_length_c 6.37904247 +_cell_angle_alpha 90.00000013 +_cell_angle_beta 89.99999979 +_cell_angle_gamma 89.99999953 +_symmetry_Int_Tables_number 1 +_chemical_formula_structural CsPbI3 +_chemical_formula_sum 'Cs1 Pb1 I3' +_cell_volume 259.57716340 +_cell_formula_units_Z 1 +loop_ + _symmetry_equiv_pos_site_id + _symmetry_equiv_pos_as_xyz + 1 'x, y, z' +loop_ + _atom_site_type_symbol + _atom_site_label + _atom_site_symmetry_multiplicity + _atom_site_fract_x + _atom_site_fract_y + _atom_site_fract_z + _atom_site_occupancy + Cs Cs0 1 0.00000000 0.00000000 -0.00000000 1 + I I1 1 0.50000000 0.50000000 -0.00000000 1 + I I2 1 0.50000000 -0.00000000 0.50000000 1 + I I3 1 0.00000000 0.50000000 0.50000000 1 + Pb Pb4 1 0.50000000 0.50000000 0.50000000 1 diff --git a/data/structures/CsPbI3_delta.cif b/data/structures/CsPbI3_delta.cif new file mode 100644 index 0000000..1313893 --- /dev/null +++ b/data/structures/CsPbI3_delta.cif @@ -0,0 +1,46 @@ +# generated using pymatgen +data_CsPbI3 +_symmetry_space_group_name_H-M 'P 1' +_cell_length_a 4.91187765 +_cell_length_b 10.78602159 +_cell_length_c 18.12941751 +_cell_angle_alpha 90.00000734 +_cell_angle_beta 89.99999876 +_cell_angle_gamma 89.99999900 +_symmetry_Int_Tables_number 1 +_chemical_formula_structural CsPbI3 +_chemical_formula_sum 'Cs4 Pb4 I12' +_cell_volume 960.48962177 +_cell_formula_units_Z 4 +loop_ + _symmetry_equiv_pos_site_id + _symmetry_equiv_pos_as_xyz + 1 'x, y, z' +loop_ + _atom_site_type_symbol + _atom_site_label + _atom_site_symmetry_multiplicity + _atom_site_fract_x + _atom_site_fract_y + _atom_site_fract_z + _atom_site_occupancy + Cs Cs0 1 0.75000002 0.57436836 0.17329066 1 + Cs Cs1 1 0.25000000 0.42563165 0.82670935 1 + Cs Cs2 1 0.75000002 0.07436835 0.32670935 1 + Cs Cs3 1 0.24999999 0.92563166 0.67329065 1 + I I4 1 0.74999999 0.20706041 0.71324789 1 + I I5 1 0.25000000 0.79293960 0.28675211 1 + I I6 1 0.75000001 0.97331566 0.10968932 1 + I I7 1 0.24999998 0.02668433 0.89031069 1 + I I8 1 0.74999999 0.47331567 0.39031069 1 + I I9 1 0.25000001 0.52668433 0.60968931 1 + I I10 1 0.25000000 0.66460269 0.00429655 1 + I I11 1 0.25000000 0.16460269 0.49570344 1 + I I12 1 0.25000000 0.29293960 0.21324788 1 + I I13 1 0.75000001 0.83539729 0.50429655 1 + I I14 1 0.75000000 0.33539731 0.99570345 1 + I I15 1 0.75000001 0.70706040 0.78675212 1 + Pb Pb16 1 0.74999998 0.83743203 0.94239546 1 + Pb Pb17 1 0.24999999 0.16256797 0.05760453 1 + Pb Pb18 1 0.75000002 0.33743205 0.55760454 1 + Pb Pb19 1 0.24999998 0.66256796 0.44239545 1 diff --git a/data/structures/CsSnI3_alpha.cif b/data/structures/CsSnI3_alpha.cif new file mode 100644 index 0000000..87f7b76 --- /dev/null +++ b/data/structures/CsSnI3_alpha.cif @@ -0,0 +1,31 @@ +# generated using pymatgen +data_CsSnI3 +_symmetry_space_group_name_H-M 'P 1' +_cell_length_a 6.27081715 +_cell_length_b 6.27081724 +_cell_length_c 6.27081711 +_cell_angle_alpha 89.99999964 +_cell_angle_beta 89.99999969 +_cell_angle_gamma 89.99999972 +_symmetry_Int_Tables_number 1 +_chemical_formula_structural CsSnI3 +_chemical_formula_sum 'Cs1 Sn1 I3' +_cell_volume 246.58827085 +_cell_formula_units_Z 1 +loop_ + _symmetry_equiv_pos_site_id + _symmetry_equiv_pos_as_xyz + 1 'x, y, z' +loop_ + _atom_site_type_symbol + _atom_site_label + _atom_site_symmetry_multiplicity + _atom_site_fract_x + _atom_site_fract_y + _atom_site_fract_z + _atom_site_occupancy + Cs Cs0 1 0.00000000 0.00000000 -0.00000000 1 + I I1 1 0.50000000 0.50000000 0.00000000 1 + I I2 1 0.50000000 -0.00000000 0.50000000 1 + I I3 1 -0.00000000 0.50000000 0.50000000 1 + Sn Sn4 1 0.50000000 0.50000000 0.50000000 1 diff --git a/data/structures/CsSnI3_delta.cif b/data/structures/CsSnI3_delta.cif new file mode 100644 index 0000000..e7054a0 --- /dev/null +++ b/data/structures/CsSnI3_delta.cif @@ -0,0 +1,46 @@ +# generated using pymatgen +data_CsSnI3 +_symmetry_space_group_name_H-M 'P 1' +_cell_length_a 4.84918880 +_cell_length_b 10.69167480 +_cell_length_c 18.19616575 +_cell_angle_alpha 89.99999030 +_cell_angle_beta 90.00000262 +_cell_angle_gamma 89.99999762 +_symmetry_Int_Tables_number 1 +_chemical_formula_structural CsSnI3 +_chemical_formula_sum 'Cs4 Sn4 I12' +_cell_volume 943.39749344 +_cell_formula_units_Z 4 +loop_ + _symmetry_equiv_pos_site_id + _symmetry_equiv_pos_as_xyz + 1 'x, y, z' +loop_ + _atom_site_type_symbol + _atom_site_label + _atom_site_symmetry_multiplicity + _atom_site_fract_x + _atom_site_fract_y + _atom_site_fract_z + _atom_site_occupancy + Cs Cs0 1 0.75000000 0.57552416 0.17186806 1 + Cs Cs1 1 0.25000001 0.42447584 0.82813194 1 + Cs Cs2 1 0.75000000 0.07552414 0.32813195 1 + Cs Cs3 1 0.25000000 0.92447584 0.67186806 1 + I I4 1 0.75000000 0.21397041 0.70606567 1 + I I5 1 0.25000000 0.78602961 0.29393431 1 + I I6 1 0.75000000 0.97123818 0.11227438 1 + I I7 1 0.25000000 0.02876181 0.88772566 1 + I I8 1 0.74999999 0.47123818 0.38772564 1 + I I9 1 0.25000001 0.52876182 0.61227437 1 + I I10 1 0.25000002 0.66813711 0.00020614 1 + I I11 1 0.24999999 0.16813710 0.49979388 1 + I I12 1 0.25000001 0.28602959 0.20606567 1 + I I13 1 0.74999999 0.83186292 0.50020614 1 + I I14 1 0.75000000 0.33186292 0.99979386 1 + I I15 1 0.74999999 0.71397042 0.79393433 1 + Sn Sn16 1 0.75000000 0.84420702 0.94393743 1 + Sn Sn17 1 0.25000000 0.15579296 0.05606252 1 + Sn Sn18 1 0.74999998 0.34420702 0.55606253 1 + Sn Sn19 1 0.24999999 0.65579297 0.44393746 1 diff --git a/data/structures/Fe.cif b/data/structures/Fe.cif new file mode 100644 index 0000000..18422ee --- /dev/null +++ b/data/structures/Fe.cif @@ -0,0 +1,28 @@ +# generated using pymatgen +data_Fe +_symmetry_space_group_name_H-M 'P 1' +_cell_length_a 2.86106543 +_cell_length_b 2.86106544 +_cell_length_c 2.86106538 +_cell_angle_alpha 90.00000018 +_cell_angle_beta 89.99999992 +_cell_angle_gamma 90.00000009 +_symmetry_Int_Tables_number 1 +_chemical_formula_structural Fe +_chemical_formula_sum Fe2 +_cell_volume 23.41980977 +_cell_formula_units_Z 2 +loop_ + _symmetry_equiv_pos_site_id + _symmetry_equiv_pos_as_xyz + 1 'x, y, z' +loop_ + _atom_site_type_symbol + _atom_site_label + _atom_site_symmetry_multiplicity + _atom_site_fract_x + _atom_site_fract_y + _atom_site_fract_z + _atom_site_occupancy + Fe Fe0 1 0.00000000 -0.00000000 0.00000000 1 + Fe Fe1 1 0.50000000 0.50000000 0.50000000 1 diff --git a/data/structures/GaN_hex.cif b/data/structures/GaN_hex.cif new file mode 100644 index 0000000..86d63dc --- /dev/null +++ b/data/structures/GaN_hex.cif @@ -0,0 +1,30 @@ +# generated using pymatgen +data_GaN +_symmetry_space_group_name_H-M 'P 1' +_cell_length_a 3.21192371 +_cell_length_b 3.21192377 +_cell_length_c 5.21628467 +_cell_angle_alpha 90.00000055 +_cell_angle_beta 90.00000024 +_cell_angle_gamma 119.99999991 +_symmetry_Int_Tables_number 1 +_chemical_formula_structural GaN +_chemical_formula_sum 'Ga2 N2' +_cell_volume 46.60391121 +_cell_formula_units_Z 2 +loop_ + _symmetry_equiv_pos_site_id + _symmetry_equiv_pos_as_xyz + 1 'x, y, z' +loop_ + _atom_site_type_symbol + _atom_site_label + _atom_site_symmetry_multiplicity + _atom_site_fract_x + _atom_site_fract_y + _atom_site_fract_z + _atom_site_occupancy + Ga Ga0 1 0.66666667 0.33333333 0.49900050 1 + Ga Ga1 1 0.33333333 0.66666667 0.99900050 1 + N N2 1 0.66666667 0.33333333 0.87599950 1 + N N3 1 0.33333333 0.66666667 0.37599950 1 diff --git a/data/structures/NaCl.cif b/data/structures/NaCl.cif new file mode 100644 index 0000000..ece9ba4 --- /dev/null +++ b/data/structures/NaCl.cif @@ -0,0 +1,34 @@ +# generated using pymatgen +data_NaCl +_symmetry_space_group_name_H-M 'P 1' +_cell_length_a 5.68304678 +_cell_length_b 5.68304679 +_cell_length_c 5.68304672 +_cell_angle_alpha 90.00000009 +_cell_angle_beta 90.00000015 +_cell_angle_gamma 90.00000005 +_symmetry_Int_Tables_number 1 +_chemical_formula_structural NaCl +_chemical_formula_sum 'Na4 Cl4' +_cell_volume 183.54547783 +_cell_formula_units_Z 4 +loop_ + _symmetry_equiv_pos_site_id + _symmetry_equiv_pos_as_xyz + 1 'x, y, z' +loop_ + _atom_site_type_symbol + _atom_site_label + _atom_site_symmetry_multiplicity + _atom_site_fract_x + _atom_site_fract_y + _atom_site_fract_z + _atom_site_occupancy + Na Na0 1 -0.00000000 -0.00000000 -0.00000000 1 + Na Na1 1 -0.00000000 0.50000000 0.50000000 1 + Na Na2 1 0.50000000 -0.00000000 0.50000000 1 + Na Na3 1 0.50000000 0.50000000 -0.00000000 1 + Cl Cl4 1 0.00000000 0.00000000 0.50000000 1 + Cl Cl5 1 -0.00000000 0.50000000 0.00000000 1 + Cl Cl6 1 0.50000000 0.00000000 -0.00000000 1 + Cl Cl7 1 0.50000000 0.50000000 0.50000000 1 diff --git a/notebooks/1_solid_solution.ipynb b/notebooks/1_solid_solution.ipynb new file mode 100644 index 0000000..3c2101b --- /dev/null +++ b/notebooks/1_solid_solution.ipynb @@ -0,0 +1,253 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import ase\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "from tqdm import tqdm\n", + "\n", + "from alchemical_mace.calculator import get_alchemical_optimized_cellpar\n", + "from alchemical_mace.model import AlchemicalPair\n", + "from alchemical_mace.utils import suppress_print\n", + "\n", + "plt.style.use(\"default\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### CeO2" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 21/21 [01:15<00:00, 3.60s/it]\n" + ] + } + ], + "source": [ + "# Default settings\n", + "model = \"medium\"\n", + "device = \"cpu\"\n", + "\n", + "# Load structure\n", + "atoms = ase.io.read(\"../data/structures/CeO2.cif\")\n", + "alch_elements = [\"Ce\", \"Sn\"]\n", + "alch_idx = [i for i, atom in enumerate(atoms) if atom.symbol in alch_elements]\n", + "alch_atomic_numbers = [ase.Atoms(el).numbers[0] for el in alch_elements]\n", + "alchemical_pairs = [\n", + " [AlchemicalPair(atom_index=idx, atomic_number=z) for idx in alch_idx]\n", + " for z in alch_atomic_numbers\n", + "]\n", + "\n", + "comp_grid = [[1 - x, x] for x in np.linspace(0, 0.5, 21)]\n", + "lat_params_CeSn = []\n", + "for comp in tqdm(comp_grid):\n", + " with suppress_print(out=True, err=True):\n", + " cellpar = get_alchemical_optimized_cellpar(\n", + " atoms, alchemical_pairs, comp, model=model, device=device\n", + " )\n", + " lat_params_CeSn.append(cellpar[:3])" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 21/21 [01:35<00:00, 4.52s/it]\n" + ] + } + ], + "source": [ + "# Load structure\n", + "alch_elements = [\"Ce\", \"Zr\"]\n", + "alch_idx = [i for i, atom in enumerate(atoms) if atom.symbol in alch_elements]\n", + "alch_atomic_numbers = [ase.Atoms(el).numbers[0] for el in alch_elements]\n", + "alchemical_pairs = [\n", + " [AlchemicalPair(atom_index=idx, atomic_number=z) for idx in alch_idx]\n", + " for z in alch_atomic_numbers\n", + "]\n", + "\n", + "comp_grid = [[1 - x, x] for x in np.linspace(0, 0.5, 21)]\n", + "lat_params_CeZr = []\n", + "for comp in tqdm(comp_grid):\n", + " with suppress_print(out=True, err=True):\n", + " cellpar = get_alchemical_optimized_cellpar(\n", + " atoms, alchemical_pairs, comp, model=model, device=device\n", + " )\n", + " lat_params_CeZr.append(cellpar[:3])" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fig, ax = plt.subplots(figsize=(3, 2.5))\n", + "ax.plot(\n", + " [x[1] for x in comp_grid],\n", + " [x[0] for x in lat_params_CeSn],\n", + " label=\"Ce$_{1-x}$Sn$_x$O$_2$\",\n", + ")\n", + "ax.plot(\n", + " [x[1] for x in comp_grid],\n", + " [x[0] for x in lat_params_CeZr],\n", + " label=\"Ce$_{1-x}$Zr$_x$O$_2$\",\n", + ")\n", + "ax.set_xlabel(\"$x$\")\n", + "ax.set_ylabel(\"a [Å]\")\n", + "ax.legend()\n", + "fig.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### BiSBr" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "# Load structure\n", + "atoms = ase.io.read(\"../data/structures/BiSBr.cif\")\n", + "halide_elements = [\"Cl\", \"Br\", \"I\"]\n", + "halide_idx = [i for i, atom in enumerate(atoms) if atom.symbol in halide_elements]\n", + "halide_atomic_numbers = [ase.Atoms(el).numbers[0] for el in halide_elements]\n", + "alchemical_pairs = [\n", + " [AlchemicalPair(atom_index=idx, atomic_number=z) for idx in halide_idx]\n", + " for z in halide_atomic_numbers\n", + "]" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 21/21 [08:12<00:00, 23.46s/it]\n", + "100%|██████████| 21/21 [04:21<00:00, 12.45s/it]\n", + "100%|██████████| 21/21 [05:35<00:00, 15.95s/it]\n" + ] + } + ], + "source": [ + "comp_grid = [[1 - x, x, 0] for x in np.linspace(0, 1, 21)]\n", + "lat_params_ClBr = []\n", + "for comp in tqdm(comp_grid):\n", + " with suppress_print(out=True, err=True):\n", + " cellpar = get_alchemical_optimized_cellpar(\n", + " atoms, alchemical_pairs, comp, model=model, device=device\n", + " )\n", + " lat_params_ClBr.append(cellpar[:3])\n", + "\n", + "comp_grid = [[0, 1 - x, x] for x in np.linspace(0, 1, 21)]\n", + "lat_params_BrI = []\n", + "for comp in tqdm(comp_grid):\n", + " with suppress_print(out=True, err=True):\n", + " cellpar = get_alchemical_optimized_cellpar(\n", + " atoms, alchemical_pairs, comp, model=model, device=device\n", + " )\n", + " lat_params_BrI.append(cellpar[:3])\n", + "\n", + "\n", + "comp_grid = [[1 - x, 0, x] for x in np.linspace(0, 1, 21)]\n", + "lat_params_ClI = []\n", + "for comp in tqdm(comp_grid):\n", + " with suppress_print(out=True, err=True):\n", + " cellpar = get_alchemical_optimized_cellpar(\n", + " atoms, alchemical_pairs, comp, model=model, device=device\n", + " )\n", + " lat_params_ClI.append(cellpar[:3])" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fig, ax = plt.subplots(figsize=(3.5, 2.5), layout=\"constrained\")\n", + "comp = np.linspace(0, 1, 21)\n", + "idx, param = 2, \"c\"\n", + "ax.plot(comp, [lat[idx] for lat in lat_params_ClBr], label=\"BiSCl$_{1-x}$Br$_x$\")\n", + "ax.plot(comp, [lat[idx] for lat in lat_params_BrI], label=\"BiSBr$_{1-x}$I$_x$\")\n", + "ax.plot(comp, [lat[idx] for lat in lat_params_ClI], label=\"BiSCl$_{1-x}$I$_x$\")\n", + "ax.legend()\n", + "ax.set_xlabel(\"$x$\")\n", + "ax.set_ylabel(f\"{param} [Å]\")\n", + "fig.show()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "chgnet", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/notebooks/2_compositional_optimization.ipynb b/notebooks/2_compositional_optimization.ipynb new file mode 100644 index 0000000..34cee6c --- /dev/null +++ b/notebooks/2_compositional_optimization.ipynb @@ -0,0 +1,368 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import ase\n", + "import matplotlib.patheffects as pe\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import torch\n", + "from mpltern.datasets import get_triangular_grid\n", + "from tqdm import tqdm\n", + "\n", + "from alchemical_mace.calculator import get_alchemical_optimized_cellpar\n", + "from alchemical_mace.model import (\n", + " AlchemicalPair,\n", + " AlchemyManager,\n", + " alchemical_mace_mp,\n", + " get_z_table_and_r_max,\n", + ")\n", + "from alchemical_mace.optimize import ExponentiatedGradientDescent\n", + "from alchemical_mace.utils import suppress_print\n", + "\n", + "plt.style.use(\"default\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Cell parameter scan" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# Default settings\n", + "model = \"medium\"\n", + "device = \"cpu\"\n", + "\n", + "# Load pre-optimized structure (medium model)\n", + "atoms = ase.io.read(\"../data/structures/NaCl.cif\")\n", + "\n", + "# Construct the alchemical pairs\n", + "alkali_elements = [\"Li\", \"Na\", \"K\"]\n", + "alkali_idx = [i for i, atom in enumerate(atoms) if atom.symbol in alkali_elements]\n", + "alkali_atomic_numbers = [ase.Atoms(el).numbers[0] for el in alkali_elements]\n", + "alchemical_pairs = [\n", + " [AlchemicalPair(atom_index=idx, atomic_number=z) for idx in alkali_idx]\n", + " for z in alkali_atomic_numbers\n", + "]" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 105/105 [05:25<00:00, 3.10s/it]\n" + ] + } + ], + "source": [ + "# Compute lattice parameters for a triangular grid of compositions\n", + "comp_grid = np.array(get_triangular_grid(14))\n", + "lat_params = []\n", + "for comp in tqdm(comp_grid.T):\n", + " with suppress_print(out=True, err=True):\n", + " cellpar = get_alchemical_optimized_cellpar(\n", + " atoms, alchemical_pairs, comp, model=model, device=device\n", + " )\n", + " lat_params.append(np.mean(cellpar[:3])) # a = b = c" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fig = plt.figure(figsize=(3, 3), layout=\"constrained\")\n", + "ax = plt.subplot(projection=\"ternary\")\n", + "\n", + "cs = ax.tripcolor(*comp_grid, lat_params, cmap=\"Purples\", shading=\"gouraud\")\n", + "cax = ax.inset_axes([1.1, 0.3, 0.075, 0.8], transform=ax.transAxes)\n", + "cbar = fig.colorbar(cs, cax=cax)\n", + "cbar.set_label(\"Lattice const. [Å]\", rotation=270, va=\"baseline\")\n", + "\n", + "cs = ax.tricontour(*comp_grid, lat_params, colors=\"k\", linewidths=0.5, levels=6)\n", + "clabels = ax.clabel(cs)\n", + "for txt in clabels:\n", + " txt.set_fontsize(8)\n", + " txt.set_path_effects([pe.Stroke(linewidth=1.5, foreground=\"white\"), pe.Normal()])\n", + "ax.set_tlabel(\"Li\", weight=\"bold\", fontsize=12)\n", + "ax.set_llabel(\"Na\", weight=\"bold\", fontsize=12)\n", + "ax.set_rlabel(\"K\", weight=\"bold\", fontsize=12)\n", + "fig.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Stress gradient visualization" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "# Load alchemical MACE model\n", + "with suppress_print(out=True, err=True):\n", + " mace = alchemical_mace_mp(model=model, device=device, default_dtype=\"float32\")\n", + "for param in mace.parameters():\n", + " param.requires_grad = False\n", + "\n", + "# Set AlchemyManager\n", + "z_table, r_max = get_z_table_and_r_max(mace)\n", + "alchemical_weights = torch.ones(3, dtype=torch.float32) / 3\n", + "alchemy_manager = AlchemyManager(\n", + " atoms=atoms,\n", + " alchemical_pairs=alchemical_pairs,\n", + " alchemical_weights=alchemical_weights,\n", + " z_table=z_table,\n", + " r_max=r_max,\n", + ").to(device)\n", + "\n", + "# Common inputs\n", + "tensor_kwargs = {\"dtype\": torch.float32, \"device\": device}\n", + "positions = torch.tensor(atoms.get_positions(), **tensor_kwargs)\n", + "cell = torch.tensor(atoms.get_cell().array, **tensor_kwargs)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 105/105 [00:07<00:00, 14.57it/s]\n", + "100%|██████████| 55/55 [00:09<00:00, 5.53it/s]\n" + ] + } + ], + "source": [ + "# Calculate hydrostatic stresses for a triangular grid of compositions\n", + "comp_grid = np.array(get_triangular_grid(14))\n", + "stress_list = []\n", + "for comp in tqdm(comp_grid.T):\n", + " # Set alchemical weights\n", + " alchemical_weights = torch.tensor(comp, dtype=torch.float32)\n", + " alchemy_manager.alchemical_weights.data = alchemical_weights\n", + " batch = alchemy_manager(positions, cell).to(device)\n", + "\n", + " # Get hydrostatic stress\n", + " out = mace(batch, retain_graph=True, create_graph=True, compute_stress=True)\n", + " stress = torch.abs(torch.trace(out[\"stress\"][0])) / 3\n", + " stress_list.append(stress.item())\n", + "\n", + "# Calculate gradients for a small triangular grid of compositions\n", + "comp_grid_small = np.array(get_triangular_grid(10))\n", + "grad_list = []\n", + "for comp in tqdm(comp_grid_small.T):\n", + " # Set alchemical weights\n", + " alchemical_weights = torch.tensor(comp, dtype=torch.float32)\n", + " alchemy_manager.alchemical_weights.data = alchemical_weights\n", + " batch = alchemy_manager(positions, cell).to(device)\n", + "\n", + " # Get hydrostatic stress\n", + " out = mace(batch, retain_graph=True, create_graph=True, compute_stress=True)\n", + " stress = torch.abs(torch.trace(out[\"stress\"][0])) / 3 # hydrostatic stress\n", + " stress.backward()\n", + " grad = alchemy_manager.alchemical_weights.grad\n", + " grad_list.append(grad.clone().cpu().numpy())\n", + " grad.data.zero_()" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "stress_list = np.array(stress_list)\n", + "grad_list = np.array(grad_list)\n", + "\n", + "fig = plt.figure(figsize=(3, 3), layout=\"constrained\")\n", + "ax = plt.subplot(projection=\"ternary\")\n", + "cs = ax.tripcolor(*comp_grid, stress_list * 1e3, cmap=\"Blues\", shading=\"gouraud\")\n", + "cax = ax.inset_axes([1.1, 0.3, 0.075, 0.8], transform=ax.transAxes)\n", + "cbar = fig.colorbar(cs, cax=cax)\n", + "cbar.set_label(\"Stress [meV/Å$^3$]\", rotation=270, va=\"baseline\")\n", + "ax.quiver(\n", + " *comp_grid_small,\n", + " -grad_list[:, 0],\n", + " -grad_list[:, 1],\n", + " -grad_list[:, 2],\n", + " color=\"k\",\n", + " scale=1.2,\n", + " pivot=\"mid\",\n", + " linewidth=0.3,\n", + " width=0.01,\n", + " headlength=4,\n", + " headaxislength=3.5,\n", + ")\n", + "ax.set_tlabel(\"Li\", weight=\"bold\", fontsize=12)\n", + "ax.set_llabel(\"Na\", weight=\"bold\", fontsize=12)\n", + "ax.set_rlabel(\"K\", weight=\"bold\", fontsize=12)\n", + "fig.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Al(Sc,Y)N/GaN optimization" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "# Load structure\n", + "atoms = ase.io.read(\"../data/structures/GaN_hex.cif\")\n", + "alch_elements = [\"Al\", \"Sc\"]\n", + "Ga_idx = [i for i, atom in enumerate(atoms) if atom.symbol == \"Ga\"]\n", + "alch_numbers = [ase.Atoms(el).numbers[0] for el in alch_elements]\n", + "alchemical_pairs = [\n", + " [AlchemicalPair(atom_index=idx, atomic_number=z) for idx in Ga_idx]\n", + " for z in alch_numbers\n", + "]\n", + "\n", + "# Set AlchemyManager\n", + "z_table, r_max = get_z_table_and_r_max(mace)\n", + "alchemical_weights = torch.tensor([0.999, 0.001], dtype=torch.float32)\n", + "alchemy_manager = AlchemyManager(\n", + " atoms=atoms,\n", + " alchemical_pairs=alchemical_pairs,\n", + " alchemical_weights=alchemical_weights,\n", + " z_table=z_table,\n", + " r_max=r_max,\n", + ").to(device)\n", + "\n", + "# Common inputs\n", + "tensor_kwargs = {\"dtype\": torch.float32, \"device\": device}\n", + "cell = torch.tensor(atoms.get_cell().array, **tensor_kwargs, requires_grad=True)\n", + "frac_coords = torch.tensor(\n", + " atoms.get_scaled_positions(), **tensor_kwargs, requires_grad=True\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "lr_weights = 5e-3\n", + "lr_cell = 1e-2\n", + "max_steps = 500\n", + "early_stop_loss = 5e-4\n", + "\n", + "for i in range(max_steps):\n", + " # Prepare data\n", + " positions = frac_coords @ cell\n", + " batch = alchemy_manager(positions, cell).to(device)\n", + "\n", + " # Get stress loss\n", + " out = mace(batch, retain_graph=True, create_graph=True, compute_stress=True)\n", + " stress = out[\"stress\"][0]\n", + " loss = torch.abs(stress[0, 0] + stress[1, 1])\n", + " loss.backward()\n", + "\n", + " # Gradient update\n", + " weights = alchemy_manager.alchemical_weights\n", + " weights.grad -= weights.grad.mean()\n", + " weights.data -= lr_weights * weights.grad\n", + " weights.grad.zero_()\n", + "\n", + " c_update = cell.grad[2, 2].detach().clone()\n", + " cell.grad.zero_()\n", + " cell.grad[2, 2] = c_update # only update the c component\n", + " cell.data -= lr_cell * cell.grad\n", + " cell.grad.zero_()\n", + "\n", + " if loss < early_stop_loss:\n", + " break" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[0.8017931 0.19820715]\n" + ] + } + ], + "source": [ + "print(weights.detach().cpu().numpy())" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "chgnet", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/notebooks/3_vacancy_analysis.ipynb b/notebooks/3_vacancy_analysis.ipynb new file mode 100644 index 0000000..d41610f --- /dev/null +++ b/notebooks/3_vacancy_analysis.ipynb @@ -0,0 +1,165 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from pathlib import Path\n", + "\n", + "import numpy as np\n", + "import pandas as pd\n", + "from ase import units\n", + "from scipy.integrate import trapezoid" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "def integrate_switching(\n", + " df_log: pd.DataFrame,\n", + " equil_time: int = 20000,\n", + " switch_time: int = 30000,\n", + " return_E_diss: bool = False,\n", + "):\n", + " fwd_start, fwd_end = equil_time, equil_time + switch_time\n", + " rev_start, rev_end = 2 * equil_time + switch_time, 2 * equil_time + 2 * switch_time\n", + " grad, lamda = df_log[\"lambda_grad\"], df_log[\"lambda\"]\n", + " W_fwd = trapezoid(grad[fwd_start:fwd_end], lamda[fwd_start:fwd_end])\n", + " W_rev = trapezoid(grad[rev_start:rev_end], lamda[rev_start:rev_end])\n", + " if return_E_diss:\n", + " return (W_fwd - W_rev) / 2, (W_fwd + W_rev) / 2\n", + " return (W_fwd - W_rev) / 2 # free energy difference\n", + "\n", + "\n", + "def analyze_frenkel_ladd(\n", + " base_path: Path,\n", + " temp: float,\n", + " equil_time: int = 20000,\n", + " switch_time: int = 30000,\n", + " verbose: bool = False,\n", + "):\n", + " T = temp\n", + " k = np.load(base_path / \"spring_constants.npy\")\n", + "\n", + " mass = np.load(base_path / \"masses.npy\")\n", + " omega = np.sqrt(k / mass)\n", + " n_atoms = len(mass)\n", + "\n", + " # 1. Perfect crystal\n", + " df_log = pd.read_csv(base_path / \"observables.csv\")\n", + " volume = df_log[\"volume\"].values[0]\n", + " if verbose:\n", + " _, E_diss_perfect = integrate_switching(\n", + " df_log, equil_time, switch_time, return_E_diss=True\n", + " )\n", + " delta_F = integrate_switching(df_log, equil_time, switch_time)\n", + " F_E = 3 * units.kB * T * np.mean(np.log(units._hbar * omega / (units.kB * T)))\n", + " PV = volume * 1.01325 * units.bar\n", + " G_perfect = delta_F + F_E + PV\n", + "\n", + " # 2. Defective crystal\n", + " df_log = pd.read_csv(base_path / \"observables_defect.csv\")\n", + " volume = df_log[\"volume\"].values[0]\n", + " if verbose:\n", + " _, E_diss_defect = integrate_switching(\n", + " df_log, equil_time, switch_time, return_E_diss=True\n", + " )\n", + " delta_F = integrate_switching(df_log, equil_time, switch_time)\n", + " F_E = 3 * units.kB * T * np.mean(np.log(units._hbar * omega / (units.kB * T)))\n", + " PV = volume * 1.01325 * units.bar\n", + " G_defect = delta_F + F_E + PV\n", + " G_v = G_defect * (n_atoms - 1) - G_perfect * (n_atoms - 1)\n", + "\n", + " # 3. Partial FL\n", + " df_log = pd.read_csv(base_path / \"observables_FL.csv\")\n", + " if verbose:\n", + " _, E_diss_FL = integrate_switching(\n", + " df_log, equil_time, switch_time, return_E_diss=True\n", + " )\n", + " delta_G = integrate_switching(df_log, equil_time, switch_time)\n", + " # delta_G * N = (G_defect * N-1 + F_E) - G_perfect * N\n", + " G_defect_alt = (delta_G * n_atoms - F_E + G_perfect * n_atoms) / (n_atoms - 1)\n", + " G_v_alt = G_defect_alt * (n_atoms - 1) - G_perfect * (n_atoms - 1)\n", + "\n", + " if verbose:\n", + " return G_perfect, G_defect, delta_G, E_diss_perfect, E_diss_defect, E_diss_FL\n", + " return G_v, G_v_alt" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "G_v (50 K) = 1.5513 ± 0.0063 eV\n", + "G_v FL (50 K) = 1.5582 ± 0.0014 eV\n", + "G_v (100 K) = 1.5475 ± 0.0060 eV\n", + "G_v FL (100 K) = 1.5405 ± 0.0018 eV\n", + "G_v (150 K) = 1.5232 ± 0.0276 eV\n", + "G_v FL (150 K) = 1.5188 ± 0.0029 eV\n", + "G_v (200 K) = 1.5110 ± 0.0283 eV\n", + "G_v FL (200 K) = 1.5003 ± 0.0072 eV\n" + ] + } + ], + "source": [ + "result_path = Path(\"../data/results/vacancy\")\n", + "temp_range = [50, 100, 150, 200]\n", + "\n", + "G_v_all, G_v_std_all = [], []\n", + "G_v_alt_all, G_v_alt_std_all = [], []\n", + "for temp in temp_range:\n", + " G_v_list = []\n", + " G_v_alt_list = []\n", + " E_diss_perfect_list = []\n", + " E_diss_defect_list = []\n", + " E_diss_FL_list = []\n", + " for i in range(4):\n", + " base_path = result_path / f\"Fe_5x5x5_{temp}K/{i}\"\n", + " G_v, G_v_alt = analyze_frenkel_ladd(base_path, temp=temp, verbose=False)\n", + " G_v_list.append(G_v)\n", + " G_v_alt_list.append(G_v_alt)\n", + " G_v = np.mean(G_v_list)\n", + " G_v_std = np.std(G_v_list)\n", + " G_v_alt = np.mean(G_v_alt_list)\n", + " G_v_alt_std = np.std(G_v_alt_list)\n", + " print(f\"G_v ({temp} K) = {G_v:.4f} ± {G_v_std:.4f} eV\")\n", + " print(f\"G_v FL ({temp} K) = {G_v_alt:.4f} ± {G_v_alt_std:.4f} eV\")\n", + " G_v_all.append(G_v)\n", + " G_v_std_all.append(G_v_std)\n", + " G_v_alt_all.append(G_v_alt)\n", + " G_v_alt_std_all.append(G_v_alt_std)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "chgnet", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/notebooks/4_perovskite_analysis.ipynb b/notebooks/4_perovskite_analysis.ipynb new file mode 100644 index 0000000..700e82c --- /dev/null +++ b/notebooks/4_perovskite_analysis.ipynb @@ -0,0 +1,269 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from pathlib import Path\n", + "\n", + "import numpy as np\n", + "import pandas as pd\n", + "from ase import units\n", + "from scipy.integrate import trapezoid" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Frenkel–Ladd path" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "def integrate_switching(\n", + " df_log: pd.DataFrame,\n", + " equil_time: int = 20000,\n", + " switch_time: int = 30000,\n", + " return_E_diss: bool = False,\n", + "):\n", + " fwd_start, fwd_end = equil_time, equil_time + switch_time\n", + " rev_start, rev_end = 2 * equil_time + switch_time, 2 * equil_time + 2 * switch_time\n", + " grad, lamda = df_log[\"lambda_grad\"], df_log[\"lambda\"]\n", + " W_fwd = trapezoid(grad[fwd_start:fwd_end], lamda[fwd_start:fwd_end])\n", + " W_rev = trapezoid(grad[rev_start:rev_end], lamda[rev_start:rev_end])\n", + " if return_E_diss:\n", + " return (W_fwd - W_rev) / 2, (W_fwd + W_rev) / 2\n", + " return (W_fwd - W_rev) / 2 # free energy difference\n", + "\n", + "\n", + "def analyze_frenkel_ladd(\n", + " base_path: Path,\n", + " temp: float,\n", + " equil_time: int = 20000,\n", + " switch_time: int = 30000,\n", + "):\n", + " T = temp\n", + " df_log = pd.read_csv(base_path / \"observables.csv\")\n", + " k = np.load(base_path / \"spring_constants.npy\")\n", + " mass = np.load(base_path / \"masses.npy\")\n", + " omega = np.sqrt(k / mass)\n", + " volume = df_log[\"volume\"].values[0]\n", + "\n", + " delta_F = integrate_switching(df_log, equil_time, switch_time)\n", + " F_E = 3 * units.kB * T * np.mean(np.log(units._hbar * omega / (units.kB * T)))\n", + " PV = volume * 1.01325 * units.bar\n", + " delta_G = delta_F + F_E + PV\n", + "\n", + " return delta_G\n", + "\n", + "\n", + "def analyze_alchemical_switching(\n", + " base_path: Path,\n", + " temp: float,\n", + " equil_time: int = 20000,\n", + " switch_time: int = 30000,\n", + "):\n", + " T = temp\n", + " df_log = pd.read_csv(base_path / \"observables.csv\")\n", + " mass_init = np.load(base_path / \"masses_init.npy\")\n", + " mass_final = np.load(base_path / \"masses_final.npy\")\n", + "\n", + " work = integrate_switching(df_log, equil_time, switch_time)\n", + " G_mass = 1.5 * units.kB * T * np.mean(np.log(mass_init / mass_final))\n", + " delta_G = work + G_mass\n", + "\n", + " return delta_G" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CsPbI3 Alpha G (300 K) = -8.8514 ± 0.0003 eV/atom\n", + "CsPbI3 Alpha G (350 K) = -9.8642 ± 0.0004 eV/atom\n", + "CsPbI3 Alpha G (400 K) = -10.8783 ± 0.0006 eV/atom\n", + "CsPbI3 Alpha G (450 K) = -11.8943 ± 0.0003 eV/atom\n", + "CsPbI3 Alpha G (500 K) = -12.9119 ± 0.0004 eV/atom\n", + "CsPbI3 Delta G (300 K) = -8.8560 ± 0.0001 eV/atom\n", + "CsPbI3 Delta G (350 K) = -9.8660 ± 0.0001 eV/atom\n", + "CsPbI3 Delta G (400 K) = -10.8782 ± 0.0002 eV/atom\n", + "CsPbI3 Delta G (450 K) = -11.8921 ± 0.0002 eV/atom\n", + "CsPbI3 Delta G (500 K) = -12.9072 ± 0.0002 eV/atom\n" + ] + } + ], + "source": [ + "result_path = Path(\"../data/results/perovskite/frenkel_ladd\")\n", + "temp_range = [300, 350, 400, 450, 500]\n", + "\n", + "G_alpha = []\n", + "G_alpha_std = []\n", + "for temp in temp_range:\n", + " G_list = []\n", + " for i in range(4):\n", + " base_path = result_path / f\"CsPbI3_alpha_6x6x6_{temp}K/{i}\"\n", + " G_list.append(analyze_frenkel_ladd(base_path, temp=temp))\n", + " G = np.mean(G_list)\n", + " G_std = np.std(G_list)\n", + " print(f\"CsPbI3 Alpha G ({temp} K) = {G:.4f} ± {G_std:.4f} eV/atom\")\n", + " G_alpha.append(G)\n", + " G_alpha_std.append(G_std)\n", + "\n", + "G_delta = []\n", + "G_delta_std = []\n", + "for temp in temp_range:\n", + " G_list = []\n", + " for i in range(4):\n", + " base_path = result_path / f\"CsPbI3_delta_6x3x3_{temp}K/{i}\"\n", + " G_list.append(analyze_frenkel_ladd(base_path, temp=temp))\n", + " G = np.mean(G_list)\n", + " G_std = np.std(G_list)\n", + " print(f\"CsPbI3 Delta G ({temp} K) = {G:.4f} ± {G_std:.4f} eV/atom\")\n", + " G_delta.append(G)\n", + " G_delta_std.append(G_std)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CsSnI3 Alpha G (300 K) = -8.8297 ± 0.0002 eV/atom\n", + "CsSnI3 Alpha G (350 K) = -9.8413 ± 0.0002 eV/atom\n", + "CsSnI3 Alpha G (400 K) = -10.8544 ± 0.0002 eV/atom\n", + "CsSnI3 Alpha G (450 K) = -11.8695 ± 0.0003 eV/atom\n", + "CsSnI3 Alpha G (500 K) = -12.8863 ± 0.0003 eV/atom\n", + "CsSnI3 Delta G (300 K) = -8.8289 ± 0.0001 eV/atom\n", + "CsSnI3 Delta G (350 K) = -9.8381 ± 0.0000 eV/atom\n", + "CsSnI3 Delta G (400 K) = -10.8494 ± 0.0003 eV/atom\n", + "CsSnI3 Delta G (450 K) = -11.8627 ± 0.0003 eV/atom\n", + "CsSnI3 Delta G (500 K) = -12.8771 ± 0.0003 eV/atom\n" + ] + } + ], + "source": [ + "G_CsSnI3_alpha = []\n", + "G_CsSnI3_alpha_std = []\n", + "for temp in temp_range:\n", + " G_list = []\n", + " for i in range(4):\n", + " base_path = result_path / f\"CsSnI3_alpha_6x6x6_{temp}K/{i}\"\n", + " G_list.append(analyze_frenkel_ladd(base_path, temp=temp))\n", + " G = np.mean(G_list)\n", + " G_std = np.std(G_list)\n", + " print(f\"CsSnI3 Alpha G ({temp} K) = {G:.4f} ± {G_std:.4f} eV/atom\")\n", + " G_CsSnI3_alpha.append(G)\n", + " G_CsSnI3_alpha_std.append(G_std)\n", + "\n", + "G_CsSnI3_delta = []\n", + "G_CsSnI3_delta_std = []\n", + "for temp in temp_range:\n", + " G_list = []\n", + " for i in range(4):\n", + " base_path = result_path / f\"CsSnI3_delta_6x3x3_{temp}K/{i}\"\n", + " G_list.append(analyze_frenkel_ladd(base_path, temp=temp))\n", + " G = np.mean(G_list)\n", + " G_std = np.std(G_list)\n", + " print(f\"CsSnI3 Delta G ({temp} K) = {G:.4f} ± {G_std:.4f} eV/atom\")\n", + " G_CsSnI3_delta.append(G)\n", + " G_CsSnI3_delta_std.append(G_std)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Alchemical path" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Alpha ΔG (300 K) = 0.0233 ± 0.0001 eV/atom\n", + "Alpha ΔG (350 K) = 0.0236 ± 0.0001 eV/atom\n", + "Alpha ΔG (400 K) = 0.0241 ± 0.0001 eV/atom\n", + "Alpha ΔG (450 K) = 0.0249 ± 0.0000 eV/atom\n", + "Alpha ΔG (500 K) = 0.0258 ± 0.0000 eV/atom\n", + "Delta ΔG (300 K) = 0.0271 ± 0.0000 eV/atom\n", + "Delta ΔG (350 K) = 0.0279 ± 0.0000 eV/atom\n", + "Delta ΔG (400 K) = 0.0286 ± 0.0000 eV/atom\n", + "Delta ΔG (450 K) = 0.0294 ± 0.0000 eV/atom\n", + "Delta ΔG (500 K) = 0.0301 ± 0.0000 eV/atom\n" + ] + } + ], + "source": [ + "result_path = Path(\"../data/results/perovskite/alchemy\")\n", + "\n", + "G_alpha = []\n", + "G_alpha_std = []\n", + "for temp in temp_range:\n", + " G_list = []\n", + " for i in range(4):\n", + " base_path = result_path / f\"CsPbI3_CsSnI3_alpha_{temp}K/{i}\"\n", + " G_list.append(analyze_alchemical_switching(base_path, temp=temp))\n", + " G = np.mean(G_list)\n", + " G_std = np.std(G_list)\n", + " print(f\"Alpha ΔG ({temp} K) = {G:.4f} ± {G_std:.4f} eV/atom\")\n", + " G_alpha.append(G)\n", + " G_alpha_std.append(G_std)\n", + "\n", + "G_delta = []\n", + "G_delta_std = []\n", + "for temp in temp_range:\n", + " G_list = []\n", + " for i in range(4):\n", + " base_path = result_path / f\"CsPbI3_CsSnI3_delta_{temp}K/{i}\"\n", + " G_list.append(analyze_alchemical_switching(base_path, temp=temp))\n", + " G = np.mean(G_list)\n", + " G_std = np.std(G_list)\n", + " print(f\"Delta ΔG ({temp} K) = {G:.4f} ± {G_std:.4f} eV/atom\")\n", + " G_delta.append(G)\n", + " G_delta_std.append(G_std)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "chgnet", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..9096539 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,16 @@ +[build-system] +requires = ["setuptools>=61.0"] +build-backend = "setuptools.build_meta" + +[project] +name = "alchemical_mace" +authors = [ + { name = "Juno Nam", email = "junonam@mit.edu" }, +] +description = "Alchemical MACE model" +readme = "README.md" +requires-python = ">=3.9" +version = "0.1.0" + +[tool.setuptools] +packages = ["alchemical_mace"] diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..19e8df1 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,12 @@ +torch==2.0.1 +e3nn==0.4.4 +mace-torch==0.3.4 +ase==3.22.1 +pymatgen==2024.3.1 +numpy==1.25.2 +scipy==1.11.2 +pandas==2.2.2 +matplotlib==3.8.0 +mpltern==1.0.2 +tqdm==4.66.1 +ipykernel==6.25.2 \ No newline at end of file diff --git a/scripts/perovskite_alchemy.py b/scripts/perovskite_alchemy.py new file mode 100644 index 0000000..c58422c --- /dev/null +++ b/scripts/perovskite_alchemy.py @@ -0,0 +1,192 @@ +import argparse +from pathlib import Path + +import ase +import numpy as np +import pandas as pd +from ase import units +from ase.build import make_supercell +from ase.constraints import ExpCellFilter +from ase.md.npt import NPT +from ase.md.nptberendsen import Inhomogeneous_NPTBerendsen +from ase.md.velocitydistribution import MaxwellBoltzmannDistribution, Stationary +from ase.optimize import FIRE +from mace.calculators import mace_mp +from tqdm import tqdm + +from alchemical_mace.calculator import AlchemicalMACECalculator +from alchemical_mace.model import AlchemicalPair +from alchemical_mace.utils import upper_triangular_cell + + +# Arguments +parser = argparse.ArgumentParser() + +# Structure +parser.add_argument("--structure-file", type=str) +parser.add_argument("--supercell", type=int, nargs=3, default=[6, 6, 6]) + +# Alchemy +parser.add_argument("--switch-pair", type=str, nargs=2, default=["Pb", "Sn"]) + +# Molecular dynamics: general +parser.add_argument("--temperature", type=float, default=300.0) +parser.add_argument("--pressure", type=float, default=1.0) +parser.add_argument("--timestep", type=float, default=2.0) +parser.add_argument("--ttime", type=float, default=25.0) +parser.add_argument("--ptime", type=int, default=75.0) + +# Molecular dynamics: timesteps +parser.add_argument("--npt-equil-stpes", type=int, default=10000) +parser.add_argument("--alchemy-equil-steps", type=int, default=20000) +parser.add_argument("--alchemy-switch-steps", type=int, default=30000) + +# Molecular dynamics: output control +parser.add_argument("--output-dir", type=Path, default=Path("results")) +parser.add_argument("--log-interval", type=int, default=1) + +# MACE model +parser.add_argument("--device", type=str, default="cuda") +parser.add_argument("--model", type=str, default="small") + +args = parser.parse_args() +args.output_dir.mkdir(exist_ok=True, parents=True) + +# Load structure +atoms = ase.io.read(args.structure_file) +atoms = make_supercell(atoms, np.diag(args.supercell)) + +# Load universal MACE calculator and relax the structure +mace_calc = mace_mp(model=args.model, device=args.device, default_dtype="float32") +atoms.calc = mace_calc +atoms = ExpCellFilter(atoms) +optimizer = FIRE(atoms) +optimizer.run(fmax=0.01, steps=500) +atoms = atoms.atoms # get the relaxed structure +initial_atoms = atoms.copy() # save the initial structure + + +################################################################################ +# Cell volume equilibration +################################################################################ + +atoms = initial_atoms.copy() +atoms.set_calculator(mace_calc) +bulk_modulus = 100.0 * units.GPa + +# NPT equilibration +dyn = Inhomogeneous_NPTBerendsen( + atoms, + timestep=args.timestep * units.fs, + temperature_K=args.temperature, + pressure_au=args.pressure * 1.01325 * units.bar, + taut=args.ttime * units.fs, + taup=args.ptime * units.fs, + compressibility_au=1.0 / bulk_modulus, +) +MaxwellBoltzmannDistribution(atoms, temperature_K=args.temperature) +Stationary(atoms) + +# NPT equilibration and volume relaxation +for step in tqdm(range(args.npt_equil_stpes), desc="NPT equil"): + dyn.run(steps=1) + + +################################################################################ +# Alchemical switching +################################################################################ + +# Define alchemical transformation +src_elem, dst_elem = args.switch_pair +src_Z, dst_Z = ase.data.atomic_numbers[src_elem], ase.data.atomic_numbers[dst_elem] +src_idx = np.where(atoms.get_atomic_numbers() == src_Z)[0] +alchemical_pairs = [ + [AlchemicalPair(atom_index=idx, atomic_number=Z) for idx in src_idx] + for Z in [src_Z, dst_Z] +] + +# Set up the alchemical MACE calculator +calc = AlchemicalMACECalculator( + atoms=atoms, + alchemical_pairs=alchemical_pairs, + alchemical_weights=[1.0, 0.0], + device=args.device, + model=args.model, +) +atoms.set_calculator(calc) +upper_triangular_cell(atoms) # for ASE NPT + +# NPT alchemical switching +ptime = args.ptime * units.fs +pfactor = bulk_modulus * ptime * ptime + +dyn = NPT( + atoms, + timestep=args.timestep * units.fs, + temperature_K=args.temperature, + externalstress=args.pressure * 1.01325 * units.bar, + ttime=args.ttime * units.fs, + pfactor=pfactor, +) + +# Define alchemical path +t = np.linspace(0.0, 1.0, args.alchemy_switch_steps) +lambda_steps = t ** 5 * (70 * t ** 4 - 315 * t ** 3 + 540 * t ** 2 - 420 * t + 126) +lambda_values = [ + np.zeros(args.alchemy_equil_steps), + lambda_steps, + np.ones(args.alchemy_equil_steps), + lambda_steps[::-1], +] +lambda_values = np.concatenate(lambda_values) + +calculate_gradients = [ + np.zeros(args.alchemy_equil_steps, dtype=bool), + np.ones(args.alchemy_switch_steps, dtype=bool), + np.zeros(args.alchemy_equil_steps, dtype=bool), + np.ones(args.alchemy_switch_steps, dtype=bool), +] +calculate_gradients = np.concatenate(calculate_gradients) + + +def get_observables(dynamics, time, lambda_value): + num_atoms = len(dynamics.atoms) + alchemical_grad = dynamics.atoms._calc.results["alchemical_grad"] + lambda_grad = (alchemical_grad[1] - alchemical_grad[0]) / num_atoms + return { + "time": time, + "potential": dynamics.atoms.get_potential_energy() / num_atoms, + "temperature": dynamics.atoms.get_temperature(), + "volume": dynamics.atoms.get_volume() / num_atoms, + "lambda": lambda_value, + "lambda_grad": lambda_grad, + } + + +# Simulation loop +total_steps = 2 * args.alchemy_equil_steps + 2 * args.alchemy_switch_steps + +observables = [] +for step in (tqdm(range(total_steps), desc="Alchemical switching")): + lambda_value = lambda_values[step] + grad_enabled = calculate_gradients[step] + + # Set alchemical weights and atomic masses + calc.set_alchemical_weights([1 - lambda_value, lambda_value]) + atoms.set_masses(calc.get_alchemical_atomic_masses()) + calc.calculate_alchemical_grad = grad_enabled + + dyn.run(steps=1) + if step % args.log_interval == 0: + time = (step + 1) * args.timestep + observables.append(get_observables(dyn, time, lambda_value)) + +# Save observables +df = pd.DataFrame(observables) +df.to_csv(args.output_dir / "observables.csv", index=False) + +# Save masses for post-processing +calc.set_alchemical_weights([1.0, 0.0]) +np.save(args.output_dir / "masses_init.npy", calc.get_alchemical_atomic_masses()) +calc.set_alchemical_weights([0.0, 1.0]) +np.save(args.output_dir / "masses_final.npy", calc.get_alchemical_atomic_masses()) diff --git a/scripts/perovskite_frenkel_ladd.py b/scripts/perovskite_frenkel_ladd.py new file mode 100644 index 0000000..9d33f3f --- /dev/null +++ b/scripts/perovskite_frenkel_ladd.py @@ -0,0 +1,217 @@ +import argparse +from pathlib import Path + +import ase +import numpy as np +import pandas as pd +from ase import units +from ase.build import make_supercell +from ase.constraints import ExpCellFilter +from ase.md.langevin import Langevin +from ase.md.nptberendsen import Inhomogeneous_NPTBerendsen +from ase.md.velocitydistribution import MaxwellBoltzmannDistribution, Stationary +from ase.optimize import FIRE +from mace.calculators import mace_mp +from pymatgen.io.ase import AseAtomsAdaptor +from pymatgen.symmetry.analyzer import SpacegroupAnalyzer +from tqdm import tqdm + +from alchemical_mace.calculator import FrenkelLaddCalculator, NVTMACECalculator + + +# Arguments +parser = argparse.ArgumentParser() + +# Structure +parser.add_argument("--structure-file", type=str) +parser.add_argument("--supercell", type=int, nargs=3, default=[6, 6, 6]) + +# Molecular dynamics: general +parser.add_argument("--temperature", type=float, default=300.0) +parser.add_argument("--pressure", type=float, default=1.0) +parser.add_argument("--timestep", type=float, default=2.0) +parser.add_argument("--ttime", type=float, default=25.0) +parser.add_argument("--ptime", type=int, default=75.0) + +# Molecular dynamics: timesteps +parser.add_argument("--npt-equil-stpes", type=int, default=10000) +parser.add_argument("--npt-prod-steps", type=int, default=20000) +parser.add_argument("--nvt-equil-steps", type=int, default=20000) +parser.add_argument("--nvt-prod-steps", type=int, default=30000) +parser.add_argument("--alchemy-equil-steps", type=int, default=20000) +parser.add_argument("--alchemy-switch-steps", type=int, default=30000) + +# Molecular dynamics: output control +parser.add_argument("--output-dir", type=Path, default=Path("results")) +parser.add_argument("--log-interval", type=int, default=1) + +# MACE model +parser.add_argument("--device", type=str, default="cuda") +parser.add_argument("--model", type=str, default="small") + +args = parser.parse_args() +args.output_dir.mkdir(exist_ok=True, parents=True) + +# Load structure +atoms = ase.io.read(args.structure_file) +atoms = make_supercell(atoms, np.diag(args.supercell)) + +# Load universal MACE calculator and relax the structure +mace_calc = mace_mp(model=args.model, device=args.device, default_dtype="float32") +atoms.calc = mace_calc +atoms = ExpCellFilter(atoms) +optimizer = FIRE(atoms) +optimizer.run(fmax=0.01, steps=500) +atoms = atoms.atoms # get the relaxed structure +initial_atoms = atoms.copy() # save the initial structure + + +################################################################################ +# Cell volume equilibration +################################################################################ + +atoms = initial_atoms.copy() +atoms.set_calculator(mace_calc) +bulk_modulus = 100.0 * units.GPa + +# Equilibration and volume calculation +dyn = Inhomogeneous_NPTBerendsen( + atoms, + timestep=args.timestep * units.fs, + temperature_K=args.temperature, + pressure_au=args.pressure * 1.01325 * units.bar, + taut=args.ttime * units.fs, + taup=args.ptime * units.fs, + compressibility_au=1.0 / bulk_modulus, +) +MaxwellBoltzmannDistribution(atoms, temperature_K=args.temperature) +Stationary(atoms) + +# NPT equilibration and volume relaxation +cellpar_traj = [] +for step in tqdm(range(args.npt_equil_stpes), desc="NPT equil"): + dyn.run(steps=1) +for step in tqdm(range(args.npt_prod_steps), desc="NPT prod"): + dyn.run(steps=1) + if step % args.log_interval == 0: + cellpar_traj.append(atoms.get_cell().cellpar()) +abc_new = np.mean(cellpar_traj, axis=0)[:3] + +# Scale the initial cell to match the average volume +atoms = initial_atoms +atoms.set_cell(np.diag(abc_new), scale_atoms=True) +atoms.set_calculator(mace_calc) + +# Relax the atomic positions +optimizer = FIRE(atoms) +optimizer.run(fmax=0.01, steps=500) +initial_atoms = atoms.copy() # save the initial structure + + +################################################################################ +# MSD calculation +################################################################################ + +initial_positions = atoms.get_positions() +# Using the reversible scaling MACE calculator with fixed scale of 1.0 +# since we can turn off the stress calculation +calc = NVTMACECalculator(device=args.device, model=args.model) +atoms.set_calculator(calc) + +# NVT MSD calculation +dyn = Langevin( + atoms, + timestep=args.timestep * units.fs, + temperature_K=args.temperature, + friction=1 / (args.ttime * units.fs), +) +MaxwellBoltzmannDistribution(atoms, temperature_K=args.temperature) +Stationary(atoms) + +temperatures = [] +for step in tqdm(range(args.nvt_equil_steps), desc="NVT equil"): + dyn.run(steps=1) +squared_disp = np.zeros(len(atoms)) +for step in tqdm(range(args.nvt_prod_steps), desc="NVT prod"): + dyn.run(steps=1) + squared_disp += np.sum((atoms.get_positions() - initial_positions) ** 2, axis=1) +mean_squared_disp = squared_disp / args.nvt_prod_steps + +# Calculate spring constants and average over symmetrically equivalent atoms +spring_constants = 3.0 * units.kB * args.temperature / mean_squared_disp +structure = AseAtomsAdaptor.get_structure(initial_atoms) +sga = SpacegroupAnalyzer(structure) +equivalent_indices = sga.get_symmetrized_structure().equivalent_indices +for indices in equivalent_indices: + spring_constants[indices] = np.mean(spring_constants[indices]) + +np.save(args.output_dir / "spring_constants.npy", spring_constants) +np.save(args.output_dir / "masses.npy", atoms.get_masses()) + + +################################################################################ +# Frenkel-Ladd calculation +################################################################################ + +atoms = initial_atoms.copy() +calc = FrenkelLaddCalculator( + spring_constants=spring_constants, + initial_positions=initial_positions, + device=args.device, + model=args.model, +) +atoms.set_calculator(calc) + +# NVT Frenkel-Ladd calculation +dyn = Langevin( + atoms, + timestep=args.timestep * units.fs, + temperature_K=args.temperature, + friction=1 / (args.ttime * units.fs), +) +MaxwellBoltzmannDistribution(atoms, temperature_K=args.temperature) +Stationary(atoms) + +# Define Frenkel-Ladd path +t = np.linspace(0.0, 1.0, args.alchemy_switch_steps) +lambda_steps = t**5 * (70 * t**4 - 315 * t**3 + 540 * t**2 - 420 * t + 126) +lambda_values = [ + np.zeros(args.alchemy_equil_steps), + lambda_steps, + np.ones(args.alchemy_equil_steps), + lambda_steps[::-1], +] +lambda_values = np.concatenate(lambda_values) + + +def get_observables(dynamics, time, lambda_value): + num_atoms = len(dynamics.atoms) + return { + "time": time, + "potential": dynamics.atoms.get_potential_energy() / num_atoms, + "temperature": dynamics.atoms.get_temperature(), + "volume": dynamics.atoms.get_volume() / num_atoms, + "lambda": lambda_value, + "lambda_grad": dynamics.atoms._calc.results["energy_diff"] / num_atoms, + } + + +# Simulation loop +calc.compute_mace = False +total_steps = 2 * args.alchemy_equil_steps + 2 * args.alchemy_switch_steps + +observables = [] +for step in tqdm(range(total_steps), desc="Frenkel-Ladd"): + if step == args.alchemy_equil_steps: # turn on MACE after spring equilibration + calc.compute_mace = True + lambda_value = lambda_values[step] + calc.set_weights(lambda_value) + + dyn.run(steps=1) + if step % args.log_interval == 0: + time = (step + 1) * args.timestep + observables.append(get_observables(dyn, time, lambda_value)) + +# Save observables +df = pd.DataFrame(observables) +df.to_csv(args.output_dir / "observables.csv", index=False) diff --git a/scripts/vacancy_frenkel_ladd.py b/scripts/vacancy_frenkel_ladd.py new file mode 100644 index 0000000..ca84e46 --- /dev/null +++ b/scripts/vacancy_frenkel_ladd.py @@ -0,0 +1,427 @@ +import argparse +from pathlib import Path + +import ase +import numpy as np +import pandas as pd +from ase import units +from ase.build import make_supercell +from ase.constraints import ExpCellFilter +from ase.md.langevin import Langevin +from ase.md.npt import NPT +from ase.md.nptberendsen import Inhomogeneous_NPTBerendsen +from ase.md.velocitydistribution import MaxwellBoltzmannDistribution, Stationary +from ase.optimize import FIRE +from mace.calculators import mace_mp +from pymatgen.io.ase import AseAtomsAdaptor +from pymatgen.symmetry.analyzer import SpacegroupAnalyzer +from tqdm import tqdm + +from alchemical_mace.calculator import ( + DefectFrenkelLaddCalculator, + FrenkelLaddCalculator, + NVTMACECalculator, +) +from alchemical_mace.utils import upper_triangular_cell + + +# Arguments +parser = argparse.ArgumentParser() + +# Structure +parser.add_argument("--structure-file", type=str) +parser.add_argument("--supercell", type=int, nargs=3, default=[5, 5, 5]) + +# Molecular dynamics: general +parser.add_argument("--temperature", type=float, default=300.0) +parser.add_argument("--pressure", type=float, default=1.0) +parser.add_argument("--timestep", type=float, default=2.0) +parser.add_argument("--ttime", type=float, default=25.0) +parser.add_argument("--ptime", type=int, default=75.0) + +# Molecular dynamics: timesteps +parser.add_argument("--npt-equil-stpes", type=int, default=10000) +parser.add_argument("--npt-prod-steps", type=int, default=20000) +parser.add_argument("--nvt-equil-steps", type=int, default=20000) +parser.add_argument("--nvt-prod-steps", type=int, default=30000) +parser.add_argument("--alchemy-equil-steps", type=int, default=20000) +parser.add_argument("--alchemy-switch-steps", type=int, default=30000) + +# Molecular dynamics: output control +parser.add_argument("--output-dir", type=Path, default=Path("results")) +parser.add_argument("--log-interval", type=int, default=1) + +# MACE model +parser.add_argument("--device", type=str, default="cuda") +parser.add_argument("--model", type=str, default="small") + +args = parser.parse_args() +args.output_dir.mkdir(exist_ok=True, parents=True) + + +################################################################################ +# Energy minimization: defect-free structure +################################################################################ + +# Load structure +atoms = ase.io.read(args.structure_file) +atoms = make_supercell(atoms, np.diag(args.supercell)) + +# Load universal MACE calculator and relax the structure +mace_calc = mace_mp(model=args.model, device=args.device, default_dtype="float32") +atoms.calc = mace_calc +atoms = ExpCellFilter(atoms) +optimizer = FIRE(atoms) +optimizer.run(fmax=0.01, steps=500) +atoms = atoms.atoms # get the relaxed structure +initial_atoms = atoms.copy() # save the initial structure + + +################################################################################ +# Cell volume equilibration: defect-free structure +################################################################################ + +atoms = initial_atoms.copy() +atoms.set_calculator(mace_calc) +bulk_modulus = 100.0 * units.GPa + +# Equilibration and volume calculation +dyn = Inhomogeneous_NPTBerendsen( + atoms, + timestep=args.timestep * units.fs, + temperature_K=args.temperature, + pressure_au=args.pressure * 1.01325 * units.bar, + taut=args.ttime * units.fs, + taup=args.ptime * units.fs, + compressibility_au=1.0 / bulk_modulus, +) +MaxwellBoltzmannDistribution(atoms, temperature_K=args.temperature) +Stationary(atoms) + +# NPT equilibration and volume relaxation +cellpar_traj = [] +for step in tqdm(range(args.npt_equil_stpes), desc="NPT equil"): + dyn.run(steps=1) +for step in tqdm(range(args.npt_prod_steps), desc="NPT prod"): + dyn.run(steps=1) + if step % args.log_interval == 0: + cellpar_traj.append(atoms.get_cell().cellpar()) +abc_new = np.mean(cellpar_traj, axis=0)[:3] + +# Scale the initial cell to match the average volume +atoms = initial_atoms +atoms.set_cell(np.diag(abc_new), scale_atoms=True) +atoms.set_calculator(mace_calc) + +# Relax the atomic positions +optimizer = FIRE(atoms) +optimizer.run(fmax=0.01, steps=500) +initial_atoms = atoms.copy() # save the initial structure + + +################################################################################ +# MSD calculation: defect-free structure +################################################################################ + +initial_positions = atoms.get_positions() +# Using the reversible scaling MACE calculator with fixed scale of 1.0 +# since we can turn off the stress calculation +calc = NVTMACECalculator(device=args.device, model=args.model) +atoms.set_calculator(calc) + +# NVT MSD calculation +dyn = Langevin( + atoms, + timestep=args.timestep * units.fs, + temperature_K=args.temperature, + friction=1 / (args.ttime * units.fs), +) +MaxwellBoltzmannDistribution(atoms, temperature_K=args.temperature) +Stationary(atoms) + +temperatures = [] +for step in tqdm(range(args.nvt_equil_steps), desc="NVT equil"): + dyn.run(steps=1) +squared_disp = np.zeros(len(atoms)) +for step in tqdm(range(args.nvt_prod_steps), desc="NVT prod"): + dyn.run(steps=1) + squared_disp += np.sum((atoms.get_positions() - initial_positions) ** 2, axis=1) +mean_squared_disp = squared_disp / args.nvt_prod_steps + +# Calculate spring constants and average over symmetrically equivalent atoms +spring_constants = 3.0 * units.kB * args.temperature / mean_squared_disp +structure = AseAtomsAdaptor.get_structure(initial_atoms) +sga = SpacegroupAnalyzer(structure) +equivalent_indices = sga.get_symmetrized_structure().equivalent_indices +for indices in equivalent_indices: + spring_constants[indices] = np.mean(spring_constants[indices]) + +np.save(args.output_dir / "spring_constants.npy", spring_constants) +np.save(args.output_dir / "masses.npy", atoms.get_masses()) + + +################################################################################ +# Frenkel-Ladd calculation: defect-free structure +################################################################################ + +atoms = initial_atoms.copy() +calc = FrenkelLaddCalculator( + spring_constants=spring_constants, + initial_positions=initial_positions, + device=args.device, + model=args.model, +) +atoms.set_calculator(calc) + +# NVT Frenkel-Ladd calculation +dyn = Langevin( + atoms, + timestep=args.timestep * units.fs, + temperature_K=args.temperature, + friction=1 / (args.ttime * units.fs), +) +MaxwellBoltzmannDistribution(atoms, temperature_K=args.temperature) +Stationary(atoms) + +# Define Frenkel-Ladd path +t = np.linspace(0.0, 1.0, args.alchemy_switch_steps) +lambda_steps = t**5 * (70 * t**4 - 315 * t**3 + 540 * t**2 - 420 * t + 126) +lambda_values = [ + np.zeros(args.alchemy_equil_steps), + lambda_steps, + np.ones(args.alchemy_equil_steps), + lambda_steps[::-1], +] +lambda_values = np.concatenate(lambda_values) + + +def get_observables(dynamics, time, lambda_value): + num_atoms = len(dynamics.atoms) + return { + "time": time, + "potential": dynamics.atoms.get_potential_energy() / num_atoms, + "temperature": dynamics.atoms.get_temperature(), + "volume": dynamics.atoms.get_volume() / num_atoms, + "lambda": lambda_value, + "lambda_grad": dynamics.atoms._calc.results["energy_diff"] / num_atoms, + } + + +# Simulation loop +calc.compute_mace = False +total_steps = 2 * args.alchemy_equil_steps + 2 * args.alchemy_switch_steps + +observables = [] +for step in tqdm(range(total_steps), desc="Frenkel-Ladd"): + if step == args.alchemy_equil_steps: # turn on MACE after spring equilibration + calc.compute_mace = True + lambda_value = lambda_values[step] + calc.set_weights(lambda_value) + + dyn.run(steps=1) + if step % args.log_interval == 0: + time = (step + 1) * args.timestep + observables.append(get_observables(dyn, time, lambda_value)) + +# Save observables +df = pd.DataFrame(observables) +df.to_csv(args.output_dir / "observables.csv", index=False) + + +################################################################################ +# Cell volume equilibration: structure with a defect +################################################################################ + +atoms = initial_atoms.copy() + +# Create a vacancy at the center of the supercell +vacancy_index = len(atoms) // 2 +atom_mask = np.ones(len(atoms), dtype=bool) +atom_mask[vacancy_index] = False +del atoms[vacancy_index] + +atoms.set_calculator(mace_calc) + +# Equilibration and volume calculation +dyn = Inhomogeneous_NPTBerendsen( + atoms, + timestep=args.timestep * units.fs, + temperature_K=args.temperature, + pressure_au=args.pressure * 1.01325 * units.bar, + taut=args.ttime * units.fs, + taup=args.ptime * units.fs, + compressibility_au=1.0 / bulk_modulus, +) +MaxwellBoltzmannDistribution(atoms, temperature_K=args.temperature) +Stationary(atoms) + +# NPT equilibration and volume relaxation +cellpar_traj = [] +for step in tqdm(range(args.npt_equil_stpes), desc="NPT equil"): + dyn.run(steps=1) +for step in tqdm(range(args.npt_prod_steps), desc="NPT prod"): + dyn.run(steps=1) + if step % args.log_interval == 0: + cellpar_traj.append(atoms.get_cell().cellpar()) +abc_new = np.mean(cellpar_traj, axis=0)[:3] + +# Scale the initial cell to match the average volume +atoms = initial_atoms.copy() +atoms.set_cell(np.diag(abc_new), scale_atoms=True) +del atoms[vacancy_index] +atoms.set_calculator(mace_calc) + +# Relax the atomic positions +optimizer = FIRE(atoms) +optimizer.run(fmax=0.01, steps=500) + + +################################################################################ +# Frenkel-Ladd calculation: structure with a defect +################################################################################ + +calc = FrenkelLaddCalculator( + spring_constants=spring_constants[atom_mask], + initial_positions=initial_positions[atom_mask], + device=args.device, + model=args.model, +) +atoms.set_calculator(calc) + +# NVT Frenkel-Ladd calculation +dyn = Langevin( + atoms, + timestep=args.timestep * units.fs, + temperature_K=args.temperature, + friction=1 / (args.ttime * units.fs), +) +MaxwellBoltzmannDistribution(atoms, temperature_K=args.temperature) +Stationary(atoms) + +# Simulation loop +calc.compute_mace = False +total_steps = 2 * args.alchemy_equil_steps + 2 * args.alchemy_switch_steps + +observables = [] +for step in tqdm(range(total_steps), desc="Frenkel-Ladd"): + if step == args.alchemy_equil_steps: # turn on MACE after spring equilibration + calc.compute_mace = True + lambda_value = lambda_values[step] + calc.set_weights(lambda_value) + + dyn.run(steps=1) + if step % args.log_interval == 0: + time = (step + 1) * args.timestep + observables.append(get_observables(dyn, time, lambda_value)) + +# Save observables +df = pd.DataFrame(observables) +df.to_csv(args.output_dir / "observables_defect.csv", index=False) + + +################################################################################ +# Cell volume equilibration: partial Frenkel-Ladd calculation +################################################################################ + +atoms = initial_atoms.copy() +atoms.set_calculator(mace_calc) + +# Equilibration and volume calculation +dyn = Inhomogeneous_NPTBerendsen( + atoms, + timestep=args.timestep * units.fs, + temperature_K=args.temperature, + pressure_au=args.pressure * 1.01325 * units.bar, + taut=args.ttime * units.fs, + taup=args.ptime * units.fs, + compressibility_au=1.0 / bulk_modulus, +) +MaxwellBoltzmannDistribution(atoms, temperature_K=args.temperature) +Stationary(atoms) + +# NPT equilibration and volume relaxation +for step in tqdm(range(args.npt_equil_stpes), desc="NPT equil"): + dyn.run(steps=1) + + +################################################################################ +# Alchemical switching +################################################################################ + +# Set up the partial Frenkel-Ladd calculation +calc = DefectFrenkelLaddCalculator( + atoms=atoms, + spring_constant=spring_constants[vacancy_index], + defect_index=vacancy_index, + device=args.device, + model=args.model, +) +atoms.set_calculator(calc) +upper_triangular_cell(atoms) # for ASE NPT + +# NPT alchemical switching +ptime = args.ptime * units.fs +pfactor = bulk_modulus * ptime * ptime + +dyn = NPT( + atoms, + timestep=args.timestep * units.fs, + temperature_K=args.temperature, + externalstress=args.pressure * 1.01325 * units.bar, + ttime=args.ttime * units.fs, + pfactor=pfactor, +) + +# Define alchemical path +t = np.linspace(0.0, 1.0, args.alchemy_switch_steps) +lambda_steps = t**5 * (70 * t**4 - 315 * t**3 + 540 * t**2 - 420 * t + 126) +lambda_values = [ + np.zeros(args.alchemy_equil_steps), + lambda_steps, + np.ones(args.alchemy_equil_steps), + lambda_steps[::-1], +] +lambda_values = np.concatenate(lambda_values) + +calculate_gradients = [ + np.zeros(args.alchemy_equil_steps, dtype=bool), + np.ones(args.alchemy_switch_steps, dtype=bool), + np.zeros(args.alchemy_equil_steps, dtype=bool), + np.ones(args.alchemy_switch_steps, dtype=bool), +] +calculate_gradients = np.concatenate(calculate_gradients) + + +def get_observables(dynamics, time, lambda_value): + num_atoms = len(dynamics.atoms) + alchemical_grad = dynamics.atoms._calc.results["alchemical_grad"] + return { + "time": time, + "potential": dynamics.atoms.get_potential_energy() / num_atoms, + "temperature": dynamics.atoms.get_temperature(), + "volume": dynamics.atoms.get_volume() / num_atoms, + "lambda": lambda_value, + "lambda_grad": alchemical_grad / num_atoms, + } + + +# Simulation loop +total_steps = 2 * args.alchemy_equil_steps + 2 * args.alchemy_switch_steps + +observables = [] +for step in tqdm(range(total_steps), desc="Alchemical switching"): + lambda_value = lambda_values[step] + grad_enabled = calculate_gradients[step] + + # Set alchemical weights and atomic masses + calc.set_alchemical_weight(lambda_value) + calc.calculate_alchemical_grad = grad_enabled + + dyn.run(steps=1) + if step % args.log_interval == 0: + time = (step + 1) * args.timestep + observables.append(get_observables(dyn, time, lambda_value)) + +# Save observables +df = pd.DataFrame(observables) +df.to_csv(args.output_dir / "observables_FL.csv", index=False)