Skip to content

Commit

Permalink
FF fix
Browse files Browse the repository at this point in the history
  • Loading branch information
knc6 committed Jan 31, 2025
1 parent ec9e8d3 commit 1be0faa
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 3 deletions.
65 changes: 65 additions & 0 deletions alignn/ff/ff.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,71 @@ def get_figshare_model_ff(
return dir_path


def get_figshare_model_prop(
model_name="jv_mbj_bandgap_alignn", dir_path=None, filename="best_model.pt"
):
"""Get ALIGNN-FF torch models from figshare."""
all_models_prop = get_all_models_prop()
# https://doi.org/10.6084/m9.figshare.23695695
if dir_path is None:
dir_path = str(os.path.join(os.path.dirname(__file__), model_name))
# cwd=os.getcwd()
dir_path = os.path.abspath(dir_path)
if not os.path.exists(dir_path):
os.makedirs(dir_path)
# os.chdir(dir_path)
url = all_models_prop[model_name]
zfile = model_name + ".zip"
path = str(os.path.join(dir_path, zfile))
# path = str(os.path.join(os.path.dirname(__file__), zfile))
print("dir_path", dir_path)
best_path = os.path.join(dir_path, filename)
if not os.path.exists(best_path):
response = requests.get(url, stream=True)
total_size_in_bytes = int(response.headers.get("content-length", 0))
block_size = 1024 # 1 Kibibyte
progress_bar = tqdm(
total=total_size_in_bytes, unit="iB", unit_scale=True
)
with open(path, "wb") as file:
for data in response.iter_content(block_size):
progress_bar.update(len(data))
file.write(data)
progress_bar.close()
zp = zipfile.ZipFile(path)
names = zp.namelist()
chks = []
cfg = []
for i in names:
if "checkpoint_" in i and "pt" in i:
tmp = i
# fname = i
chks.append(i)
if filename in i:
tmp = i
chks.append(i)

if "config.json" in i:
cfg = i

config = zipfile.ZipFile(path).read(cfg)
# print("Loading the zipfile...", zipfile.ZipFile(path).namelist())
data = zipfile.ZipFile(path).read(tmp)
# print('dir_path',dir_path,filename)
# new_file, filename = tempfile.mkstemp()
filename = os.path.join(dir_path, filename)
with open(filename, "wb") as f:
f.write(data)
filename = os.path.join(dir_path, "config.json")
with open(filename, "wb") as f:
f.write(config)
os.remove(path)
# print("Using model file", url, "from ", chks)
# print("Path", os.path.abspath(path))
# print("Config", os.path.abspath(cfg))
return dir_path


def default_path():
"""Get default model path."""
dpath = get_figshare_model_ff(model_name="v12.2.2024_dft_3d_307k")
Expand Down
3 changes: 3 additions & 0 deletions alignn/pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@
# See also, alignn/ff/ff.py
# Both alignn and alignn_atomwise
# models are shared

# See: alignn/ff/all_models_alignn.json
# to load as a calculator
all_models = {
"jv_formation_energy_peratom_alignn": [
"https://figshare.com/ndownloader/files/31458679",
Expand Down
28 changes: 25 additions & 3 deletions alignn/tests/test_alignn_ff.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
ForceField,
)
from jarvis.io.vasp.inputs import Poscar
from alignn.ff.ff import get_figshare_model_prop, get_figshare_model_ff
import os

# JVASP-25139
pos = """Rb8
1.0
Expand Down Expand Up @@ -127,11 +129,31 @@ def test_qclean():

def test_ialignn_ff():
from alignn.ff.calculators import iAlignnAtomwiseCalculator

calc = iAlignnAtomwiseCalculator()
atoms = Poscar.from_string(pos).atoms.ase_converter()
atoms.calc=calc
en=atoms.get_potential_energy()
results=atoms.calc.results
atoms.calc = calc
en = atoms.get_potential_energy()
results = atoms.calc.results


def test_alexandria_gap():
p = get_figshare_model_ff("alex_band_gap")
calc = AlignnAtomwiseCalculator(path=p)
atoms = Poscar.from_string(pos).atoms.ase_converter()
atoms.calc = calc
val = atoms.get_potential_energy() # gap


def test_jdft_mbj_gap():
# p=get_figshare_model_ff('alex_band_gap')
p = get_figshare_model_prop()
calc = AlignnAtomwiseCalculator(path=p)
atoms = Poscar.from_string(pos).atoms.ase_converter()
atoms.calc = calc
val = atoms.get_potential_energy() # gap


# print('test_graph_builder')
# test_graph_builder()
# print('test_ev')
Expand Down

0 comments on commit 1be0faa

Please sign in to comment.