Skip to content

Commit

Permalink
Match the pylint requirements
Browse files Browse the repository at this point in the history
  • Loading branch information
ko tsz wai authored and ko tsz wai committed Dec 6, 2022
2 parents 6984884 + fab4c43 commit 4e53c05
Show file tree
Hide file tree
Showing 22 changed files with 842 additions and 194 deletions.
1 change: 0 additions & 1 deletion .github/workflows/linting.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,3 @@ jobs:
run: |
# pip install pydocstyle --upgrade --quiet
pydocstyle --count megnet
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,6 @@ setuptools*
gulptmp_4_1
.coverage
.mypy_cache
.vs
env
venv
7 changes: 1 addition & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ default_language_version:

ci:
autoupdate_schedule: monthly
skip: [flake8, mypy, pylint]
skip: [flake8, pylint]

repos:

Expand Down Expand Up @@ -50,11 +50,6 @@ repos:
hooks:
- id: flake8

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.961
hooks:
- id: mypy

- repo: local
hooks:
- id: pylint
Expand Down
46 changes: 39 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,16 +1,48 @@
# Introduction

This is a reimplementation of the [MatErials Graph Network (MEGNet)](https://github.com/materialsvirtuallab/megnet)
and [3-body MEGNet (m3gnet)](https://github.com/materialsvirtuallab/m3gnet) in DGL in an effort to
improve its extensibility and scalability. The original MEGNet and M3GNet were implemented in TensorFlow. It is a
collaboration between the Materials Virtual Lab and Intel Labs (Santiago Miret, Marcel Nassar, Carmelo Gonzales).
Mathematical graphs are a natural representation for a collection of atoms (e.g., molecules or crystals). Graph deep
learning models have been shown to consistently deliver exceptional performance as surrogate models for the prediction
of materials properties.

This repository is a unified reimplementation of the [3-body MatErials Graph Network (m3gnet)](https://github.com/materialsvirtuallab/m3gnet)
and its predecessors, [MEGNet](https://github.com/materialsvirtuallab/megnet) using the [Deep Graph Library (DGL)](https://www.dgl.ai).
The goal is to improve the usability, extensibility and scalability of these models. The original M3GNet and MEGNet were
implemented in TensorFlow.

This effort is a collaboration between the [Materials Virtual Lab](http://materialsvirtuallab.org) and Intel Labs
(Santiago Miret, Marcel Nassar, Carmelo Gonzales).

# Status

This repository is still a work in progress. At the present moment, only the MEGNet architecture has been implemented.
We are still extensively testing the implementation. The plan is to complete implementation of M3GNet by end of 2022.
At the present moment, only the simpler MEGNet architecture has been implemented. The implementation has been
extensively tested. It is reasonably robust and performs several times faster than the original TF implementation.

For users wishing to use the pre-trained models as-is, we recommend you check out the official [MEGNet](https://github.com/materialsvirtuallab/megnet)
The plan is to complete implementation of M3GNet by end of 2022.

For users wishing to use the pre-trained models, we recommend you check out the official [MEGNet](https://github.com/materialsvirtuallab/megnet)
and [M3GNet](https://github.com/materialsvirtuallab/m3gnet) implementations. For new users wishing to train new MEGNet
models, we welcome you to use this DGL implementation. Any contributions, e.g., code improvements or issue reports, are
very welcome!

# References

Please cite the following works:

- MEGNET
```txt
Chen, C.; Ye, W.; Zuo, Y.; Zheng, C.; Ong, S. P. Graph Networks as a Universal Machine Learning Framework for
Molecules and Crystals. Chem. Mater. 2019, 31 (9), 3564–3572. https://doi.org/10.1021/acs.chemmater.9b01294.
```
- M3GNet
```txt
Chen, C., Ong, S.P. A universal graph deep learning interatomic potential for the periodic table. Nat Comput Sci,
2, 718–728 (2022). https://doi.org/10.1038/s43588-022-00349-3.
```
# Acknowledgements
This work was primarily supported by the Materials Project, funded by the U.S. Department of Energy, Office of Science,
Office of Basic Energy Sciences, Materials Sciences and Engineering Division under contract no.
DE-AC02-05-CH11231: Materials Project program KC23MP. This work used the Expanse supercomputing cluster at the Extreme
Science and Engineering Discovery Environment (XSEDE), which is supported by National Science Foundation grant number
ACI-1548562.
2 changes: 1 addition & 1 deletion configs/qm9_test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@ data:
verbose: False
split:
val_size: 1000
test_size: 1000
test_size: 1000
53 changes: 25 additions & 28 deletions utils/utils.py → examples/qm9_utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import os, json
from __future__ import annotations

import json
import os
from collections import namedtuple
from random import seed as python_seed
from time import sleep
from typing import Dict, List, Tuple, Union

import numpy as np
import torch
import torch.nn as nn
import yaml
from dgl.data import QM9EdgeDataset
from dgl.dataloading import GraphDataLoader
Expand All @@ -24,7 +23,6 @@ def prepare_munch_object(path: str) -> Munch:


def prepare_config(path: str) -> Munch:
root = '/'.join(path.split('/')[:-1])
config = prepare_munch_object(path)

# for k, v in config.model.items():
Expand All @@ -33,19 +31,19 @@ def prepare_config(path: str) -> Munch:
return config


def compute_data_stats(dataset) -> Tuple[torch.Tensor]:
def compute_data_stats(dataset) -> tuple:
graphs, targets = zip(*dataset)
targets = torch.cat(targets)

z_mean_list = []
num_bond_mean_list = []

for g in graphs:
z_mean_list.append(torch.mean(g.ndata['attr']))
z_mean_list.append(torch.mean(g.ndata["attr"]))
temp = 0
for ii in range(g.num_nodes()):
temp += len(g.successors(ii))
num_bond_mean_list.append(torch.tensor(temp/g.num_nodes()))
num_bond_mean_list.append(torch.tensor(temp / g.num_nodes()))

data_std, data_mean = torch.std_mean(targets)

Expand All @@ -55,22 +53,21 @@ def compute_data_stats(dataset) -> Tuple[torch.Tensor]:
return data_std, data_mean, data_zmean, num_bond_mean


def prepare_data(config: Munch) -> namedtuple:
print('## Started data processing ##')
def prepare_data(config: Munch) -> tuple:
print("## Started data processing ##")

if config.data.dataset == 'qm9':
if config.data.dataset == "qm9":
dataset = QM9EdgeDataset(**config.data.source)

val_size = config.data.split.val_size
test_size = config.data.split.test_size
train_size = len(dataset) - val_size - test_size

train_data, val_data, test_data = random_split(
dataset, [train_size, val_size, test_size])
train_data, val_data, test_data = random_split(dataset, [train_size, val_size, test_size])

data_std, data_mean, data_zmean, num_bond_mean = compute_data_stats(train_data)

data = namedtuple('Data', ['train', 'val', 'test', 'std', 'mean'])
data = namedtuple("data", ["train", "val", "test", "std", "mean"])

data.train = train_data
data.val = val_data
Expand All @@ -80,8 +77,7 @@ def prepare_data(config: Munch) -> namedtuple:
data.z_mean = data_zmean
data.num_bond_mean = num_bond_mean


print('## Finished data processing ##')
print("## Finished data processing ##")

return data

Expand All @@ -93,22 +89,22 @@ def set_seed(seed: int) -> None:
dgl_seed(seed)


def create_dataloaders(config: Munch, data: namedtuple):
dataloaders = namedtuple('Dataloaders', ['train', 'val', 'test'])
def create_dataloaders(config: Munch, data: tuple):
dataloaders = namedtuple("Dataloaders", ["train", "val", "test"])

dataloaders.train = GraphDataLoader(
data.train,
pin_memory=False,
batch_size=config.data.batch_size
# **config.experiment.train,
)
dataloaders.val = GraphDataLoader(data.val)#, **config.experiment.val)
dataloaders.test = GraphDataLoader(data.test)#, **config.experiment.test)
dataloaders.val = GraphDataLoader(data.val) # , **config.experiment.val)
dataloaders.test = GraphDataLoader(data.test) # , **config.experiment.test)

return dataloaders


class StreamingJSONWriter(object):
class StreamingJSONWriter:
"""
Serialize streaming data to JSON.
Expand All @@ -118,13 +114,14 @@ class StreamingJSONWriter(object):
When a new item is added, the file cursor is moved backwards to overwrite
the list closing bracket.
"""

def __init__(self, filename, encoder=json.JSONEncoder):
if os.path.exists(filename):
self.file = open(filename, 'r+')
self.delimeter = ','
self.file = open(filename, "r+")
self.delimeter = ","
else:
self.file = open(filename, 'w')
self.delimeter = '['
self.file = open(filename, "w")
self.delimeter = "["
self.encoder = encoder

def dump(self, obj):
Expand All @@ -134,9 +131,9 @@ def dump(self, obj):
data = json.dumps(obj, cls=self.encoder)
close_str = "\n]\n"
self.file.seek(max(self.file.seek(0, os.SEEK_END) - len(close_str), 0))
self.file.write("%s\n %s%s" % (self.delimeter, data, close_str))
self.file.write(f"{self.delimeter}\n {data}{close_str}")
self.file.flush()
self.delimeter = ','
self.delimeter = ","

def close(self):
self.file.close()
Loading

0 comments on commit 4e53c05

Please sign in to comment.