Skip to content

Commit

Permalink
Clean up.
Browse files Browse the repository at this point in the history
  • Loading branch information
knc6 committed Nov 4, 2024
1 parent 8bb073f commit e445579
Show file tree
Hide file tree
Showing 15 changed files with 91 additions and 2,498 deletions.
2 changes: 1 addition & 1 deletion alignn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
"""Version number."""

__version__ = "2024.8.30"
__version__ = "2024.10.30"
4 changes: 0 additions & 4 deletions alignn/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
from typing import Literal
from alignn.utils import BaseSettings
from alignn.models.alignn import ALIGNNConfig
from alignn.models.alignn_ff2 import ALIGNNFF2Config
from alignn.models.alignn_eff import ALIGNNeFFConfig
from alignn.models.alignn_atomwise import ALIGNNAtomWiseConfig

# import torch
Expand Down Expand Up @@ -211,8 +209,6 @@ class TrainingConfig(BaseSettings):
# model configuration
model: Union[
ALIGNNConfig,
ALIGNNFF2Config,
ALIGNNeFFConfig,
ALIGNNAtomWiseConfig,
# CGCNNConfig,
# ICGCNNConfig,
Expand Down
2 changes: 2 additions & 0 deletions alignn/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
tqdm.pandas()


# NOTE: Use lmd_dataset,
# need to fix adding lattice in dataloader
def load_graphs(
dataset=[],
name: str = "dft_3d",
Expand Down
4 changes: 2 additions & 2 deletions alignn/examples/sample_data_ff/config_example_atomwise.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"dataset": "user_data",
"target": "target",
"atom_features": "cgcnn",
"neighbor_strategy": "radius_graph_jarvis",
"neighbor_strategy": "radius_graph",
"id_tag": "jid",
"dtype": "float32",
"random_seed": 123,
Expand Down Expand Up @@ -40,7 +40,7 @@
"distributed":false,
"use_lmdb": true,
"model": {
"name": "alignn_ff2",
"name": "alignn_atomwise",
"atom_input_features": 92,
"calculate_gradient":true,
"atomwise_output_features":0,
Expand Down
7 changes: 0 additions & 7 deletions alignn/ff/ff.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,6 @@
from jarvis.db.jsonutils import loadjson
from alignn.graphs import Graph
from alignn.models.alignn_atomwise import ALIGNNAtomWise, ALIGNNAtomWiseConfig
from alignn.models.alignn_ff2 import ALIGNNFF2, ALIGNNFF2Config
from alignn.models.alignn_eff import ALIGNNeFF, ALIGNNeFFConfig
from alignn.config import TrainingConfig
from jarvis.analysis.defects.vacancy import Vacancy
import numpy as np
from alignn.pretrained import get_prediction
Expand Down Expand Up @@ -270,12 +267,8 @@ def __init__(
)
if self.model is None:

if config["model"]["name"] == "alignn_ff2":
model = ALIGNNFF2(ALIGNNFF2Config(**config["model"]))
if config["model"]["name"] == "alignn_atomwise":
model = ALIGNNAtomWise(ALIGNNAtomWiseConfig(**config["model"]))
if config["model"]["name"] == "alignn_eff":
model = ALIGNNeFF(ALIGNNeFFConfig(**config["model"]))
model.state_dict()
model.load_state_dict(
torch.load(
Expand Down
51 changes: 26 additions & 25 deletions alignn/graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import torch
import dgl
from tqdm import tqdm
from jarvis.core.atoms import Atoms

# import matgl

Expand Down Expand Up @@ -63,13 +64,12 @@ def temp_graph(
g.ndata["Z"] = torch.tensor(atom_feats, dtype=torch.int64)
g.edata["r"] = torch.tensor(np.array(r), dtype=dtype)
g.edata["d"] = torch.tensor(d, dtype=dtype)
g.edata["pbc_offset"] = torch.tensor(images, dtype=dtype)
g.edata["pbc_offshift"] = torch.tensor(images, dtype=dtype)
# g.edata["pbc_offset"] = torch.tensor(images, dtype=dtype)
# g.edata["pbc_offshift"] = torch.tensor(images, dtype=dtype)
g.edata["images"] = torch.tensor(images, dtype=dtype)
# g.edata["lattice"] = torch.tensor(torch.repeat_interleave(torch.tensor(atoms.lattice_mat.flatten()), atoms.num_atoms), dtype=dtype)
node_type = torch.tensor([0 for i in range(len(atoms.atomic_numbers))])
g.ndata["node_type"] = node_type
lattice_mat = atoms.lattice_mat
# node_type = torch.tensor([0 for i in range(len(atoms.atomic_numbers))])
# g.ndata["node_type"] = node_type
# lattice_mat = atoms.lattice_mat
# g.ndata["lattice"] = torch.tensor(
# [lattice_mat for ii in range(g.num_nodes())]
# , dtype=dtype)
Expand All @@ -78,7 +78,6 @@ def temp_graph(
# , dtype=dtype)
g.ndata["pos"] = torch.tensor(atoms.cart_coords, dtype=dtype)
g.ndata["frac_coords"] = torch.tensor(atoms.frac_coords, dtype=dtype)
# g.ndata["V"] = torch.tensor([atoms.volume] * atoms.num_atoms, dtype=dtype)

return g, u, v, r

Expand Down Expand Up @@ -516,43 +515,45 @@ def atom_dgl_multigraph(
# u, v, r = build_undirected_edgedata(atoms, edges)

# build up atom attribute tensor
comp = atoms.composition.to_dict()
comp_dict = {}
c_ind = 0
for ii, jj in comp.items():
if ii not in comp_dict:
comp_dict[ii] = c_ind
c_ind += 1
# comp = atoms.composition.to_dict()
# comp_dict = {}
# c_ind = 0
# for ii, jj in comp.items():
# if ii not in comp_dict:
# comp_dict[ii] = c_ind
# c_ind += 1
sps_features = []
node_types = []
# node_types = []
for ii, s in enumerate(atoms.elements):
feat = list(get_node_attributes(s, atom_features=atom_features))
# if include_prdf_angles:
# feat=feat+list(prdf[ii])+list(adf[ii])
sps_features.append(feat)
node_types.append(comp_dict[s])
# node_types.append(comp_dict[s])
sps_features = np.array(sps_features)
node_features = torch.tensor(sps_features).type(
torch.get_default_dtype()
)
g = dgl.graph((u, v))
g.ndata["atom_features"] = node_features
g.ndata["node_type"] = torch.tensor(node_types, dtype=torch.int64)
node_type = torch.tensor([0 for i in range(len(atoms.atomic_numbers))])
g.ndata["node_type"] = node_type
# g.ndata["node_type"] = torch.tensor(node_types, dtype=torch.int64)
# node_type = torch.tensor([0 for i in range(len(atoms.atm_num))])
# g.ndata["node_type"] = node_type
# print('g.ndata["node_type"]',g.ndata["node_type"])
g.edata["r"] = torch.tensor(r).type(torch.get_default_dtype())
g.edata["r"] = torch.tensor(np.array(r)).type(
torch.get_default_dtype()
)
# images=torch.tensor(images).type(torch.get_default_dtype())
# print('images',images.shape,r.shape)
# print('type',torch.get_default_dtype())
g.edata["images"] = torch.tensor(images).type(
g.edata["images"] = torch.tensor(np.array(images)).type(
torch.get_default_dtype()
)
vol = atoms.volume
g.ndata["V"] = torch.tensor([vol for ii in range(atoms.num_atoms)])
g.ndata["coords"] = torch.tensor(atoms.cart_coords).type(
torch.get_default_dtype()
)
# g.ndata["coords"] = torch.tensor(atoms.cart_coords).type(
# torch.get_default_dtype()
# )
g.ndata["frac_coords"] = torch.tensor(atoms.frac_coords).type(
torch.get_default_dtype()
)
Expand Down Expand Up @@ -1048,7 +1049,7 @@ def setup_standardizer(self, ids):
@staticmethod
def collate(samples: List[Tuple[dgl.DGLGraph, torch.Tensor]]):
"""Dataloader helper to batch graphs cross `samples`."""
graphs, lattice, labels = map(list, zip(*samples))
graphs, lattices, labels = map(list, zip(*samples))
batched_graph = dgl.batch(graphs)
return batched_graph, torch.tensor(lattices), torch.tensor(labels)

Expand Down
34 changes: 23 additions & 11 deletions alignn/lmdb_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,11 @@ def __getitem__(self, idx):
with self.env.begin() as txn:
serialized_data = txn.get(f"{idx}".encode())
if self.line_graph:
graph, line_graph, lattice,label = pk.loads(serialized_data)
return graph, line_graph, lattice,label
graph, line_graph, lattice, label = pk.loads(serialized_data)
return graph, line_graph, lattice, label
else:
graph, lattice,label = pk.loads(serialized_data)
return graph, lattice,label
graph, lattice, label = pk.loads(serialized_data)
return graph, lattice, label

def close(self):
"""Close connection."""
Expand All @@ -76,7 +76,7 @@ def __del__(self):
def collate(samples: List[Tuple[dgl.DGLGraph, torch.Tensor]]):
"""Dataloader helper to batch graphs cross `samples`."""
# print('samples',samples)
graphs, lattices,labels = map(list, zip(*samples))
graphs, lattices, labels = map(list, zip(*samples))
# graphs, lgs, labels = map(list, zip(*samples))
batched_graph = dgl.batch(graphs)
return batched_graph, torch.tensor(lattices), torch.tensor(labels)
Expand All @@ -90,9 +90,19 @@ def collate_line_graph(
batched_graph = dgl.batch(graphs)
batched_line_graph = dgl.batch(line_graphs)
if len(labels[0].size()) > 0:
return batched_graph, batched_line_graph, torch.tensor(lattices),torch.stack(labels)
return (
batched_graph,
batched_line_graph,
torch.tensor(lattices),
torch.stack(labels),
)
else:
return batched_graph, batched_line_graph, torch.stack(lattices),torch.tensor(labels)
return (
batched_graph,
batched_line_graph,
torch.stack(lattices),
torch.tensor(labels),
)


def get_torch_dataset(
Expand Down Expand Up @@ -143,7 +153,7 @@ def get_torch_dataset(
for idx, (d) in tqdm(enumerate(dataset), total=len(dataset)):
ids.append(d[id_tag])
# g, lg = Graph.atom_dgl_multigraph(
atoms=Atoms.from_dict(d["atoms"])
atoms = Atoms.from_dict(d["atoms"])
g = Graph.atom_dgl_multigraph(
atoms,
cutoff=float(cutoff),
Expand All @@ -157,7 +167,9 @@ def get_torch_dataset(
)
if line_graph:
g, lg = g
lattice=torch.tensor(atoms.lattice_mat).type(torch.get_default_dtype())
lattice = torch.tensor(atoms.lattice_mat).type(
torch.get_default_dtype()
)
label = torch.tensor(d[target]).type(torch.get_default_dtype())
# print('label',label,label.view(-1).long())
if classification:
Expand All @@ -184,9 +196,9 @@ def get_torch_dataset(

# labels.append(label)
if line_graph:
serialized_data = pk.dumps((g, lg, lattice,label))
serialized_data = pk.dumps((g, lg, lattice, label))
else:
serialized_data = pk.dumps((g, lattice,label))
serialized_data = pk.dumps((g, lattice, label))
txn.put(f"{idx}".encode(), serialized_data)

env.close()
Expand Down
8 changes: 1 addition & 7 deletions alignn/models/alignn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,15 @@
"""

from typing import Tuple, Union

import dgl
import dgl.function as fn
import numpy as np
import torch
from dgl.nn import AvgPooling

# from dgl.nn.functional import edge_softmax
from typing import Literal
from torch import nn
from torch.nn import functional as F

# from alignn.models.utils import RBFExpansion
# from alignn.utils import BaseSettings

from alignn.models.utils import RBFExpansion
from pydantic_settings import BaseSettings


Expand Down
Loading

0 comments on commit e445579

Please sign in to comment.