diff --git a/tests/utils/test_transfer.py b/tests/utils/test_transfer.py index 22a9e30e8..cceb3bf23 100644 --- a/tests/utils/test_transfer.py +++ b/tests/utils/test_transfer.py @@ -17,8 +17,18 @@ def test_systems_and_targets_to_dtype_and_device(): blocks=[metatensor.torch.block_from_array(torch.tensor([[1.0]]))], ) + systems = [system] + targets = {"energy": targets} + + assert systems[0].positions.dtype == torch.float32 + assert systems[0].positions.device == torch.device("cpu") + assert systems[0].cell.dtype == torch.float32 + assert systems[0].types.device == torch.device("cpu") + assert targets["energy"].block().values.dtype == torch.float32 + assert targets["energy"].block().values.device == torch.device("cpu") + systems, targets = systems_and_targets_to_dtype_and_device( - [system], {"energy": targets}, torch.float64, torch.device("meta") + systems, targets, torch.float64, torch.device("meta") ) assert systems[0].positions.dtype == torch.float64