-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
35a5996
commit a071ca1
Showing
34 changed files
with
270 additions
and
1,355 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
metatensor-models | ||
----------------- | ||
|
||
This is a repository for models using metatensor, in one shape or another. The only | ||
requirement is for these models to be able to take metatensor objects as inputs and | ||
outputs. The models do not need to live entirely in this repository: in the most extreme | ||
case, this repository can simply contain a wrapper to an external model. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,11 @@ | ||
from pathlib import Path | ||
import torch | ||
|
||
PACKAGE_ROOT = Path(__file__).parent.resolve() | ||
|
||
CONFIG_PATH = PACKAGE_ROOT / "cli" / "conf" | ||
ARCHITECTURE_CONFIG_PATH = CONFIG_PATH / "architecture" | ||
|
||
__version__ = "2023.11.29" | ||
|
||
torch.set_default_dtype(torch.float64) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .eval_model import eval_model # noqa | ||
from .export_model import export_model # noqa | ||
from .train_model import train_model # noqa |
32 changes: 32 additions & 0 deletions
32
src/metatensor/models/cli/conf/architecture/soap_bpnn.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
# default hyperparameters for the SOAP-BPNN model | ||
name: soap_bpnn | ||
|
||
model: | ||
soap: | ||
cutoff: 5.0 | ||
max_radial: 8 | ||
max_angular: 6 | ||
atomic_gaussian_width: 0.3 | ||
radial_basis: | ||
Gto: {} | ||
center_atom_weight: 1.0 | ||
cutoff_function: | ||
ShiftedCosine: | ||
width: 1.0 | ||
radial_scaling: | ||
Willatt2018: | ||
rate: 1.0 | ||
scale: 2.0 | ||
exponent: 7.0 | ||
|
||
bpnn: | ||
num_hidden_layers: 2 | ||
num_neurons_per_layer: 32 | ||
activation_function: SiLU | ||
|
||
training: | ||
batch_size: 8 | ||
num_epochs: 100 | ||
learning_rate: 0.001 | ||
log_interval: 10 | ||
checkpoint_interval: 25 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
defaults: | ||
- architecture: ??? |
2 changes: 1 addition & 1 deletion
2
src/metatensor/models/scripts/evaluate.py → src/metatensor/models/cli/eval_model.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,3 @@ | ||
def evaluate(): | ||
def eval_model(): | ||
"""evaluate a model""" | ||
print("Run evaluate...") |
2 changes: 1 addition & 1 deletion
2
src/metatensor/models/scripts/export.py → src/metatensor/models/cli/export_model.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,3 @@ | ||
def export(): | ||
def export_model(): | ||
"""export a model""" | ||
print("Run exort...") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
import importlib | ||
import logging | ||
|
||
import hydra | ||
from omegaconf import DictConfig, OmegaConf | ||
|
||
from metatensor.models.utils.data import Dataset | ||
from metatensor.models.utils.data.readers import read_structures, read_targets | ||
|
||
from .. import CONFIG_PATH | ||
|
||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
@hydra.main(config_path=str(CONFIG_PATH), config_name="config", version_base=None) | ||
def train_model(config: DictConfig) -> None: | ||
"""train a model.""" | ||
|
||
logger.info("Setting up dataset") | ||
structures = read_structures(config["dataset"]["structure_path"]) | ||
targets = read_targets( | ||
config["dataset"]["targets_path"], | ||
target_value=config["dataset"]["target_value"], | ||
) | ||
dataset = Dataset(structures, targets) | ||
|
||
logger.info("Setting up model") | ||
architetcure_name = config["architecture"]["name"] | ||
architecture = importlib.import_module(f"metatensor.models.{architetcure_name}") | ||
model = architecture.Model( | ||
all_species=dataset.all_species, | ||
hypers=OmegaConf.to_container(config["architecture"]["model"]), | ||
) | ||
|
||
logger.info("Run training") | ||
architecture.train( | ||
model=model, | ||
train_dataset=dataset, | ||
hypers=OmegaConf.to_container(config["architecture"]["training"]), | ||
) |
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,2 @@ | ||
from .model import SoapBPNN # noqa: F401 | ||
from .model import Model # noqa: F401 | ||
from .train import train # noqa: F401 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
from pathlib import Path | ||
|
||
from metatensor.models import ARCHITECTURE_CONFIG_PATH | ||
from omegaconf import OmegaConf | ||
|
||
|
||
DEAFAULT_HYPERS = OmegaConf.to_container( | ||
OmegaConf.load(ARCHITECTURE_CONFIG_PATH / "soap_bpnn.yaml") | ||
) | ||
DATASET_PATH = str( | ||
Path(__file__).parent.resolve() / "../../../../../tests/resources/qm9_reduced_100.xyz" | ||
) |
This file was deleted.
Oops, something went wrong.
13 changes: 3 additions & 10 deletions
13
src/metatensor/models/soap_bpnn/tests/test_functionality.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,25 +1,18 @@ | ||
import os | ||
|
||
import ase | ||
import rascaline.torch | ||
import torch | ||
import yaml | ||
|
||
from metatensor.models.soap_bpnn import SoapBPNN | ||
|
||
from metatensor.models.soap_bpnn import Model | ||
|
||
path = os.path.dirname(__file__) | ||
hypers_path = os.path.join(path, "../default.yml") | ||
dataset_path = os.path.join(path, "data/qm9_reduced_100.xyz") | ||
from . import DEAFAULT_HYPERS | ||
|
||
|
||
def test_prediction_subset(): | ||
"""Tests that the model can predict on a subset | ||
of the elements it was trained on.""" | ||
|
||
all_species = [1, 6, 7, 8] | ||
hypers = yaml.safe_load(open(hypers_path, "r")) | ||
soap_bpnn = SoapBPNN(all_species, hypers).to(torch.float64) | ||
soap_bpnn = Model(all_species, DEAFAULT_HYPERS["model"]).to(torch.float64) | ||
|
||
structure = ase.Atoms("O2", positions=[[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]]) | ||
soap_bpnn([rascaline.torch.systems_to_torch(structure)]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.