Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extract and speed up device transfers #357

Merged
merged 4 commits into from
Oct 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/src/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# to include the documentation
os.environ["METATENSOR_IMPORT_FOR_SPHINX"] = "1"
os.environ["RASCALINE_IMPORT_FOR_SPHINX"] = "1"
os.environ["PYTORCH_JIT"] = "0"

import metatrain # noqa: E402

Expand Down Expand Up @@ -53,9 +54,11 @@ def generate_examples():
# METATENSOR_IMPORT_FOR_SPHINX=1). So instead we run it inside a small script, and
# include the corresponding output later.
del os.environ["METATENSOR_IMPORT_FOR_SPHINX"]
del os.environ["PYTORCH_JIT"]
script = os.path.join(ROOT, "docs", "generate_examples", "generate-examples.py")
subprocess.run([sys.executable, script])
os.environ["METATENSOR_IMPORT_FOR_SPHINX"] = "1"
os.environ["PYTORCH_JIT"] = "0"


def setup(app):
Expand Down
1 change: 1 addition & 0 deletions docs/src/dev-docs/utils/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,5 @@ This is the API for the ``utils`` module of ``metatrain``.
omegaconf
output_gradient
per_atom
transfer
units
7 changes: 7 additions & 0 deletions docs/src/dev-docs/utils/transfer.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Data type and device transfers
##############################

.. automodule:: metatrain.utils.transfer
:members:
:undoc-members:
:show-inheritance:
17 changes: 7 additions & 10 deletions src/metatrain/experimental/alchemical_model/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
get_system_with_neighbor_lists,
)
from ...utils.per_atom import average_by_num_atoms
from ...utils.transfer import systems_and_targets_to_dtype_and_device
from . import AlchemicalModel
from .utils.composition import calculate_composition_weights
from .utils.normalize import (
Expand Down Expand Up @@ -222,11 +223,9 @@ def train(

systems, targets = batch
assert len(systems[0].known_neighbor_lists()) > 0
systems = [system.to(dtype=dtype, device=device) for system in systems]
targets = {
key: value.to(dtype=dtype, device=device)
for key, value in targets.items()
}
systems, targets = systems_and_targets_to_dtype_and_device(
systems, targets, dtype, device
)
for additive_model in model.additive_models:
targets = remove_additive(
systems, targets, additive_model, model.dataset_info.targets
Expand Down Expand Up @@ -262,11 +261,9 @@ def train(
for batch in val_dataloader:
systems, targets = batch
assert len(systems[0].known_neighbor_lists()) > 0
systems = [system.to(dtype=dtype, device=device) for system in systems]
targets = {
key: value.to(dtype=dtype, device=device)
for key, value in targets.items()
}
systems, targets = systems_and_targets_to_dtype_and_device(
systems, targets, dtype, device
)
for additive_model in model.additive_models:
targets = remove_additive(
systems, targets, additive_model, model.dataset_info.targets
Expand Down
17 changes: 7 additions & 10 deletions src/metatrain/experimental/soap_bpnn/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
get_system_with_neighbor_lists,
)
from ...utils.per_atom import average_by_num_atoms
from ...utils.transfer import systems_and_targets_to_dtype_and_device
from .model import SoapBpnn


Expand Down Expand Up @@ -251,11 +252,9 @@ def train(
targets = remove_additive(
systems, targets, additive_model, train_targets
)
systems = [system.to(dtype=dtype, device=device) for system in systems]
targets = {
key: value.to(dtype=dtype, device=device)
for key, value in targets.items()
}
systems, targets = systems_and_targets_to_dtype_and_device(
systems, targets, dtype, device
)
predictions = evaluate_model(
model,
systems,
Expand Down Expand Up @@ -294,11 +293,9 @@ def train(
targets = remove_additive(
systems, targets, additive_model, train_targets
)
systems = [system.to(dtype=dtype, device=device) for system in systems]
targets = {
key: value.to(dtype=dtype, device=device)
for key, value in targets.items()
}
systems, targets = systems_and_targets_to_dtype_and_device(
systems, targets, dtype, device
)
predictions = evaluate_model(
model,
systems,
Expand Down
28 changes: 28 additions & 0 deletions src/metatrain/utils/transfer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from typing import Dict, List

import torch
from metatensor.torch import TensorMap
from metatensor.torch.atomistic import System


@torch.jit.script
def systems_and_targets_to_dtype_and_device(
systems: List[System],
targets: Dict[str, TensorMap],
dtype: torch.dtype,
device: torch.device,
):
"""
Transfers the systems and targets to the specified dtype and device.

:param systems: List of systems.
:param targets: Dictionary of targets.
:param dtype: Desired data type.
:param device: Device to transfer to.
"""

systems = [system.to(dtype=dtype, device=device) for system in systems]
targets = {

Check warning on line 25 in src/metatrain/utils/transfer.py

View check run for this annotation

Codecov / codecov/patch

src/metatrain/utils/transfer.py#L24-L25

Added lines #L24 - L25 were not covered by tests
key: value.to(dtype=dtype, device=device) for key, value in targets.items()
}
return systems, targets

Check warning on line 28 in src/metatrain/utils/transfer.py

View check run for this annotation

Codecov / codecov/patch

src/metatrain/utils/transfer.py#L28

Added line #L28 was not covered by tests
39 changes: 39 additions & 0 deletions tests/utils/test_transfer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import metatensor.torch
import torch
from metatensor.torch import Labels, TensorMap
from metatensor.torch.atomistic import System

from metatrain.utils.transfer import systems_and_targets_to_dtype_and_device


def test_systems_and_targets_to_dtype_and_device():
system = System(
positions=torch.tensor([[1.0, 1.0, 1.0]]),
cell=torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]),
types=torch.tensor([1]),
)
targets = TensorMap(
keys=Labels.single(),
blocks=[metatensor.torch.block_from_array(torch.tensor([[1.0]]))],
)

systems = [system]
targets = {"energy": targets}

assert systems[0].positions.dtype == torch.float32
assert systems[0].positions.device == torch.device("cpu")
assert systems[0].cell.dtype == torch.float32
assert systems[0].types.device == torch.device("cpu")
assert targets["energy"].block().values.dtype == torch.float32
assert targets["energy"].block().values.device == torch.device("cpu")

systems, targets = systems_and_targets_to_dtype_and_device(
systems, targets, torch.float64, torch.device("meta")
)

assert systems[0].positions.dtype == torch.float64
Luthaf marked this conversation as resolved.
Show resolved Hide resolved
assert systems[0].positions.device == torch.device("meta")
assert systems[0].cell.dtype == torch.float64
assert systems[0].types.device == torch.device("meta")
assert targets["energy"].block().values.dtype == torch.float64
assert targets["energy"].block().values.device == torch.device("meta")