Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove pytorch-ignite, train with alignn_atomwise model only. #144

Merged
merged 22 commits into from
Mar 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.9]
python-version: ['3.10']

steps:
- uses: actions/checkout@v2
Expand All @@ -27,24 +27,25 @@ jobs:
pip install flake8 pytest pycodestyle pydocstyle
pycodestyle --ignore E203,W503 --exclude=tests alignn
pydocstyle --match-dir=core --match-dir=io --match-dir=io --match-dir=ai --match-dir=analysis --match-dir=db --match-dir=tasks --count alignn
flake8 --ignore E203,W503 --exclude=tests --statistics --count --exit-zero alignn
flake8 --ignore E203,W503 --exclude=tests,scripts --statistics --count --exit-zero alignn
- name: Test with pytest
run: |
export DGLBACKEND=pytorch
export CUDA_VISIBLE_DEVICES="-1"
#pip install dgl-cu111
pip install flake8 pytest pycodestyle pydocstyle codecov pytest-cov coverage
pip install phonopy flake8 pytest pycodestyle pydocstyle codecov pytest-cov coverage
#pip uninstall -y torch nvidia-cublas-cu11 nvidia-cuda-nvrtc-cu11 nvidia-cuda-runtime-cu11 nvidia-cudnn-cu11
#conda install -y pytorch-cpu
pip install torch==2.0.0
#pip install attrs==22.1.0 certifi==2022.9.24 charset-normalizer==2.1.1 codecov==2.1.12 contourpy==1.0.5 coverage==6.5.0 cycler==0.11.0 dgl==0.9.1 flake8==5.0.4 fonttools==4.38.0 idna==3.4 iniconfig==1.1.1 jarvis-tools==2022.9.16 joblib==1.2.0 kiwisolver==1.4.4 matplotlib==3.6.1 mccabe==0.7.0 networkx==3.0b1 numpy==1.23.4 packaging==21.3 pandas==1.5.1 Pillow==9.2.0 pluggy==1.0.0 psutil==5.9.3 py==1.11.0 pycodestyle==2.9.1 pydantic==1.10.2 pydocstyle==6.1.1 pyflakes==2.5.0 pyparsing==2.4.7 pytest==7.1.3 pytest-cov==4.0.0 python-dateutil==2.8.2 pytorch-ignite==0.5.0.dev20221024 pytz==2022.5 requests==2.28.1 scikit-learn==1.1.2 scipy==1.9.3 six==1.16.0 snowballstemmer==2.2.0 spglib==2.0.1 threadpoolctl==3.1.0 tomli==2.0.1 toolz==0.12.0 torch==1.12.1 tqdm==4.64.1 typing_extensions==4.4.0 urllib3==1.26.12 xmltodict==0.13.0
echo 'PIP freeze'
pip freeze
coverage run -m pytest
coverage report -m -i
codecov
codecov --token="85bd9c5d-9e55-4f6d-bd69-350ee5e3bb41"
echo 'Train folder'
train_folder.py -h
echo 'Train alignn'
train_alignn.py -h
echo 'Pre-trained models'
pretrained.py -h
#train_folder.py --root_dir "alignn/examples/sample_data" --config "alignn/examples/sample_data/config_example.json" --output_dir=temp
42 changes: 30 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,10 @@ pip install dgl==1.0.1+cu117 -f https://data.dgl.ai/wheels/cu117/repo.html
Examples
---------

#### Dataset
The main script to train model is `train_folder.py`. A user needs at least the following info to train a model: 1) `id_prop.csv` with name of the file and corresponding value, 2) `config_example.json` a config file with training and hyperparameters.
Here, we provide examples for property prediction tasks, development of machine-learning force-fields (MLFF), usage of pre-trained property predictor, MLFFs, webapps etc.

#### Dataset preparation for property prediction tasks
The main script to train model is `train_alignn.py`. A user needs at least the following info to train a model: 1) `id_prop.csv` with name of the file and corresponding value, 2) `config_example.json` a config file with training and hyperparameters.

Users can keep their structure files in `POSCAR`, `.cif`, `.xyz` or `.pdb` files in a directory. In the examples below we will use POSCAR format files. In the same directory, there should be an `id_prop.csv` file.

Expand All @@ -123,31 +125,35 @@ The dataset in split in 80:10:10 as training-validation-test set (controlled by
A brief help guide (`-h`) can be obtained as follows.

```
train_folder.py -h
train_alignn.py -h
```
#### Regression example
Now, the model is trained as follows. Please increase the `batch_size` parameter to something like 32 or 64 in `config_example.json` for general trainings.

```
train_folder.py --root_dir "alignn/examples/sample_data" --config "alignn/examples/sample_data/config_example.json" --output_dir=temp
train_alignn.py --root_dir "alignn/examples/sample_data" --config "alignn/examples/sample_data/config_example.json" --output_dir=temp
```
#### Classification example
While the above example is for regression, the follwoing example shows a classification task for metal/non-metal based on the above bandgap values. We transform the dataset
into 1 or 0 based on a threshold of 0.01 eV (controlled by the parameter, `classification_threshold`) and train a similar classification model. Currently, the script allows binary classification tasks only.
```
train_folder.py --root_dir "alignn/examples/sample_data" --classification_threshold 0.01 --config "alignn/examples/sample_data/config_example.json" --output_dir=temp
train_alignn.py --root_dir "alignn/examples/sample_data" --classification_threshold 0.01 --config "alignn/examples/sample_data/config_example.json" --output_dir=temp
```

#### Multi-output model example
While the above example regression was for single-output values, we can train multi-output regression models as well.
An example is given below for training formation energy per atom, bandgap and total energy per atom simulataneously. The script to generate the example data is provided in the script folder of the sample_data_multi_prop. Another example of training electron and phonon density of states is provided also.
```
train_folder.py --root_dir "alignn/examples/sample_data_multi_prop" --config "alignn/examples/sample_data/config_example.json" --output_dir=temp
train_alignn.py --root_dir "alignn/examples/sample_data_multi_prop" --config "alignn/examples/sample_data/config_example.json" --output_dir=temp
```
#### Automated model training
Users can try training using multiple example scripts to run multiple dataset (such as JARVIS-DFT, Materials project, QM9_JCTC etc.). Look into the [alignn/scripts/train_*.py](https://github.com/usnistgov/alignn/tree/main/alignn/scripts) folder. This is done primarily to make the trainings more automated rather than making folder/ csv files etc.
These scripts automatically download datasets from [Databases in jarvis-tools](https://jarvis-tools.readthedocs.io/en/master/databases.html) and train several models. Make sure you specify your specific queuing system details in the scripts.

#### other examples

Additional example trainings for [2D-exfoliation energy](https://colab.research.google.com/github/knc6/jarvis-tools-notebooks/blob/master/jarvis-tools-notebooks/alignn_jarvis_leaderboard.ipynb), [superconductor transition temperature](https://colab.research.google.com/github/knc6/jarvis-tools-notebooks/blob/master/jarvis-tools-notebooks/ALIGNN_Sc.ipynb).

<a name="pretrained"></a>
Using pre-trained models
-------------------------
Expand Down Expand Up @@ -177,6 +183,8 @@ The following [notebook](https://colab.research.google.com/github/knc6/jarvis-to

The following [notebook](https://colab.research.google.com/github/knc6/jarvis-tools-notebooks/blob/master/jarvis-tools-notebooks/Train_ALIGNNFF_Mlearn.ipynb) provides an example of ALIGNN-FF model.

For additional notebooks, checkout [JARVIS-Tools-Notebooks](https://github.com/JARVIS-Materials-Design/jarvis-tools-notebooks?tab=readme-ov-file#artificial-intelligencemachine-learning)

<a name="webapp"></a>
Web-app
------------
Expand All @@ -191,6 +199,8 @@ A basic web-app is for direct-prediction available at [JARVIS-ALIGNN app](https:
ALIGNN-FF
-------------------------

Atomisitic line graph neural network-based FF (ALIGNN-FF) can be used to model both structurally and chemically diverse systems with any combination of 89 elements from the periodic table. To train the ALIGNN-FF model, we have used the JARVIS-DFT dataset which contains around 75000 materials and 4 million energy-force entries, out of which 307113 are used in the training. These models can be further finetuned, or new models can be developed from scratch on a new dataset.

[ASE calculator](https://wiki.fysik.dtu.dk/ase/ase/calculators/calculators.html) provides interface to various codes. An example for ALIGNN-FF is give below. Note that there are multiple pretrained ALIGNN-FF models available, here we use the deafult_path model. As more accurate models are developed, they will be made available as well:

```
Expand Down Expand Up @@ -226,14 +236,21 @@ plt.ylabel('Total energy (eV)')
plt.show()
```

To train ALIGNN-FF use `train_folder_ff.py` script which uses `atomwise_alignn` model:
To train ALIGNN-FF use `train_alignn.py` script which uses `atomwise_alignn` model:

AtomWise prediction example which looks for similar setup as before but unstead of `id_prop.csv`, it requires `id_prop.json` file (see example in the sample_data_ff directory). An example to compile vasprun.xml files into a id_prop.json is kept [here](https://colab.research.google.com/gist/knc6/5513b21f5fd83a7943509ffdf5c3608b/make_id_prop.ipynb). Note ALIGNN-FF requires energy stored as energy per atom:

AtomWise prediction example which looks for similar setup as before but unstead of `id_prop.csv`, it requires `id_prop.json` file (see example in the sample_data_ff directory). Note ALIGNN-FF requires energy stored as energy per atom:

```
train_folder_ff.py --root_dir "alignn/examples/sample_data_ff" --config "alignn/examples/sample_data_ff/config_example_atomwise.json" --output_dir=temp
train_alignn.py --root_dir "alignn/examples/sample_data_ff" --config "alignn/examples/sample_data_ff/config_example_atomwise.json" --output_dir=temp
```


To finetune model, use `--restart_model_path` tag as well in the above with the path of a pretrained ALIGNN-FF model with same model confurations.

An example for training MLFF for silicon is provided [here](https://colab.research.google.com/github/knc6/jarvis-tools-notebooks/blob/master/jarvis-tools-notebooks/Train_ALIGNNFF_Mlearn.ipynb). It is highly recommeded to get familiar with this example before developing a new model. Note: new model configs such as `lg_on_fly` and `add_reverse_forces` should be defaulted to True for newer versions. For MD runs, `use_cutoff_function` is recommended.


A pretrained ALIGNN-FF (under active development right now) can be used for predicting several properties, such as:

```
Expand All @@ -248,7 +265,7 @@ To know about other tasks, type.
run_alignn_ff.py -h
```


Several supporting scripts for stucture optimization, equation of states, phonon and related calculations are provided in the repo as well. If you need further assistance for a particular task, feel free to raise an GitHus issue.

<a name="performances"></a>

Expand Down Expand Up @@ -386,9 +403,10 @@ Useful notes (based on some of the queries we received)
1) If you are using GPUs, make sure you have a compatible dgl-cuda version installed, for example: dgl-cu101 or dgl-cu111, so e.g. `pip install dgl-cu111` .
2) While comnventional '.cif' and '.pdb' files can be read using jarvis-tools, for complex files you might have to install `cif2cell` and `pytraj` respectively i.e.`pip install cif2cell==2.0.0a3` and `conda install -c ambermd pytraj`.
3) Make sure you use `batch_size` as 32 or 64 for large datasets, and not 2 as given in the example config file, else it will take much longer to train, and performnce might drop a lot.
4) Note that `train_folder.py` and `pretrained.py` in alignn folder are actually python executable scripts. So, even if you don't provide absolute path of these scripts, they should work.
4) Note that `train_alignn.py` and `pretrained.py` in alignn folder are actually python executable scripts. So, even if you don't provide absolute path of these scripts, they should work.
5) Learn about the issue with QM9 results here: https://github.com/usnistgov/alignn/issues/54
6) Make sure you have `pandas` version as 1.2.3.
6) Make sure you have `pandas` version as >1.2.3.
7) Starting March 2024, pytroch-ignite dependency will be removed to enable conda-forge build.


<a name="refs"></a>
Expand Down
3 changes: 2 additions & 1 deletion alignn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
"""Version number."""
__version__ = "2024.2.4"

__version__ = "2024.3.4"
63 changes: 32 additions & 31 deletions alignn/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,19 @@
import subprocess
from typing import Optional, Union
import os
from pydantic import root_validator
from pydantic.typing import Literal
from typing import Literal
from alignn.utils import BaseSettings
from alignn.models.modified_cgcnn import CGCNNConfig
from alignn.models.icgcnn import ICGCNNConfig
from alignn.models.gcn import SimpleGCNConfig
from alignn.models.densegcn import DenseGCNConfig
from alignn.models.alignn import ALIGNNConfig
from alignn.models.alignn_atomwise import ALIGNNAtomWiseConfig
from alignn.models.dense_alignn import DenseALIGNNConfig
from alignn.models.alignn_cgcnn import ACGCNNConfig
from alignn.models.alignn_layernorm import ALIGNNConfig as ALIGNN_LN_Config

# from alignn.models.modified_cgcnn import CGCNNConfig
# from alignn.models.icgcnn import ICGCNNConfig
# from alignn.models.gcn import SimpleGCNConfig
# from alignn.models.densegcn import DenseGCNConfig
# from pydantic import model_validator
# from alignn.models.dense_alignn import DenseALIGNNConfig
# from alignn.models.alignn_cgcnn import ACGCNNConfig
# from alignn.models.alignn_layernorm import ALIGNNConfig as ALIGNN_LN_Config

# from typing import List

Expand Down Expand Up @@ -159,11 +160,11 @@ class TrainingConfig(BaseSettings):
"tinnet_O",
"tinnet_N",
] = "dft_3d"
target: TARGET_ENUM = "formation_energy_peratom"
target: TARGET_ENUM = "exfoliation_energy"
atom_features: Literal["basic", "atomic_number", "cfid", "cgcnn"] = "cgcnn"
neighbor_strategy: Literal[
"k-nearest", "voronoi", "radius_graph"
] = "k-nearest"
neighbor_strategy: Literal["k-nearest", "voronoi", "radius_graph"] = (
"k-nearest"
)
id_tag: Literal["jid", "id", "_oqmd_entry_id"] = "jid"

# logging configuration
Expand Down Expand Up @@ -216,26 +217,26 @@ class TrainingConfig(BaseSettings):

# model configuration
model: Union[
CGCNNConfig,
ICGCNNConfig,
SimpleGCNConfig,
DenseGCNConfig,
ALIGNNConfig,
ALIGNNAtomWiseConfig,
ALIGNN_LN_Config,
DenseALIGNNConfig,
ACGCNNConfig,
] = ALIGNNConfig(name="alignn")
# ] = CGCNNConfig(name="cgcnn")

@root_validator()
def set_input_size(cls, values):
"""Automatically configure node feature dimensionality."""
values["model"].atom_input_features = FEATURESET_SIZE[
values["atom_features"]
]

return values
# CGCNNConfig,
# ICGCNNConfig,
# SimpleGCNConfig,
# DenseGCNConfig,
# ALIGNN_LN_Config,
# DenseALIGNNConfig,
# ACGCNNConfig,
] = ALIGNNAtomWiseConfig(name="alignn_atomwise")

# @root_validator()
# @model_validator(mode='before')
# def set_input_size(cls, values):
# """Automatically configure node feature dimensionality."""
# values["model"].atom_input_features = FEATURESET_SIZE[
# values["atom_features"]
# ]

# return values

# @property
# def atom_input_features(self):
Expand Down
2 changes: 1 addition & 1 deletion alignn/examples/sample_data/config_example.json
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
"max_neighbors": 12,
"keep_data_order": true,
"model": {
"name": "alignn",
"name": "alignn_atomwise",
"alignn_layers": 4,
"gcn_layers": 4,
"atom_input_features": 92,
Expand Down
3 changes: 2 additions & 1 deletion alignn/graphs.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Module to generate networkx graphs."""

from jarvis.core.atoms import get_supercell_dims
from jarvis.core.specie import Specie
from jarvis.core.utils import random_colors
Expand Down Expand Up @@ -861,7 +862,7 @@ def __getitem__(self, idx):
"""Get StructureDataset sample."""
g = self.graphs[idx]
label = self.labels[idx]

# id = self.ids[idx]
if self.transform:
g = self.transform(g)

Expand Down
3 changes: 2 additions & 1 deletion alignn/models/alignn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

A prototype crystal line graph network dgl implementation.
"""

from typing import Tuple, Union

import dgl
Expand All @@ -11,7 +12,7 @@
from dgl.nn import AvgPooling

# from dgl.nn.functional import edge_softmax
from pydantic.typing import Literal
from typing import Literal
from torch import nn
from torch.nn import functional as F

Expand Down
9 changes: 6 additions & 3 deletions alignn/models/alignn_atomwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

A prototype crystal line graph network dgl implementation.
"""

from typing import Tuple, Union
from torch.autograd import grad
import dgl
Expand All @@ -11,7 +12,7 @@
import torch

# from dgl.nn.functional import edge_softmax
from pydantic.typing import Literal
from typing import Literal
from torch import nn
from torch.nn import functional as F
from alignn.models.utils import RBFExpansion
Expand Down Expand Up @@ -333,8 +334,9 @@ def __init__(
)

if self.classification:
self.fc = nn.Linear(config.hidden_features, 2)
self.softmax = nn.LogSoftmax(dim=1)
self.fc = nn.Linear(config.hidden_features, 1)
self.softmax = nn.Sigmoid()
# self.softmax = nn.LogSoftmax(dim=1)
else:
self.fc = nn.Linear(config.hidden_features, config.output_features)
self.link = None
Expand Down Expand Up @@ -543,6 +545,7 @@ def forward(
out = self.link(out)

if self.classification:
# out = torch.max(out,dim=1)
out = self.softmax(out)
result["out"] = out
result["grad"] = forces
Expand Down
Loading
Loading