-
Notifications
You must be signed in to change notification settings - Fork 5
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add an interface to PET #68
Merged
Merged
Changes from 12 commits
Commits
Show all changes
31 commits
Select commit
Hold shift + click to select a range
fc65cc8
pet wrapper minimal
spozdn 6abe7ee
Not dataset, but list of graphs
frostedoyster 54a96ab
Capability checks
frostedoyster 36b6a77
Convert dataset to `ase.Atoms`, not `pyg.graphs`
frostedoyster 41e60a3
Convert energy and forces to numpy when converting systems to ASE
frostedoyster 7a6024a
Merge branch 'main' into pet_wrapper
frostedoyster 4da14ff
Basic folder structure for PET
frostedoyster 1d655f2
Updated PET Model API
abmazitov ef792c7
Update of the model API
abmazitov 33a3359
Merge branch 'pet_wrapper' of https://github.com/lab-cosmo/metatensor…
abmazitov f47869a
Updated torch compatilility
abmazitov 40898d0
Removed non-relevant tests
abmazitov 45b90ad
Make PET train!
frostedoyster 0f4d692
Clean-up
frostedoyster 977ba61
avoiding pyg interlay
spozdn b84e5cc
Fixed a few torchscript roadblocks
frostedoyster ddde485
Run PET tests on GitHub
frostedoyster 19c99e9
making wrapper utilities torch scriptable
spozdn 15a77ec
cleanup
spozdn 60f55e1
Allow PET to evaluate
frostedoyster 539bb45
Merge branch 'main' into pet_wrapper
frostedoyster 5648b0e
Update output keys of the PET wrapper
frostedoyster 8e6548e
Add `system_to_ase` to tests and docs
frostedoyster 4010eed
Merge branch 'main' into pet_wrapper
frostedoyster 06e32d8
Apply suggestions from reviewers
frostedoyster 7886abd
Clean up PET hypers
frostedoyster 67488ad
Fix description of PET tests in `tox.ini`
frostedoyster 7983d69
Re-add MLIP-related hypers to default hypers
frostedoyster 45b8ef0
Apply suggestions from code review
frostedoyster cff80c0
Merge branch 'main' into pet_wrapper
frostedoyster e705a33
Merge branch 'pet_wrapper' of https://github.com/lab-cosmo/metatensor…
frostedoyster File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
65 changes: 65 additions & 0 deletions
65
src/metatensor/models/cli/conf/architecture/experimental.pet.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
ARCHITECTURAL_HYPERS: | ||
CUTOFF_DELTA: 0.2 | ||
AVERAGE_POOLING: False | ||
TRANSFORMERS_CENTRAL_SPECIFIC: False | ||
HEADS_CENTRAL_SPECIFIC: False | ||
ADD_TOKEN_FIRST: True | ||
ADD_TOKEN_SECOND: True | ||
N_GNN_LAYERS: 3 | ||
TRANSFORMER_D_MODEL: 128 | ||
TRANSFORMER_N_HEAD: 4 | ||
TRANSFORMER_DIM_FEEDFORWARD: 512 | ||
HEAD_N_NEURONS: 128 | ||
N_TRANS_LAYERS: 3 | ||
ACTIVATION: silu | ||
USE_LENGTH: True | ||
USE_ONLY_LENGTH: False | ||
R_CUT: 5.0 | ||
R_EMBEDDING_ACTIVATION: False | ||
COMPRESS_MODE: mlp | ||
BLEND_NEIGHBOR_SPECIES: False | ||
AVERAGE_BOND_ENERGIES: False | ||
USE_BOND_ENERGIES: True | ||
USE_ADDITIONAL_SCALAR_ATTRIBUTES: False | ||
SCALAR_ATTRIBUTES_SIZE: None | ||
TRANSFORMER_TYPE: PostLN # PostLN or PreLN | ||
USE_LONG_RANGE: False | ||
K_CUT: None # should be float; only used when USE_LONG_RANGE is True | ||
|
||
|
||
FITTING_SCHEME: | ||
INITIAL_LR: 1e-4 | ||
EPOCH_NUM_ATOMIC: 1000000000000000000 | ||
SCHEDULER_STEP_SIZE_ATOMIC: 500000000 | ||
EPOCHS_WARMUP_ATOMIC: 250000000 | ||
GLOBAL_AUG: True | ||
SLIDING_FACTOR: 0.7 | ||
ATOMIC_BATCH_SIZE: 850 | ||
MAX_TIME: 234000 | ||
ENERGY_WEIGHT: 0.1 # only used when fitting MLIP | ||
MULTI_GPU: False | ||
RANDOM_SEED: 0 | ||
CUDA_DETERMINISTIC: False | ||
MODEL_TO_START_WITH: None | ||
SUPPORT_MISSING_VALUES: False | ||
USE_WEIGHT_DECAY: False | ||
WEIGHT_DECAY: 0.0 | ||
DO_GRADIENT_CLIPPING: False | ||
GRADIENT_CLIPPING_MAX_NORM: None # must be overwritten if DO_GRADIENT_CLIPPING is True | ||
USE_SHIFT_AGNOSTIC_LOSS: False # only used when fitting general target. Primary use case: EDOS | ||
ENERGIES_LOSS: per_structure # per_structure or per_atom | ||
|
||
MLIP_SETTINGS: # only used when fitting MLIP | ||
ENERGY_KEY: energy | ||
FORCES_KEY: forces | ||
USE_ENERGIES: True | ||
USE_FORCES: True | ||
|
||
GENERAL_TARGET_SETTINGS: # only used when fitting general target | ||
TARGET_TYPE: structural | ||
TARGET_AGGREGATION: sum # sum or mean; only used when TARGET_TYPE is structural | ||
TARGET_DIM: 42 | ||
TARGET_KEY: structural_target | ||
|
||
UTILITY_FLAGS: #for internal usage; do not change/overwrite | ||
CALCULATION_TYPE: None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .model import Model, DEFAULT_HYPERS # noqa: F401 | ||
from .train import train # noqa: F401 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
import torch | ||
import numpy as np | ||
from typing import Dict, List, Optional | ||
from metatensor.torch import Labels, TensorMap, TensorBlock | ||
from metatensor.torch.atomistic import ( | ||
ModelCapabilities, | ||
ModelOutput, | ||
NeighborsListOptions, | ||
System, | ||
) | ||
from omegaconf import OmegaConf | ||
from pet.molecule import batch_to_dict | ||
from pet.pet import PET | ||
from pet.hypers import Hypers | ||
|
||
from ... import ARCHITECTURE_CONFIG_PATH | ||
from .utils import systems_to_pyg_graphs | ||
|
||
|
||
DEFAULT_HYPERS = OmegaConf.to_container( | ||
OmegaConf.load(ARCHITECTURE_CONFIG_PATH / "experimental.pet.yaml") | ||
) | ||
|
||
DEFAULT_MODEL_HYPERS = DEFAULT_HYPERS["ARCHITECTURAL_HYPERS"] | ||
|
||
# We hardcode some of the hypers to make PET model work as a MLIP. | ||
DEFAULT_MODEL_HYPERS.update( | ||
{"D_OUTPUT": 1, "TARGET_TYPE": "structural", "TARGET_AGGREGATION": "sum"} | ||
) | ||
|
||
ARCHITECTURE_NAME = "experimental.pet" | ||
|
||
|
||
class Model(torch.nn.Module): | ||
def __init__( | ||
self, capabilities: ModelCapabilities, hypers: Dict = DEFAULT_MODEL_HYPERS | ||
) -> None: | ||
super().__init__() | ||
self.name = ARCHITECTURE_NAME | ||
self.hypers = hypers | ||
self.cutoff = self.hypers["R_CUT"] | ||
self.all_species = capabilities.species | ||
self.capabilities = capabilities | ||
self.pet = PET(Hypers(self.hypers), 0.0, len(self.all_species)) | ||
|
||
def set_trained_model(self, trained_model: torch.nn.Module) -> None: | ||
self.pet = trained_model | ||
|
||
def requested_neighbors_lists( | ||
self, | ||
) -> List[NeighborsListOptions]: | ||
return [ | ||
NeighborsListOptions( | ||
model_cutoff=self.cutoff, | ||
full_list=True, | ||
) | ||
] | ||
|
||
def forward( | ||
self, | ||
systems: List[System], | ||
outputs: Dict[str, ModelOutput], | ||
selected_atoms: Optional[Labels] = None, | ||
) -> Dict[str, TensorMap]: | ||
if selected_atoms is not None: | ||
raise NotImplementedError("PET does not support selected atoms.") | ||
options = self.requested_neighbors_lists()[0] | ||
batch = systems_to_pyg_graphs(systems, options, self.all_species) | ||
predictions = self.pet(batch_to_dict(batch)) | ||
total_energies: Dict[str, TensorMap] = {} | ||
for output_name in outputs: | ||
total_energies[output_name] = predictions | ||
total_energies[output_name] = TensorMap( | ||
keys=Labels( | ||
names=["lambda", "sigma"], | ||
values=torch.tensor( | ||
[[0, 1]], | ||
device=predictions.device, | ||
PicoCentauri marked this conversation as resolved.
Show resolved
Hide resolved
|
||
), | ||
), | ||
blocks=[ | ||
TensorBlock( | ||
samples=Labels( | ||
names=["structure"], | ||
values=torch.arange( | ||
len(predictions), | ||
device=predictions.device, | ||
).view(-1, 1), | ||
), | ||
components=[], | ||
properties=Labels( | ||
names=["property"], | ||
values=torch.tensor( | ||
len(outputs), | ||
device=predictions.device, | ||
).view(1, -1), | ||
), | ||
values=total_energies[output_name], | ||
) | ||
], | ||
) | ||
return total_energies |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
from pathlib import Path | ||
|
||
DATASET_PATH = str( | ||
Path(__file__).parent.resolve() | ||
/ "../../../../../../tests/resources/qm9_reduced_100.xyz" | ||
) |
47 changes: 47 additions & 0 deletions
47
src/metatensor/models/experimental/pet/tests/test_functionality.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
import ase | ||
import rascaline.torch | ||
import torch | ||
from metatensor.torch.atomistic import ( | ||
MetatensorAtomisticModel, | ||
ModelCapabilities, | ||
ModelEvaluationOptions, | ||
ModelOutput, | ||
) | ||
|
||
from metatensor.models.experimental.pet import DEFAULT_HYPERS, Model | ||
from metatensor.models.utils.neighbors_lists import get_system_with_neighbors_lists | ||
|
||
|
||
def test_prediction_subset(): | ||
"""Tests that the model can predict on a subset | ||
of the elements it was trained on.""" | ||
|
||
capabilities = ModelCapabilities( | ||
length_unit="Angstrom", | ||
species=[1, 6, 7, 8], | ||
outputs={ | ||
"energy": ModelOutput( | ||
quantity="energy", | ||
unit="eV", | ||
) | ||
}, | ||
) | ||
|
||
model = Model(capabilities, DEFAULT_HYPERS["ARCHITECTURAL_HYPERS"]).to( | ||
torch.float64 | ||
) | ||
structure = ase.Atoms("O2", positions=[[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]]) | ||
system = rascaline.torch.systems_to_torch(structure) | ||
system = get_system_with_neighbors_lists(system, model.requested_neighbors_lists()) | ||
|
||
evaluation_options = ModelEvaluationOptions( | ||
length_unit=capabilities.length_unit, | ||
outputs=capabilities.outputs, | ||
) | ||
|
||
model = MetatensorAtomisticModel(model.eval(), model.capabilities) | ||
model( | ||
[system], | ||
evaluation_options, | ||
check_consistency=True, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
import logging | ||
import warnings | ||
from pathlib import Path | ||
from typing import Dict, List, Optional, Tuple, Union | ||
|
||
import rascaline | ||
import torch | ||
from metatensor.learn.data import DataLoader | ||
from metatensor.learn.data.dataset import _BaseDataset | ||
from metatensor.torch.atomistic import ModelCapabilities, NeighborsListOptions, System | ||
|
||
from ...utils.composition import calculate_composition_weights | ||
from ...utils.compute_loss import compute_model_loss | ||
from ...utils.data import ( | ||
check_datasets, | ||
collate_fn, | ||
combine_dataloaders, | ||
get_all_targets, | ||
) | ||
from ...utils.data.system_to_ase import system_to_ase | ||
from ...utils.extract_targets import get_outputs_dict | ||
from ...utils.info import finalize_aggregated_info, update_aggregated_info | ||
from ...utils.neighbors_lists import get_system_with_neighbors_lists | ||
from ...utils.logging import MetricLogger | ||
from ...utils.loss import TensorMapDictLoss | ||
from ...utils.merge_capabilities import merge_capabilities | ||
from ...utils.model_io import load_checkpoint, save_model | ||
from .utils import systems_to_pyg_graphs | ||
from .model import DEFAULT_HYPERS, Model | ||
|
||
|
||
logger = logging.getLogger(__name__) | ||
|
||
# disable rascaline logger | ||
rascaline.set_logging_callback(lambda x, y: None) | ||
|
||
# Filter out the second derivative and device warnings from rascaline-torch | ||
warnings.filterwarnings("ignore", category=UserWarning, message="second derivative") | ||
warnings.filterwarnings( | ||
"ignore", category=UserWarning, message="Systems data is on device" | ||
) | ||
frostedoyster marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
def train( | ||
train_datasets: List[Union[_BaseDataset, torch.utils.data.Subset]], | ||
validation_datasets: List[Union[_BaseDataset, torch.utils.data.Subset]], | ||
requested_capabilities: ModelCapabilities, | ||
hypers: Dict = DEFAULT_HYPERS, | ||
continue_from: Optional[str] = None, | ||
output_dir: str = ".", | ||
device_str: str = "cpu", | ||
): | ||
if len(requested_capabilities.outputs) != 1: | ||
raise ValueError("PET only supports a single output") | ||
target_name = next(iter(requested_capabilities.outputs.keys())) | ||
if requested_capabilities.outputs[target_name].quantity != "energy": | ||
raise ValueError("PET only supports energies as output") | ||
if requested_capabilities.outputs[target_name].per_atom: | ||
raise ValueError("PET does not support per-atom energies") | ||
|
||
if len(train_datasets) != 1: | ||
raise ValueError("PET only supports a single training dataset") | ||
if len(validation_datasets) != 1: | ||
raise ValueError("PET only supports a single validation dataset") | ||
|
||
train_dataset = train_datasets[0] | ||
validation_dataset = validation_datasets[0] | ||
|
||
# dummy dataloaders due to https://github.com/lab-cosmo/metatensor/issues/521 | ||
train_dataloader = DataLoader( | ||
train_dataset, | ||
batch_size=1, | ||
shuffle=False, | ||
collate_fn=collate_fn, | ||
) | ||
validation_dataloader = DataLoader( | ||
validation_dataset, | ||
batch_size=1, | ||
shuffle=False, | ||
collate_fn=collate_fn, | ||
) | ||
|
||
# only energies or energies and forces? | ||
do_forces = next(iter(next(iter(train_dataset))[1].values())).values.has_gradient("positions") | ||
all_species = requested_capabilities.species | ||
|
||
ase_train_dataset = [] | ||
for (system,), targets in train_dataloader: | ||
ase_atoms = system_to_ase(system) | ||
ase_atoms.info['energy'] = targets[target_name].block().values.squeeze(-1).detach().cpu().numpy() | ||
if do_forces: | ||
ase_atoms.arrays["forces"] = targets[target_name].block().gradient('positions').values.squeeze(-1).detach().cpu().numpy() | ||
ase_train_dataset.append(ase_atoms) | ||
|
||
ase_validation_dataset = [] | ||
for (system,), _ in validation_dataloader: | ||
ase_atoms = system_to_ase(system) | ||
ase_atoms.info['energy'] = targets[target_name].block().values.squeeze(-1).detach().cpu().numpy() | ||
if do_forces: | ||
ase_atoms.arrays["forces"] = targets[target_name].block().gradient('positions').values.squeeze(-1).detach().cpu().numpy() | ||
ase_validation_dataset.append(ase_atoms) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from .systems_to_pyg_graphs import systems_to_pyg_graphs | ||
|
||
__all__ = [ | ||
"systems_to_pyg_graphs", | ||
] |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we really need these hypers if in practice you never change them?
Also, we need a page explaining them as for the other architectures.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could do this, depending on how keen @serfg is. However, this should not be required for experimental models
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No it is not required but maybe makes sense to have this is bit cleaned up.