Skip to content

Commit

Permalink
Review changes
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster committed Nov 5, 2024
1 parent 4f89649 commit ce0ee06
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 11 deletions.
3 changes: 2 additions & 1 deletion src/metatrain/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,8 @@ def train_model(
for key in intersecting_keys:
if target_info_dict[key] != target_info_dict_single[key]:
raise ValueError(

Check warning on line 238 in src/metatrain/cli/train.py

View check run for this annotation

Codecov / codecov/patch

src/metatrain/cli/train.py#L238

Added line #L238 was not covered by tests
f"Target information for key {key} differs between training sets."
f"Target information for key {key} differs between training sets. "
f"Got {target_info_dict[key]} and {target_info_dict_single[key]}."
)
target_info_dict.update(target_info_dict_single)

Expand Down
35 changes: 25 additions & 10 deletions src/metatrain/utils/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,20 +33,35 @@ def __init__(
unit: Union[None, str] = "",
):
# one of these will be set to True inside the _check_layout method
self.is_scalar = False
self.is_cartesian = False
self.is_spherical = False
self._is_scalar = False
self._is_cartesian = False
self._is_spherical = False

self._check_layout(layout)

self.quantity = quantity # float64: otherwise metatensor can't serialize
self.layout = layout
self.unit = unit if unit is not None else ""

@property
def is_scalar(self) -> bool:
"""Whether the target is a scalar."""
return self._is_scalar

Check warning on line 49 in src/metatrain/utils/data/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/metatrain/utils/data/dataset.py#L49

Added line #L49 was not covered by tests

@property
def is_cartesian(self) -> bool:
"""Whether the target is a Cartesian tensor."""
return self._is_cartesian

Check warning on line 54 in src/metatrain/utils/data/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/metatrain/utils/data/dataset.py#L54

Added line #L54 was not covered by tests

@property
def is_spherical(self) -> bool:
"""Whether the target is a spherical tensor."""
return self._is_spherical

Check warning on line 59 in src/metatrain/utils/data/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/metatrain/utils/data/dataset.py#L59

Added line #L59 was not covered by tests

@property
def gradients(self) -> List[str]:
"""Sorted and unique list of gradient names."""
if self.is_scalar:
if self._is_scalar:
return sorted(self.layout.block().gradients_list())
else:
return []
Expand Down Expand Up @@ -102,22 +117,22 @@ def _check_layout(self, layout: TensorMap) -> None:
)
components_first_block = layout.block(0).components
if len(components_first_block) == 0:
self.is_scalar = True
self._is_scalar = True
elif components_first_block[0].names[0].startswith("xyz"):
self.is_cartesian = True
self._is_cartesian = True
elif (
len(components_first_block) == 1
and components_first_block[0].names[0] == "o3_mu"
):
self.is_spherical = True
self._is_spherical = True
else:
raise ValueError(
"The layout ``TensorMap`` of a target should be "
"either scalars, Cartesian tensors or spherical tensors. The type of "
"the target could not be determined."
)

if self.is_scalar:
if self._is_scalar:
if layout.keys.names != ["_"]:
raise ValueError(

Check warning on line 137 in src/metatrain/utils/data/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/metatrain/utils/data/dataset.py#L137

Added line #L137 was not covered by tests
"The layout ``TensorMap`` of a scalar target should have "
Expand All @@ -136,7 +151,7 @@ def _check_layout(self, layout: TensorMap) -> None:
"scalar targets. "
f"Found '{gradient_name}' instead."
)
if self.is_cartesian:
if self._is_cartesian:
if layout.keys.names != ["_"]:
raise ValueError(

Check warning on line 156 in src/metatrain/utils/data/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/metatrain/utils/data/dataset.py#L156

Added line #L156 was not covered by tests
"The layout ``TensorMap`` of a Cartesian tensor target should have "
Expand All @@ -152,7 +167,7 @@ def _check_layout(self, layout: TensorMap) -> None:
"Gradients of Cartesian tensor targets are not supported."
)

if self.is_spherical:
if self._is_spherical:
if layout.keys.names != ["o3_lambda", "o3_sigma"]:
raise ValueError(

Check warning on line 172 in src/metatrain/utils/data/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/metatrain/utils/data/dataset.py#L172

Added line #L172 was not covered by tests
"The layout ``TensorMap`` of a spherical tensor target "
Expand Down

0 comments on commit ce0ee06

Please sign in to comment.