-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Extract and speed up device transfers (#357)
- Loading branch information
1 parent
cc84568
commit 6ece48e
Showing
7 changed files
with
92 additions
and
20 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,4 +24,5 @@ This is the API for the ``utils`` module of ``metatrain``. | |
omegaconf | ||
output_gradient | ||
per_atom | ||
transfer | ||
units |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
Data type and device transfers | ||
############################## | ||
|
||
.. automodule:: metatrain.utils.transfer | ||
:members: | ||
:undoc-members: | ||
:show-inheritance: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
from typing import Dict, List | ||
|
||
import torch | ||
from metatensor.torch import TensorMap | ||
from metatensor.torch.atomistic import System | ||
|
||
|
||
@torch.jit.script | ||
def systems_and_targets_to_dtype_and_device( | ||
systems: List[System], | ||
targets: Dict[str, TensorMap], | ||
dtype: torch.dtype, | ||
device: torch.device, | ||
): | ||
""" | ||
Transfers the systems and targets to the specified dtype and device. | ||
:param systems: List of systems. | ||
:param targets: Dictionary of targets. | ||
:param dtype: Desired data type. | ||
:param device: Device to transfer to. | ||
""" | ||
|
||
systems = [system.to(dtype=dtype, device=device) for system in systems] | ||
targets = { | ||
key: value.to(dtype=dtype, device=device) for key, value in targets.items() | ||
} | ||
return systems, targets |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import metatensor.torch | ||
import torch | ||
from metatensor.torch import Labels, TensorMap | ||
from metatensor.torch.atomistic import System | ||
|
||
from metatrain.utils.transfer import systems_and_targets_to_dtype_and_device | ||
|
||
|
||
def test_systems_and_targets_to_dtype_and_device(): | ||
system = System( | ||
positions=torch.tensor([[1.0, 1.0, 1.0]]), | ||
cell=torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]), | ||
types=torch.tensor([1]), | ||
) | ||
targets = TensorMap( | ||
keys=Labels.single(), | ||
blocks=[metatensor.torch.block_from_array(torch.tensor([[1.0]]))], | ||
) | ||
|
||
systems = [system] | ||
targets = {"energy": targets} | ||
|
||
assert systems[0].positions.dtype == torch.float32 | ||
assert systems[0].positions.device == torch.device("cpu") | ||
assert systems[0].cell.dtype == torch.float32 | ||
assert systems[0].types.device == torch.device("cpu") | ||
assert targets["energy"].block().values.dtype == torch.float32 | ||
assert targets["energy"].block().values.device == torch.device("cpu") | ||
|
||
systems, targets = systems_and_targets_to_dtype_and_device( | ||
systems, targets, torch.float64, torch.device("meta") | ||
) | ||
|
||
assert systems[0].positions.dtype == torch.float64 | ||
assert systems[0].positions.device == torch.device("meta") | ||
assert systems[0].cell.dtype == torch.float64 | ||
assert systems[0].types.device == torch.device("meta") | ||
assert targets["energy"].block().values.dtype == torch.float64 | ||
assert targets["energy"].block().values.device == torch.device("meta") |