Skip to content

Commit

Permalink
import solvers/trainers only when needed
Browse files Browse the repository at this point in the history
  • Loading branch information
yomichi committed Oct 25, 2024
1 parent e3ff354 commit f887d03
Show file tree
Hide file tree
Showing 17 changed files with 363 additions and 352 deletions.
31 changes: 19 additions & 12 deletions abics/applications/latgas_abinitio_interface/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,24 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see http://www.gnu.org/licenses/.

# from .default_observer import *
from .default_observer import *
from .map2perflat import *
from .aenet_trainer import *
from .nequip_trainer import *
from .mlip_3_trainer import *
# from .aenet_trainer import *
# from .nequip_trainer import *
# from .mlip_3_trainer import *

from .vasp import VASPSolver
from .qe import QESolver
from .aenet import AenetSolver
from .aenet_pylammps import AenetPyLammpsSolver
from .nequip import NequipSolver
from .mlip_3 import MLIP_3_Solver
from .openmx import OpenMXSolver
from .user_function_solver import UserFunctionSolver
from .base_solver import register_solver
from .base_trainer import register_trainer

register_solver("vasp", "VASPSolver", "abics.applications.latgas_abinitio_interface.vasp")
register_solver("qe", "QESolver", "abics.applications.latgas_abinitio_interface.qe")
register_solver("openmx", "OpenMXSolver", "abics.applications.latgas_abinitio_interface.openmx")
register_solver("aenet", "AenetSolver", "abics.applications.latgas_abinitio_interface.aenet")
register_solver("nequip", "NequipSolver", "abics.applications.latgas_abinitio_interface.nequip")
# register_solver("allegro", "NequipSolver", "abics.applications.latgas_abinitio_interface.nequip")
register_solver("mlip_3", "MLIP_3_Solver", "abics.applications.latgas_abinitio_interface.mlip_3")
register_solver("User", "UserFunctionSolver", "abics.applications.latgas_abinitio_interface.user_function_solver")

register_trainer("aenet", "Aenet_trainer", "abics.applications.latgas_abinitio_interface.aenet_trainer")
register_trainer("nequip", "Nequip_trainer", "abics.applications.latgas_abinitio_interface.nequip_trainer")
register_trainer("mlip_3", "MLIP_3_Trainer", "abics.applications.latgas_abinitio_interface.mlip_3_trainer")
112 changes: 4 additions & 108 deletions abics/applications/latgas_abinitio_interface/aenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,6 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see http://www.gnu.org/licenses/.

"""
Adapted from pymatgen.io.xcrysden distributed under the MIT License
# Copyright (c) Pymatgen Development Team.
# Distributed under the terms of the MIT License.
"""

from __future__ import annotations

import os
Expand All @@ -28,105 +22,9 @@
import numpy as np
from pymatgen.core import Structure

from .base_solver import SolverBase, register_solver
from .base_solver import SolverBase
from .params import ALParams, DFTParams


def to_XSF(structure: Structure, write_force_zero=False):
"""
Returns a string with the structure in XSF format
See http://www.xcrysden.org/doc/XSF.html
"""
lines = []
app = lines.append

app("CRYSTAL")
app("# Primitive lattice vectors in Angstrom")
app("PRIMVEC")
cell = structure.lattice.matrix
for i in range(3):
app(" %.14f %.14f %.14f" % tuple(cell[i]))

cart_coords = structure.cart_coords
app("# Cartesian coordinates in Angstrom.")
app("PRIMCOORD")
app(" %d 1" % len(cart_coords))
species = structure.species
site_properties = structure.site_properties
if "forces" not in site_properties.keys():
write_force_zero = True
else:
forces = site_properties["forces"]

if write_force_zero:
for a in range(len(cart_coords)):
app(
str(species[a])
+ " %20.14f %20.14f %20.14f" % tuple(cart_coords[a])
+ " 0.0 0.0 0.0"
)
else:
for a in range(len(cart_coords)):
app(
str(species[a])
+ " %20.14f %20.14f %20.14f" % tuple(cart_coords[a])
+ " %20.14f %20.14f %20.14f" % tuple(forces[a])
)

return "\n".join(lines)


def from_XSF(input_string: str):
"""
Initialize a `Structure` object from a string with data in XSF format.
Args:
input_string: String with the structure in XSF format.
See http://www.xcrysden.org/doc/XSF.html
cls_: Structure class to be created. default: pymatgen structure
"""
# CRYSTAL see (1)
# these are primitive lattice vectors (in Angstroms)
# PRIMVEC
# 0.0000000 2.7100000 2.7100000 see (2)
# 2.7100000 0.0000000 2.7100000
# 2.7100000 2.7100000 0.0000000

# these are conventional lattice vectors (in Angstroms)
# CONVVEC
# 5.4200000 0.0000000 0.0000000 see (3)
# 0.0000000 5.4200000 0.0000000
# 0.0000000 0.0000000 5.4200000

# these are atomic coordinates in a primitive unit cell (in Angstroms)
# PRIMCOORD
# 2 1 see (4)
# 16 0.0000000 0.0000000 0.0000000 see (5)
# 30 1.3550000 -1.3550000 -1.3550000

lattice, coords, species = [], [], []
lines = input_string.splitlines()

for i in range(len(lines)):
if "PRIMVEC" in lines[i]:
for j in range(i + 1, i + 4):
lattice.append([float(c) for c in lines[j].split()])

if "PRIMCOORD" in lines[i]:
num_sites = int(lines[i + 1].split()[0])

for j in range(i + 2, i + 2 + num_sites):
tokens = lines[j].split()
species.append(tokens[0])
coords.append([float(j) for j in tokens[1:4]])
break
else:
raise ValueError("Invalid XSF data")

s = Structure(lattice, species, coords, coords_are_cartesian=True)
return s

from .util import structure_to_XSF, structure_from_XSF

class AenetSolver(SolverBase):
"""
Expand Down Expand Up @@ -183,7 +81,7 @@ def update_info_by_structure(self, structure: Structure):
if self.ignore_species is not None:
structure = structure.copy()
structure.remove_species(self.ignore_species)
self.pos_info = to_XSF(structure)
self.pos_info = structure_to_XSF(structure)

def update_info_from_files(self, output_dir, rerun):
"""
Expand Down Expand Up @@ -262,7 +160,7 @@ def get_results(self, output_dir):
# Read results from files in output_dir and calculate values
Phys = namedtuple("PhysValues", ("energy", "structure"))
with open(os.path.join(output_dir, "structure.xsf")) as f:
structure = from_XSF(f.read())
structure = structure_from_XSF(f.read())
with open(os.path.join(output_dir, "stdout")) as f:
lines = f.read()
fi_io = io.StringIO(lines)
Expand Down Expand Up @@ -291,5 +189,3 @@ def create(cls, params: ALParams | DFTParams):
ignore_species = params.ignore_species
run_scheme = params.solver_run_scheme
return cls(path, ignore_species, run_scheme)

register_solver("aenet", AenetSolver)
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import numpy as np
from pymatgen.core import Structure

from .base_solver import SolverBase, register_solver
from .base_solver import SolverBase, register_solver_lazy
from .params import ALParams, DFTParams


Expand Down Expand Up @@ -229,4 +229,5 @@ def create(cls, params: ALParams | DFTParams):
return cls(ignore_species)


register_solver("aenetpylammps", AenetPyLammpsSolver)
# register_solver("aenetpylammps", AenetPyLammpsSolver)
register_solver_lazy("aenetpylammps", "AenetPyLammpsSolver", "abics.applications.latgas_abinitio_interface.aenet_pylammps")
27 changes: 22 additions & 5 deletions abics/applications/latgas_abinitio_interface/aenet_trainer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,19 @@
# ab-Initio Configuration Sampling tool kit (abICS)
# Copyright (C) 2019- The University of Tokyo
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see http://www.gnu.org/licenses/.

from __future__ import annotations
from typing import Sequence

Expand All @@ -8,9 +24,10 @@
from pymatgen.core import Structure

from abics.util import expand_cmd_path
from abics.applications.latgas_abinitio_interface import aenet
from abics.applications.latgas_abinitio_interface.base_trainer import TrainerBase
from abics.applications.latgas_abinitio_interface.util import structure_to_XSF

class aenet_trainer:
class Aenet_trainer(TrainerBase):
def __init__(
self,
structures: Sequence[Structure],
Expand Down Expand Up @@ -48,15 +65,15 @@ def prepare(self, latgas_mode = True, st_dir = "aenetXSF"):
xsfdir = os.getcwd()
if latgas_mode:
for i, st in enumerate(self.structures):
xsf_string = aenet.to_XSF(st, write_force_zero=False)
xsf_string = structure_to_XSF(st, write_force_zero=False)
xsf_string = (
"# total energy = {} eV\n\n".format(self.energies[i]) + xsf_string
)
with open("structure.{}.xsf".format(i), "w") as fi:
fi.write(xsf_string)
else:
for i, st in enumerate(self.structures):
xsf_string = aenet.to_XSF(st, write_force_zero=False)
xsf_string = structure_to_XSF(st, write_force_zero=False)
xsf_string = (
"# total energy = {} eV\n\n".format(self.energies[i]) + xsf_string
)
Expand Down Expand Up @@ -170,7 +187,7 @@ def train(self, train_dir = "train"):
os.chdir(pathlib.Path(os.getcwd()).parent)
self.is_trained = True

def new_baseinput(self, baseinput_dir):
def new_baseinput(self, baseinput_dir, train_dir=""):
try:
assert self.is_trained
except AssertionError as e:
Expand Down
20 changes: 13 additions & 7 deletions abics/applications/latgas_abinitio_interface/base_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,21 +200,21 @@ def create(cls, params: ALParams | DFTParams) -> SolverBase:

__solver_table = {}

def register_solver(solver_name: str, solver_class) -> None:
def register_solver(solver_name: str, solver_class: str, solver_module: str) -> None:
"""
Register solver class.
Parameters
----------
solver_name : str
Solver name (case insensible).
solver_class : SolverBase
solver_class : str
Solver class, which should be a subclass of SolverBase.
solver_module : str
Module name including the solver class.
"""

if SolverBase not in solver_class.mro():
raise TypeError("solver_class must be a subclass of SolverBase")
__solver_table[solver_name.lower()] = solver_class
__solver_table[solver_name.lower()] = (solver_class, solver_module)


def create_solver(solver_name, params: ALParams | DFTParams) -> SolverBase:
Expand All @@ -236,5 +236,11 @@ def create_solver(solver_name, params: ALParams | DFTParams) -> SolverBase:
sn = solver_name.lower()
if sn not in __solver_table:
raise ValueError(f"Unknown solver: {solver_name}")
solver_class = __solver_table[sn]
return solver_class.create(params)

import importlib
solver_class_name, solver_module = __solver_table[sn]
mod = importlib.import_module(solver_module)
solver_class = getattr(mod, solver_class_name)
if SolverBase not in solver_class.mro():
raise TypeError("solver_class must be a subclass of SolverBase")
return solver_class.create(params)
Loading

0 comments on commit f887d03

Please sign in to comment.