Skip to content

Commit

Permalink
Get rid of ignite.
Browse files Browse the repository at this point in the history
  • Loading branch information
knc6 committed Mar 19, 2024
1 parent 7873c2b commit 809dedf
Show file tree
Hide file tree
Showing 8 changed files with 138 additions and 582 deletions.
4 changes: 2 additions & 2 deletions alignn/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ class TrainingConfig(BaseSettings):
"tinnet_O",
"tinnet_N",
] = "dft_3d"
target: TARGET_ENUM = "formation_energy_peratom"
target: TARGET_ENUM = "exfoliation_energy"
atom_features: Literal["basic", "atomic_number", "cfid", "cgcnn"] = "cgcnn"
neighbor_strategy: Literal["k-nearest", "voronoi", "radius_graph"] = (
"k-nearest"
Expand Down Expand Up @@ -226,7 +226,7 @@ class TrainingConfig(BaseSettings):
# ALIGNN_LN_Config,
# DenseALIGNNConfig,
# ACGCNNConfig,
] = ALIGNNConfig(name="alignn")
] = ALIGNNAtomWiseConfig(name="alignn_atomwise")

# @root_validator()
# @model_validator(mode='before')
Expand Down
2 changes: 1 addition & 1 deletion alignn/examples/sample_data/config_example.json
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
"max_neighbors": 12,
"keep_data_order": true,
"model": {
"name": "alignn",
"name": "alignn_atomwise",
"alignn_layers": 4,
"gcn_layers": 4,
"atom_input_features": 92,
Expand Down
3 changes: 2 additions & 1 deletion alignn/graphs.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Module to generate networkx graphs."""

from jarvis.core.atoms import get_supercell_dims
from jarvis.core.specie import Specie
from jarvis.core.utils import random_colors
Expand Down Expand Up @@ -861,7 +862,7 @@ def __getitem__(self, idx):
"""Get StructureDataset sample."""
g = self.graphs[idx]
label = self.labels[idx]

# id = self.ids[idx]
if self.transform:
g = self.transform(g)

Expand Down
6 changes: 4 additions & 2 deletions alignn/models/alignn_atomwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,8 +334,9 @@ def __init__(
)

if self.classification:
self.fc = nn.Linear(config.hidden_features, 2)
self.softmax = nn.LogSoftmax(dim=1)
self.fc = nn.Linear(config.hidden_features, 1)
self.softmax = nn.Sigmoid()
# self.softmax = nn.LogSoftmax(dim=1)
else:
self.fc = nn.Linear(config.hidden_features, config.output_features)
self.link = None
Expand Down Expand Up @@ -544,6 +545,7 @@ def forward(
out = self.link(out)

if self.classification:
# out = torch.max(out,dim=1)
out = self.softmax(out)
result["out"] = out
result["grad"] = forces
Expand Down
28 changes: 14 additions & 14 deletions alignn/tests/test_prop.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,37 +50,37 @@
# os.system(cmd3)


def test_minor_configs():
tmp = config
# tmp["log_tensorboard"] = True
tmp["n_early_stopping"] = 2
tmp["model"]["name"] = "alignn"
config["write_predictions"] = True
result = train_dgl(tmp)
# def test_minor_configs():
# tmp = config
# # tmp["log_tensorboard"] = True
# tmp["n_early_stopping"] = 2
# tmp["model"]["name"] = "alignn"
# config["write_predictions"] = True
# result = train_dgl(tmp)


def test_models():

config["write_predictions"] = True
config["model"]["name"] = "alignn"
config["model"]["name"] = "alignn_atomwise"
t1 = time.time()
result = train_dgl(config)
t2 = time.time()
print("Total time", t2 - t1)
print("train=", result["train"])
print("validation=", result["validation"])
# print("train=", result["train"])
# print("validation=", result["validation"])
print()
print()
print()

config["model"]["name"] = "alignn"
config["model"]["name"] = "alignn_atomwise"
config["classification_threshold"] = 0.0
t1 = time.time()
result = train_dgl(config)
t2 = time.time()
print("Total time", t2 - t1)
print("train=", result["train"])
print("validation=", result["validation"])
# print("train=", result["train"])
# print("validation=", result["validation"])
print()
print()
print()
Expand Down Expand Up @@ -241,5 +241,5 @@ def test_del_files():
# test_pretrained()
# test_runtime_training()
# test_alignn_train()
# test_models()
test_models()
# test_calculator()
Loading

0 comments on commit 809dedf

Please sign in to comment.