diff --git a/docs/src/conf.py b/docs/src/conf.py index c9de1502..d367c9eb 100644 --- a/docs/src/conf.py +++ b/docs/src/conf.py @@ -10,6 +10,7 @@ # to include the documentation os.environ["METATENSOR_IMPORT_FOR_SPHINX"] = "1" os.environ["RASCALINE_IMPORT_FOR_SPHINX"] = "1" +os.environ["PYTORCH_JIT"] = "0" import metatrain # noqa: E402 @@ -53,9 +54,11 @@ def generate_examples(): # METATENSOR_IMPORT_FOR_SPHINX=1). So instead we run it inside a small script, and # include the corresponding output later. del os.environ["METATENSOR_IMPORT_FOR_SPHINX"] + del os.environ["PYTORCH_JIT"] script = os.path.join(ROOT, "docs", "generate_examples", "generate-examples.py") subprocess.run([sys.executable, script]) os.environ["METATENSOR_IMPORT_FOR_SPHINX"] = "1" + os.environ["PYTORCH_JIT"] = "0" def setup(app): diff --git a/docs/src/dev-docs/utils/index.rst b/docs/src/dev-docs/utils/index.rst index 312ee534..0f5cf8ab 100644 --- a/docs/src/dev-docs/utils/index.rst +++ b/docs/src/dev-docs/utils/index.rst @@ -24,4 +24,5 @@ This is the API for the ``utils`` module of ``metatrain``. omegaconf output_gradient per_atom + transfer units diff --git a/docs/src/dev-docs/utils/transfer.rst b/docs/src/dev-docs/utils/transfer.rst new file mode 100644 index 00000000..ba971029 --- /dev/null +++ b/docs/src/dev-docs/utils/transfer.rst @@ -0,0 +1,7 @@ +Data type and device transfers +############################## + +.. automodule:: metatrain.utils.transfer + :members: + :undoc-members: + :show-inheritance: diff --git a/src/metatrain/experimental/alchemical_model/trainer.py b/src/metatrain/experimental/alchemical_model/trainer.py index 0dd13a2b..c9b9b5f6 100644 --- a/src/metatrain/experimental/alchemical_model/trainer.py +++ b/src/metatrain/experimental/alchemical_model/trainer.py @@ -25,6 +25,7 @@ get_system_with_neighbor_lists, ) from ...utils.per_atom import average_by_num_atoms +from ...utils.transfer import systems_and_targets_to_dtype_and_device from . import AlchemicalModel from .utils.composition import calculate_composition_weights from .utils.normalize import ( @@ -222,11 +223,9 @@ def train( systems, targets = batch assert len(systems[0].known_neighbor_lists()) > 0 - systems = [system.to(dtype=dtype, device=device) for system in systems] - targets = { - key: value.to(dtype=dtype, device=device) - for key, value in targets.items() - } + systems, targets = systems_and_targets_to_dtype_and_device( + systems, targets, dtype, device + ) for additive_model in model.additive_models: targets = remove_additive( systems, targets, additive_model, model.dataset_info.targets @@ -262,11 +261,9 @@ def train( for batch in val_dataloader: systems, targets = batch assert len(systems[0].known_neighbor_lists()) > 0 - systems = [system.to(dtype=dtype, device=device) for system in systems] - targets = { - key: value.to(dtype=dtype, device=device) - for key, value in targets.items() - } + systems, targets = systems_and_targets_to_dtype_and_device( + systems, targets, dtype, device + ) for additive_model in model.additive_models: targets = remove_additive( systems, targets, additive_model, model.dataset_info.targets diff --git a/src/metatrain/experimental/soap_bpnn/trainer.py b/src/metatrain/experimental/soap_bpnn/trainer.py index b67d9dd8..cdfc3ba8 100644 --- a/src/metatrain/experimental/soap_bpnn/trainer.py +++ b/src/metatrain/experimental/soap_bpnn/trainer.py @@ -23,6 +23,7 @@ get_system_with_neighbor_lists, ) from ...utils.per_atom import average_by_num_atoms +from ...utils.transfer import systems_and_targets_to_dtype_and_device from .model import SoapBpnn @@ -251,11 +252,9 @@ def train( targets = remove_additive( systems, targets, additive_model, train_targets ) - systems = [system.to(dtype=dtype, device=device) for system in systems] - targets = { - key: value.to(dtype=dtype, device=device) - for key, value in targets.items() - } + systems, targets = systems_and_targets_to_dtype_and_device( + systems, targets, dtype, device + ) predictions = evaluate_model( model, systems, @@ -294,11 +293,9 @@ def train( targets = remove_additive( systems, targets, additive_model, train_targets ) - systems = [system.to(dtype=dtype, device=device) for system in systems] - targets = { - key: value.to(dtype=dtype, device=device) - for key, value in targets.items() - } + systems, targets = systems_and_targets_to_dtype_and_device( + systems, targets, dtype, device + ) predictions = evaluate_model( model, systems, diff --git a/src/metatrain/utils/transfer.py b/src/metatrain/utils/transfer.py new file mode 100644 index 00000000..5aae6929 --- /dev/null +++ b/src/metatrain/utils/transfer.py @@ -0,0 +1,28 @@ +from typing import Dict, List + +import torch +from metatensor.torch import TensorMap +from metatensor.torch.atomistic import System + + +@torch.jit.script +def systems_and_targets_to_dtype_and_device( + systems: List[System], + targets: Dict[str, TensorMap], + dtype: torch.dtype, + device: torch.device, +): + """ + Transfers the systems and targets to the specified dtype and device. + + :param systems: List of systems. + :param targets: Dictionary of targets. + :param dtype: Desired data type. + :param device: Device to transfer to. + """ + + systems = [system.to(dtype=dtype, device=device) for system in systems] + targets = { + key: value.to(dtype=dtype, device=device) for key, value in targets.items() + } + return systems, targets diff --git a/tests/utils/test_transfer.py b/tests/utils/test_transfer.py new file mode 100644 index 00000000..cceb3bf2 --- /dev/null +++ b/tests/utils/test_transfer.py @@ -0,0 +1,39 @@ +import metatensor.torch +import torch +from metatensor.torch import Labels, TensorMap +from metatensor.torch.atomistic import System + +from metatrain.utils.transfer import systems_and_targets_to_dtype_and_device + + +def test_systems_and_targets_to_dtype_and_device(): + system = System( + positions=torch.tensor([[1.0, 1.0, 1.0]]), + cell=torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]), + types=torch.tensor([1]), + ) + targets = TensorMap( + keys=Labels.single(), + blocks=[metatensor.torch.block_from_array(torch.tensor([[1.0]]))], + ) + + systems = [system] + targets = {"energy": targets} + + assert systems[0].positions.dtype == torch.float32 + assert systems[0].positions.device == torch.device("cpu") + assert systems[0].cell.dtype == torch.float32 + assert systems[0].types.device == torch.device("cpu") + assert targets["energy"].block().values.dtype == torch.float32 + assert targets["energy"].block().values.device == torch.device("cpu") + + systems, targets = systems_and_targets_to_dtype_and_device( + systems, targets, torch.float64, torch.device("meta") + ) + + assert systems[0].positions.dtype == torch.float64 + assert systems[0].positions.device == torch.device("meta") + assert systems[0].cell.dtype == torch.float64 + assert systems[0].types.device == torch.device("meta") + assert targets["energy"].block().values.dtype == torch.float64 + assert targets["energy"].block().values.device == torch.device("meta")