diff --git a/pyproject.toml b/pyproject.toml index 8e7fcd940..0c37a1431 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,6 +13,7 @@ authors = [ dependencies = [ "ase", "torch", + "rascaline-torch @ git+https://github.com/luthaf/rascaline#subdirectory=python/rascaline-torch", "metatensor-core", "metatensor-torch" ] @@ -54,9 +55,7 @@ requires = [ build-backend = "setuptools.build_meta" [project.optional-dependencies] -soap-bpnn = [ - "rascaline-torch @ git+https://github.com/luthaf/rascaline#subdirectory=python/rascaline-torch", -] +soap-bpnn = [] [tool.setuptools.packages.find] where = ["src"] diff --git a/src/metatensor_models/readers/__init__.py b/src/metatensor_models/readers/__init__.py new file mode 100644 index 000000000..27441dfcb --- /dev/null +++ b/src/metatensor_models/readers/__init__.py @@ -0,0 +1,44 @@ +""""Readers for structures and target values.""" + +from typing import List, Dict, Optional + +from pathlib import Path + +from metatensor.torch import TensorMap + +from .structures import STRUCTURE_READERS +from .targets import TARGET_READERS + +from rascaline.torch.system import Systems + + +def read_structures(filename: str, fileformat: Optional[str] = None) -> List[Systems]: + """Reads a structure information from file.""" + + if fileformat is None: + fileformat = Path(filename).suffix + + try: + reader = STRUCTURE_READERS[fileformat] + except KeyError: + raise ValueError(f"fileformat '{fileformat}' is not supported") + + return reader(filename) + + +def read_targets( + filename: str, + target_value: str, + fileformat: Optional[str] = None, +) -> Dict[str, TensorMap]: + """Reads target information from file.""" + + if fileformat is None: + fileformat = Path(filename).suffix + + try: + reader = TARGET_READERS[fileformat] + except KeyError: + raise ValueError(f"fileformat '{fileformat}' is not supported") + + return reader(filename, target_value) diff --git a/src/metatensor_models/readers/structures/__init__.py b/src/metatensor_models/readers/structures/__init__.py new file mode 100644 index 000000000..123a893df --- /dev/null +++ b/src/metatensor_models/readers/structures/__init__.py @@ -0,0 +1,3 @@ +from .ase import read_ase + +STRUCTURE_READERS = {".xyz": read_ase} diff --git a/src/metatensor_models/readers/structures/ase.py b/src/metatensor_models/readers/structures/ase.py new file mode 100644 index 000000000..26d94ba35 --- /dev/null +++ b/src/metatensor_models/readers/structures/ase.py @@ -0,0 +1,11 @@ +from typing import List + +import ase.io +from rascaline.systems import AseSystem +from rascaline.torch.system import Systems, systems_to_torch + + +def read_ase(filename: str) -> List[Systems]: + systems = [AseSystem(atoms) for atoms in ase.io.read(filename, ":")] + + return systems_to_torch(systems) diff --git a/src/metatensor_models/readers/targets/__init__.py b/src/metatensor_models/readers/targets/__init__.py new file mode 100644 index 000000000..69a98d1a8 --- /dev/null +++ b/src/metatensor_models/readers/targets/__init__.py @@ -0,0 +1,3 @@ +from .ase import read_ase + +TARGET_READERS = {".xyz": read_ase} diff --git a/src/metatensor_models/readers/targets/ase.py b/src/metatensor_models/readers/targets/ase.py new file mode 100644 index 000000000..160535b18 --- /dev/null +++ b/src/metatensor_models/readers/targets/ase.py @@ -0,0 +1,38 @@ +from typing import Dict, List, Union + +import ase.io +import torch +from metatensor.torch import Labels, TensorBlock, TensorMap + + +def read_ase( + filename: str, + target_values: Union[List[str], str], +) -> Dict[str, TensorMap]: + """Store target informations from file in a :class:`metatensor.TensorMap`. + + :returns: + TensorMap containing the given information + """ + + if type(target_values) is str: + target_values = [target_values] + + frames = ase.io.read(filename, ":") + + target_dictionary = {} + for target_value in target_values: + values = [f.info[target_value] for f in frames] + + n_structures = len(values) + + block = TensorBlock( + values=torch.tensor(values).reshape(-1, 1), + samples=Labels(["structure"], torch.arange(n_structures).reshape(-1, 1)), + components=[], + properties=Labels([target_value], torch.tensor([(0,)])), + ) + + target_dictionary[target_value] = TensorMap(Labels.single(), [block]) + + return target_dictionary