From ce0ee06f44e8067e33bdcac7e7cc2460758e47fc Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Tue, 5 Nov 2024 17:22:18 +0100 Subject: [PATCH] Review changes --- src/metatrain/cli/train.py | 3 ++- src/metatrain/utils/data/dataset.py | 35 ++++++++++++++++++++--------- 2 files changed, 27 insertions(+), 11 deletions(-) diff --git a/src/metatrain/cli/train.py b/src/metatrain/cli/train.py index 2d93cf7e1..e053f3810 100644 --- a/src/metatrain/cli/train.py +++ b/src/metatrain/cli/train.py @@ -236,7 +236,8 @@ def train_model( for key in intersecting_keys: if target_info_dict[key] != target_info_dict_single[key]: raise ValueError( - f"Target information for key {key} differs between training sets." + f"Target information for key {key} differs between training sets. " + f"Got {target_info_dict[key]} and {target_info_dict_single[key]}." ) target_info_dict.update(target_info_dict_single) diff --git a/src/metatrain/utils/data/dataset.py b/src/metatrain/utils/data/dataset.py index 79eb7b794..20c192503 100644 --- a/src/metatrain/utils/data/dataset.py +++ b/src/metatrain/utils/data/dataset.py @@ -33,9 +33,9 @@ def __init__( unit: Union[None, str] = "", ): # one of these will be set to True inside the _check_layout method - self.is_scalar = False - self.is_cartesian = False - self.is_spherical = False + self._is_scalar = False + self._is_cartesian = False + self._is_spherical = False self._check_layout(layout) @@ -43,10 +43,25 @@ def __init__( self.layout = layout self.unit = unit if unit is not None else "" + @property + def is_scalar(self) -> bool: + """Whether the target is a scalar.""" + return self._is_scalar + + @property + def is_cartesian(self) -> bool: + """Whether the target is a Cartesian tensor.""" + return self._is_cartesian + + @property + def is_spherical(self) -> bool: + """Whether the target is a spherical tensor.""" + return self._is_spherical + @property def gradients(self) -> List[str]: """Sorted and unique list of gradient names.""" - if self.is_scalar: + if self._is_scalar: return sorted(self.layout.block().gradients_list()) else: return [] @@ -102,14 +117,14 @@ def _check_layout(self, layout: TensorMap) -> None: ) components_first_block = layout.block(0).components if len(components_first_block) == 0: - self.is_scalar = True + self._is_scalar = True elif components_first_block[0].names[0].startswith("xyz"): - self.is_cartesian = True + self._is_cartesian = True elif ( len(components_first_block) == 1 and components_first_block[0].names[0] == "o3_mu" ): - self.is_spherical = True + self._is_spherical = True else: raise ValueError( "The layout ``TensorMap`` of a target should be " @@ -117,7 +132,7 @@ def _check_layout(self, layout: TensorMap) -> None: "the target could not be determined." ) - if self.is_scalar: + if self._is_scalar: if layout.keys.names != ["_"]: raise ValueError( "The layout ``TensorMap`` of a scalar target should have " @@ -136,7 +151,7 @@ def _check_layout(self, layout: TensorMap) -> None: "scalar targets. " f"Found '{gradient_name}' instead." ) - if self.is_cartesian: + if self._is_cartesian: if layout.keys.names != ["_"]: raise ValueError( "The layout ``TensorMap`` of a Cartesian tensor target should have " @@ -152,7 +167,7 @@ def _check_layout(self, layout: TensorMap) -> None: "Gradients of Cartesian tensor targets are not supported." ) - if self.is_spherical: + if self._is_spherical: if layout.keys.names != ["o3_lambda", "o3_sigma"]: raise ValueError( "The layout ``TensorMap`` of a spherical tensor target "