Skip to content

Commit

Permalink
Add explicit dtype argument to readers
Browse files Browse the repository at this point in the history
  • Loading branch information
PicoCentauri committed Feb 14, 2024
1 parent a956320 commit 23fa33b
Show file tree
Hide file tree
Showing 7 changed files with 108 additions and 38 deletions.
4 changes: 3 additions & 1 deletion src/metatensor/models/cli/eval_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,15 +139,17 @@ def eval_model(
"""
logging.basicConfig(level=logging.INFO, format="%(message)s")
logger.info("Setting up evaluation set.")
dtype = next(model.parameters()).dtype

options = expand_dataset_config(options)
eval_structures = read_structures(
filename=options["structures"]["read_from"],
fileformat=options["structures"]["file_format"],
dtype=dtype,
)
# Predict targets
if hasattr(options, "targets"):
eval_targets = read_targets(options["targets"])
eval_targets = read_targets(conf=options["targets"], dtype=dtype)
eval_dataset = Dataset(structure=eval_structures, energy=eval_targets["energy"])
_eval_targets(model, eval_dataset)

Expand Down
20 changes: 13 additions & 7 deletions src/metatensor/models/cli/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,11 +147,11 @@ def _train_model_hydra(options: DictConfig) -> None:
necessary options for dataset preparation, model hyperparameters, and training.
"""
if options["base_precision"] == 64:
torch.set_default_dtype(torch.float64)
dtype = torch.float64
elif options["base_precision"] == 32:
torch.set_default_dtype(torch.float32)
dtype = torch.float32
elif options["base_precision"] == 16:
torch.set_default_dtype(torch.float16)
dtype = torch.float16
else:
raise ValueError("Only 64, 32 or 16 are possible values for `base_precision`.")

Expand All @@ -175,8 +175,9 @@ def _train_model_hydra(options: DictConfig) -> None:
train_structures = read_structures(
filename=train_options["structures"]["read_from"],
fileformat=train_options["structures"]["file_format"],
dtype=dtype,
)
train_targets = read_targets(train_options["targets"])
train_targets = read_targets(conf=train_options["targets"], dtype=dtype)
train_dataset = Dataset(structure=train_structures, energy=train_targets["energy"])
train_size = 1.0

Expand Down Expand Up @@ -205,8 +206,9 @@ def _train_model_hydra(options: DictConfig) -> None:
test_structures = read_structures(
filename=test_options["structures"]["read_from"],
fileformat=test_options["structures"]["file_format"],
dtype=dtype,
)
test_targets = read_targets(test_options["targets"])
test_targets = read_targets(conf=test_options["targets"], dtype=dtype)
test_dataset = Dataset(structure=test_structures, energy=test_targets["energy"])
check_units(actual_options=test_options, desired_options=train_options)

Expand Down Expand Up @@ -234,10 +236,14 @@ def _train_model_hydra(options: DictConfig) -> None:
validation_structures = read_structures(
filename=validation_options["structures"]["read_from"],
fileformat=validation_options["structures"]["file_format"],
dtype=dtype,
)
validation_targets = read_targets(
conf=validation_options["targets"], dtype=dtype
)
validation_targets = read_targets(validation_options["targets"])
validation_dataset = Dataset(
structure=validation_structures, energy=validation_targets["energy"]
structure=validation_structures,
energy=validation_targets["energy"],
)
check_units(actual_options=validation_options, desired_options=train_options)

Expand Down
45 changes: 37 additions & 8 deletions src/metatensor/models/utils/data/readers/readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@


def _base_reader(
readers: dict, filename: str, fileformat: Optional[str] = None, **reader_kwargs
readers: dict,
filename: str,
fileformat: Optional[str] = None,
dtype: torch.dtype = torch.float64,
**reader_kwargs,
):
if fileformat is None:
fileformat = Path(filename).suffix
Expand All @@ -25,110 +29,130 @@ def _base_reader(
except KeyError:
raise ValueError(f"fileformat {fileformat!r} is not supported")

return reader(filename, **reader_kwargs)
return reader(filename, dtype=dtype, **reader_kwargs)


def read_energy(
filename: str,
target_value: str = "energy",
fileformat: Optional[str] = None,
dtype: torch.dtype = torch.float64,
) -> List[TensorBlock]:
"""Read energy informations from a file.
:param filename: name of the file to read
:param target_value: target value key name to be parsed from the file.
:param fileformat: format of the structure file. If :py:obj:`None` the format is
determined from the suffix.
determined from the suffix
:param dtype: desired data type of returned tensor
:returns: target value stored stored as a :class:`metatensor.TensorBlock`
"""
return _base_reader(
readers=ENERGY_READERS,
filename=filename,
fileformat=fileformat,
key=target_value,
dtype=dtype,
)


def read_forces(
filename: str,
target_value: str = "forces",
fileformat: Optional[str] = None,
dtype: torch.dtype = torch.float64,
) -> List[TensorBlock]:
"""Read force informations from a file.
:param filename: name of the file to read
:param target_value: target value key name to be parsed from the file.
:param target_value: target value key name to be parsed from the file
:param fileformat: format of the structure file. If :py:obj:`None` the format is
determined from the suffix.
determined from the suffix
:param dtype: desired data type of returned tensor
:returns: target value stored stored as a :class:`metatensor.TensorBlock`
"""
return _base_reader(
readers=FORCES_READERS,
filename=filename,
fileformat=fileformat,
key=target_value,
dtype=dtype,
)


def read_stress(
filename: str,
target_value: str = "stress",
fileformat: Optional[str] = None,
dtype: torch.dtype = torch.float64,
) -> List[TensorBlock]:
"""Read stress informations from a file.
:param filename: name of the file to read
:param target_value: target value key name to be parsed from the file.
:param fileformat: format of the structure file. If :py:obj:`None` the format is
determined from the suffix.
determined from the suffix
:param dtype: desired data type of returned tensor
:returns: target value stored stored as a :class:`metatensor.TensorBlock`
"""
return _base_reader(
readers=STRESS_READERS,
filename=filename,
fileformat=fileformat,
key=target_value,
dtype=dtype,
)


def read_structures(
filename: str,
fileformat: Optional[str] = None,
dtype: torch.dtype = torch.float64,
) -> List[System]:
"""Read structure informations from a file.
:param filename: name of the file to read
:param fileformat: format of the structure file. If :py:obj:`None` the format is
determined from the suffix.
:param dtype: desired data type of returned tensor
:returns: list of structures
"""
return _base_reader(
readers=STRUCTURE_READERS, filename=filename, fileformat=fileformat
readers=STRUCTURE_READERS,
filename=filename,
fileformat=fileformat,
dtype=dtype,
)


def read_virial(
filename: str,
target_value: str = "virial",
fileformat: Optional[str] = None,
dtype: torch.dtype = torch.float64,
) -> List[TensorBlock]:
"""Read virial informations from a file.
:param filename: name of the file to read
:param target_value: target value key name to be parsed from the file.
:param fileformat: format of the structure file. If :py:obj:`None` the format is
determined from the suffix.
:param dtype: desired data type of returned tensor
:returns: target value stored stored as a :class:`metatensor.TensorBlock`
"""
return _base_reader(
readers=VIRIAL_READERS,
filename=filename,
fileformat=fileformat,
key=target_value,
dtype=dtype,
)


def read_targets(conf: DictConfig) -> Dict[str, List[TensorMap]]:
def read_targets(
conf: DictConfig,
dtype: torch.dtype = torch.float64,
) -> Dict[str, List[TensorMap]]:
"""Reading all target information from a fully expanded config.
To get such a config you can use
Expand All @@ -140,6 +164,7 @@ def read_targets(conf: DictConfig) -> Dict[str, List[TensorMap]]:
added. Other gradients are silentlty irgnored.
:param conf: config containing the keys for what should be read.
:param dtype: desired data type of returned tensor
:returns: Dictionary containing one TensorMaps for each target section in the
config."""
target_dictionary = {}
Expand All @@ -150,6 +175,7 @@ def read_targets(conf: DictConfig) -> Dict[str, List[TensorMap]]:
filename=target["read_from"],
target_value=target["key"],
fileformat=target["file_format"],
dtype=dtype,
)

if target["forces"]:
Expand All @@ -158,6 +184,7 @@ def read_targets(conf: DictConfig) -> Dict[str, List[TensorMap]]:
filename=target["forces"]["read_from"],
target_value=target["forces"]["key"],
fileformat=target["forces"]["file_format"],
dtype=dtype,
)
except KeyError:
logger.warning(
Expand All @@ -183,6 +210,7 @@ def read_targets(conf: DictConfig) -> Dict[str, List[TensorMap]]:
filename=target["stress"]["read_from"],
target_value=target["stress"]["key"],
fileformat=target["stress"]["file_format"],
dtype=dtype,
)
except KeyError:
logger.warning(
Expand All @@ -203,6 +231,7 @@ def read_targets(conf: DictConfig) -> Dict[str, List[TensorMap]]:
filename=target["virial"]["read_from"],
target_value=target["virial"]["key"],
fileformat=target["virial"]["file_format"],
dtype=dtype,
)
except KeyError:
logger.warning(
Expand Down
7 changes: 5 additions & 2 deletions src/metatensor/models/utils/data/readers/structures/ase.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,17 @@
from rascaline.torch.system import System, systems_to_torch


def read_structures_ase(filename: str) -> List[System]:
def read_structures_ase(
filename: str, dtype: torch.dtype = torch.float64
) -> List[System]:
"""Store structure informations using ase.
:param filename: name of the file to read
:param dtype: desired data type of returned tensor
:returns:
A list of structures
"""
systems = [AseSystem(atoms) for atoms in ase.io.read(filename, ":")]

return [s.to(dtype=torch.get_default_dtype()) for s in systems_to_torch(systems)]
return [s.to(dtype=dtype) for s in systems_to_torch(systems)]
24 changes: 19 additions & 5 deletions src/metatensor/models/utils/data/readers/targets/ase.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@
def read_energy_ase(
filename: str,
key: str,
dtype: torch.dtype = torch.float64,
) -> List[TensorBlock]:
"""Store energy information in a List of :class:`metatensor.TensorBlock`.
:param filename: name of the file to read
:param key: target value key name to be parsed from the file.
:param dtype: desired data type of returned tensor
:returns:
TensorMap containing the given information
Expand All @@ -24,7 +26,7 @@ def read_energy_ase(

blocks = []
for i_structure, atoms in enumerate(frames):
values = torch.tensor([[atoms.info[key]]], dtype=torch.get_default_dtype())
values = torch.tensor([[atoms.info[key]]], dtype=dtype)
samples = Labels(["structure"], torch.tensor([[i_structure]]))

block = TensorBlock(
Expand All @@ -41,12 +43,14 @@ def read_energy_ase(
def read_forces_ase(
filename: str,
key: str = "energy",
dtype: torch.dtype = torch.float64,
) -> List[TensorBlock]:
"""Store force information in a List of :class:`metatensor.TensorBlock` which can be
used as ``position`` gradients.
:param filename: name of the file to read
:param key: target value key name to be parsed from the file.
:param dtype: desired data type of returned tensor
:returns:
TensorMap containing the given information
Expand All @@ -59,7 +63,7 @@ def read_forces_ase(
blocks = []
for i_structure, atoms in enumerate(frames):
# We store forces as positions gradients which means we invert the sign
values = -torch.tensor(atoms.arrays[key], dtype=torch.get_default_dtype())
values = -torch.tensor(atoms.arrays[key], dtype=dtype)
values = values.reshape(-1, 3, 1)

samples = Labels(
Expand All @@ -82,46 +86,56 @@ def read_forces_ase(
def read_virial_ase(
filename: str,
key: str = "virial",
dtype: torch.dtype = torch.float64,
) -> List[TensorBlock]:
"""Store virial information in a List of :class:`metatensor.TensorBlock` which can
be used as ``strain`` gradients.
:param filename: name of the file to read
:param key: target value key name to be parsed from the file
:param dtype: desired data type of returned tensor
:returns:
TensorMap containing the given information
"""
return _read_virial_stress_ase(filename=filename, key=key, is_virial=True)
return _read_virial_stress_ase(
filename=filename, key=key, is_virial=True, dtype=dtype
)


def read_stress_ase(
filename: str,
key: str = "stress",
dtype: torch.dtype = torch.float64,
) -> List[TensorBlock]:
"""Store stress information in a List of :class:`metatensor.TensorBlock` which can
be used as ``strain`` gradients.
:param filename: name of the file to read
:param key: target value key name to be parsed from the file
:param dtype: desired data type of returned tensor
:returns:
TensorMap containing the given information
"""
return _read_virial_stress_ase(filename=filename, key=key, is_virial=False)
return _read_virial_stress_ase(
filename=filename, key=key, is_virial=False, dtype=dtype
)


def _read_virial_stress_ase(
filename: str,
key: str,
is_virial: bool = True,
dtype: torch.dtype = torch.float64,
) -> List[TensorBlock]:
"""Store stress or virial information in a List of :class:`metatensor.TensorBlock`
which can be used as ``strain`` gradients.
:param filename: name of the file to read
:param key: target value key name to be parsed from the file
:param is_virial: if target values are stored as stress or virials.
:param dtype: desired data type of returned tensor
:returns:
TensorMap containing the given information
Expand All @@ -138,7 +152,7 @@ def _read_virial_stress_ase(
blocks = []
for i_structure, atoms in enumerate(frames):

values = torch.tensor(atoms.info[key].tolist(), dtype=torch.get_default_dtype())
values = torch.tensor(atoms.info[key].tolist(), dtype=dtype)

if values.shape == (9,):
warnings.warn(
Expand Down
Loading

0 comments on commit 23fa33b

Please sign in to comment.