Skip to content

Commit

Permalink
Make sure to remove previous lmdb files.
Browse files Browse the repository at this point in the history
  • Loading branch information
knc6 committed Apr 30, 2024
1 parent 8bf645a commit 7e30e31
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 2 deletions.
4 changes: 4 additions & 0 deletions alignn/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,10 @@ def get_train_val_loaders(
train_sample = filename + "_train.data"
val_sample = filename + "_val.data"
test_sample = filename + "_test.data"
if os.path.exists(train_sample):
print("If you are training from scratch, run")
cmd = "rm -r " + train_sample + " " + val_sample + " " + test_sample
print(cmd)
# print ('output_dir data',output_dir)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
Expand Down
27 changes: 25 additions & 2 deletions alignn/tests/test_prop.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,13 @@ def test_pretrained():
get_multiple_predictions(atoms_array=[Si, Si])


def test_alignn_train():
world_size = int(torch.cuda.device_count())


def test_alignn_train_regression():
# Regression
cmd = "rm -rf train_data test_data val_data"
os.system(cmd)
root_dir = os.path.abspath(
os.path.join(os.path.dirname(__file__), "../examples/sample_data/")
)
Expand All @@ -138,11 +144,15 @@ def test_alignn_train():
"../examples/sample_data/config_example.json",
)
)
world_size = int(torch.cuda.device_count())
train_for_folder(
rank=0, world_size=world_size, root_dir=root_dir, config_name=config
)


def test_alignn_train_regression_multi_out():
cmd = "rm -rf train_data test_data val_data"
os.system(cmd)
# Regression multi-out
root_dir = os.path.abspath(
os.path.join(
os.path.dirname(__file__), "../examples/sample_data_multi_prop/"
Expand All @@ -158,6 +168,11 @@ def test_alignn_train():
rank=0, world_size=world_size, root_dir=root_dir, config_name=config
)


def test_alignn_train_classification():
cmd = "rm -rf train_data test_data val_data"
os.system(cmd)
# Classification
root_dir = os.path.abspath(
os.path.join(os.path.dirname(__file__), "../examples/sample_data/")
)
Expand All @@ -175,6 +190,11 @@ def test_alignn_train():
classification_threshold=0.01,
)


def test_alignn_train_ff():
cmd = "rm -rf train_data test_data val_data"
os.system(cmd)
# FF
root_dir = os.path.abspath(
os.path.join(os.path.dirname(__file__), "../examples/sample_data_ff/")
)
Expand Down Expand Up @@ -248,6 +268,9 @@ def test_del_files():
os.system(cmd)


# test_alignn_train_ff()
# test_alignn_train_classification()
# test_alignn_train()
# test_minor_configs()
# test_pretrained()
# test_runtime_training()
Expand Down

0 comments on commit 7e30e31

Please sign in to comment.