Skip to content

Commit

Permalink
Extract and speed up device transfers (#357)
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster authored Oct 14, 2024
1 parent cc84568 commit 6ece48e
Show file tree
Hide file tree
Showing 7 changed files with 92 additions and 20 deletions.
3 changes: 3 additions & 0 deletions docs/src/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# to include the documentation
os.environ["METATENSOR_IMPORT_FOR_SPHINX"] = "1"
os.environ["RASCALINE_IMPORT_FOR_SPHINX"] = "1"
os.environ["PYTORCH_JIT"] = "0"

import metatrain # noqa: E402

Expand Down Expand Up @@ -53,9 +54,11 @@ def generate_examples():
# METATENSOR_IMPORT_FOR_SPHINX=1). So instead we run it inside a small script, and
# include the corresponding output later.
del os.environ["METATENSOR_IMPORT_FOR_SPHINX"]
del os.environ["PYTORCH_JIT"]
script = os.path.join(ROOT, "docs", "generate_examples", "generate-examples.py")
subprocess.run([sys.executable, script])
os.environ["METATENSOR_IMPORT_FOR_SPHINX"] = "1"
os.environ["PYTORCH_JIT"] = "0"


def setup(app):
Expand Down
1 change: 1 addition & 0 deletions docs/src/dev-docs/utils/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,5 @@ This is the API for the ``utils`` module of ``metatrain``.
omegaconf
output_gradient
per_atom
transfer
units
7 changes: 7 additions & 0 deletions docs/src/dev-docs/utils/transfer.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Data type and device transfers
##############################

.. automodule:: metatrain.utils.transfer
:members:
:undoc-members:
:show-inheritance:
17 changes: 7 additions & 10 deletions src/metatrain/experimental/alchemical_model/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
get_system_with_neighbor_lists,
)
from ...utils.per_atom import average_by_num_atoms
from ...utils.transfer import systems_and_targets_to_dtype_and_device
from . import AlchemicalModel
from .utils.composition import calculate_composition_weights
from .utils.normalize import (
Expand Down Expand Up @@ -222,11 +223,9 @@ def train(

systems, targets = batch
assert len(systems[0].known_neighbor_lists()) > 0
systems = [system.to(dtype=dtype, device=device) for system in systems]
targets = {
key: value.to(dtype=dtype, device=device)
for key, value in targets.items()
}
systems, targets = systems_and_targets_to_dtype_and_device(
systems, targets, dtype, device
)
for additive_model in model.additive_models:
targets = remove_additive(
systems, targets, additive_model, model.dataset_info.targets
Expand Down Expand Up @@ -262,11 +261,9 @@ def train(
for batch in val_dataloader:
systems, targets = batch
assert len(systems[0].known_neighbor_lists()) > 0
systems = [system.to(dtype=dtype, device=device) for system in systems]
targets = {
key: value.to(dtype=dtype, device=device)
for key, value in targets.items()
}
systems, targets = systems_and_targets_to_dtype_and_device(
systems, targets, dtype, device
)
for additive_model in model.additive_models:
targets = remove_additive(
systems, targets, additive_model, model.dataset_info.targets
Expand Down
17 changes: 7 additions & 10 deletions src/metatrain/experimental/soap_bpnn/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
get_system_with_neighbor_lists,
)
from ...utils.per_atom import average_by_num_atoms
from ...utils.transfer import systems_and_targets_to_dtype_and_device
from .model import SoapBpnn


Expand Down Expand Up @@ -251,11 +252,9 @@ def train(
targets = remove_additive(
systems, targets, additive_model, train_targets
)
systems = [system.to(dtype=dtype, device=device) for system in systems]
targets = {
key: value.to(dtype=dtype, device=device)
for key, value in targets.items()
}
systems, targets = systems_and_targets_to_dtype_and_device(
systems, targets, dtype, device
)
predictions = evaluate_model(
model,
systems,
Expand Down Expand Up @@ -294,11 +293,9 @@ def train(
targets = remove_additive(
systems, targets, additive_model, train_targets
)
systems = [system.to(dtype=dtype, device=device) for system in systems]
targets = {
key: value.to(dtype=dtype, device=device)
for key, value in targets.items()
}
systems, targets = systems_and_targets_to_dtype_and_device(
systems, targets, dtype, device
)
predictions = evaluate_model(
model,
systems,
Expand Down
28 changes: 28 additions & 0 deletions src/metatrain/utils/transfer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from typing import Dict, List

import torch
from metatensor.torch import TensorMap
from metatensor.torch.atomistic import System


@torch.jit.script
def systems_and_targets_to_dtype_and_device(
systems: List[System],
targets: Dict[str, TensorMap],
dtype: torch.dtype,
device: torch.device,
):
"""
Transfers the systems and targets to the specified dtype and device.
:param systems: List of systems.
:param targets: Dictionary of targets.
:param dtype: Desired data type.
:param device: Device to transfer to.
"""

systems = [system.to(dtype=dtype, device=device) for system in systems]
targets = {
key: value.to(dtype=dtype, device=device) for key, value in targets.items()
}
return systems, targets
39 changes: 39 additions & 0 deletions tests/utils/test_transfer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import metatensor.torch
import torch
from metatensor.torch import Labels, TensorMap
from metatensor.torch.atomistic import System

from metatrain.utils.transfer import systems_and_targets_to_dtype_and_device


def test_systems_and_targets_to_dtype_and_device():
system = System(
positions=torch.tensor([[1.0, 1.0, 1.0]]),
cell=torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]),
types=torch.tensor([1]),
)
targets = TensorMap(
keys=Labels.single(),
blocks=[metatensor.torch.block_from_array(torch.tensor([[1.0]]))],
)

systems = [system]
targets = {"energy": targets}

assert systems[0].positions.dtype == torch.float32
assert systems[0].positions.device == torch.device("cpu")
assert systems[0].cell.dtype == torch.float32
assert systems[0].types.device == torch.device("cpu")
assert targets["energy"].block().values.dtype == torch.float32
assert targets["energy"].block().values.device == torch.device("cpu")

systems, targets = systems_and_targets_to_dtype_and_device(
systems, targets, torch.float64, torch.device("meta")
)

assert systems[0].positions.dtype == torch.float64
assert systems[0].positions.device == torch.device("meta")
assert systems[0].cell.dtype == torch.float64
assert systems[0].types.device == torch.device("meta")
assert targets["energy"].block().values.dtype == torch.float64
assert targets["energy"].block().values.device == torch.device("meta")

0 comments on commit 6ece48e

Please sign in to comment.