diff --git a/src/metatensor/models/cli/eval_model.py b/src/metatensor/models/cli/eval_model.py index f730326fc..31ca2a481 100644 --- a/src/metatensor/models/cli/eval_model.py +++ b/src/metatensor/models/cli/eval_model.py @@ -139,15 +139,17 @@ def eval_model( """ logging.basicConfig(level=logging.INFO, format="%(message)s") logger.info("Setting up evaluation set.") + dtype = next(model.parameters()).dtype options = expand_dataset_config(options) eval_structures = read_structures( filename=options["structures"]["read_from"], fileformat=options["structures"]["file_format"], + dtype=dtype, ) # Predict targets if hasattr(options, "targets"): - eval_targets = read_targets(options["targets"]) + eval_targets = read_targets(conf=options["targets"], dtype=dtype) eval_dataset = Dataset(structure=eval_structures, energy=eval_targets["energy"]) _eval_targets(model, eval_dataset) diff --git a/src/metatensor/models/cli/train_model.py b/src/metatensor/models/cli/train_model.py index 8c36beeb9..f316fdce6 100644 --- a/src/metatensor/models/cli/train_model.py +++ b/src/metatensor/models/cli/train_model.py @@ -147,11 +147,11 @@ def _train_model_hydra(options: DictConfig) -> None: necessary options for dataset preparation, model hyperparameters, and training. """ if options["base_precision"] == 64: - torch.set_default_dtype(torch.float64) + dtype = torch.float64 elif options["base_precision"] == 32: - torch.set_default_dtype(torch.float32) + dtype = torch.float32 elif options["base_precision"] == 16: - torch.set_default_dtype(torch.float16) + dtype = torch.float16 else: raise ValueError("Only 64, 32 or 16 are possible values for `base_precision`.") @@ -175,8 +175,9 @@ def _train_model_hydra(options: DictConfig) -> None: train_structures = read_structures( filename=train_options["structures"]["read_from"], fileformat=train_options["structures"]["file_format"], + dtype=dtype, ) - train_targets = read_targets(train_options["targets"]) + train_targets = read_targets(conf=train_options["targets"], dtype=dtype) train_dataset = Dataset(structure=train_structures, energy=train_targets["energy"]) train_size = 1.0 @@ -205,8 +206,9 @@ def _train_model_hydra(options: DictConfig) -> None: test_structures = read_structures( filename=test_options["structures"]["read_from"], fileformat=test_options["structures"]["file_format"], + dtype=dtype, ) - test_targets = read_targets(test_options["targets"]) + test_targets = read_targets(conf=test_options["targets"], dtype=dtype) test_dataset = Dataset(structure=test_structures, energy=test_targets["energy"]) check_units(actual_options=test_options, desired_options=train_options) @@ -234,10 +236,14 @@ def _train_model_hydra(options: DictConfig) -> None: validation_structures = read_structures( filename=validation_options["structures"]["read_from"], fileformat=validation_options["structures"]["file_format"], + dtype=dtype, + ) + validation_targets = read_targets( + conf=validation_options["targets"], dtype=dtype ) - validation_targets = read_targets(validation_options["targets"]) validation_dataset = Dataset( - structure=validation_structures, energy=validation_targets["energy"] + structure=validation_structures, + energy=validation_targets["energy"], ) check_units(actual_options=validation_options, desired_options=train_options) diff --git a/src/metatensor/models/utils/data/readers/readers.py b/src/metatensor/models/utils/data/readers/readers.py index 34ac6ec94..15849dd46 100644 --- a/src/metatensor/models/utils/data/readers/readers.py +++ b/src/metatensor/models/utils/data/readers/readers.py @@ -15,7 +15,11 @@ def _base_reader( - readers: dict, filename: str, fileformat: Optional[str] = None, **reader_kwargs + readers: dict, + filename: str, + fileformat: Optional[str] = None, + dtype: torch.dtype = torch.float64, + **reader_kwargs, ): if fileformat is None: fileformat = Path(filename).suffix @@ -25,20 +29,22 @@ def _base_reader( except KeyError: raise ValueError(f"fileformat {fileformat!r} is not supported") - return reader(filename, **reader_kwargs) + return reader(filename, dtype=dtype, **reader_kwargs) def read_energy( filename: str, target_value: str = "energy", fileformat: Optional[str] = None, + dtype: torch.dtype = torch.float64, ) -> List[TensorBlock]: """Read energy informations from a file. :param filename: name of the file to read :param target_value: target value key name to be parsed from the file. :param fileformat: format of the structure file. If :py:obj:`None` the format is - determined from the suffix. + determined from the suffix + :param dtype: desired data type of returned tensor :returns: target value stored stored as a :class:`metatensor.TensorBlock` """ return _base_reader( @@ -46,6 +52,7 @@ def read_energy( filename=filename, fileformat=fileformat, key=target_value, + dtype=dtype, ) @@ -53,13 +60,15 @@ def read_forces( filename: str, target_value: str = "forces", fileformat: Optional[str] = None, + dtype: torch.dtype = torch.float64, ) -> List[TensorBlock]: """Read force informations from a file. :param filename: name of the file to read - :param target_value: target value key name to be parsed from the file. + :param target_value: target value key name to be parsed from the file :param fileformat: format of the structure file. If :py:obj:`None` the format is - determined from the suffix. + determined from the suffix + :param dtype: desired data type of returned tensor :returns: target value stored stored as a :class:`metatensor.TensorBlock` """ return _base_reader( @@ -67,6 +76,7 @@ def read_forces( filename=filename, fileformat=fileformat, key=target_value, + dtype=dtype, ) @@ -74,13 +84,15 @@ def read_stress( filename: str, target_value: str = "stress", fileformat: Optional[str] = None, + dtype: torch.dtype = torch.float64, ) -> List[TensorBlock]: """Read stress informations from a file. :param filename: name of the file to read :param target_value: target value key name to be parsed from the file. :param fileformat: format of the structure file. If :py:obj:`None` the format is - determined from the suffix. + determined from the suffix + :param dtype: desired data type of returned tensor :returns: target value stored stored as a :class:`metatensor.TensorBlock` """ return _base_reader( @@ -88,22 +100,28 @@ def read_stress( filename=filename, fileformat=fileformat, key=target_value, + dtype=dtype, ) def read_structures( filename: str, fileformat: Optional[str] = None, + dtype: torch.dtype = torch.float64, ) -> List[System]: """Read structure informations from a file. :param filename: name of the file to read :param fileformat: format of the structure file. If :py:obj:`None` the format is determined from the suffix. + :param dtype: desired data type of returned tensor :returns: list of structures """ return _base_reader( - readers=STRUCTURE_READERS, filename=filename, fileformat=fileformat + readers=STRUCTURE_READERS, + filename=filename, + fileformat=fileformat, + dtype=dtype, ) @@ -111,6 +129,7 @@ def read_virial( filename: str, target_value: str = "virial", fileformat: Optional[str] = None, + dtype: torch.dtype = torch.float64, ) -> List[TensorBlock]: """Read virial informations from a file. @@ -118,6 +137,7 @@ def read_virial( :param target_value: target value key name to be parsed from the file. :param fileformat: format of the structure file. If :py:obj:`None` the format is determined from the suffix. + :param dtype: desired data type of returned tensor :returns: target value stored stored as a :class:`metatensor.TensorBlock` """ return _base_reader( @@ -125,10 +145,14 @@ def read_virial( filename=filename, fileformat=fileformat, key=target_value, + dtype=dtype, ) -def read_targets(conf: DictConfig) -> Dict[str, List[TensorMap]]: +def read_targets( + conf: DictConfig, + dtype: torch.dtype = torch.float64, +) -> Dict[str, List[TensorMap]]: """Reading all target information from a fully expanded config. To get such a config you can use @@ -140,6 +164,7 @@ def read_targets(conf: DictConfig) -> Dict[str, List[TensorMap]]: added. Other gradients are silentlty irgnored. :param conf: config containing the keys for what should be read. + :param dtype: desired data type of returned tensor :returns: Dictionary containing one TensorMaps for each target section in the config.""" target_dictionary = {} @@ -150,6 +175,7 @@ def read_targets(conf: DictConfig) -> Dict[str, List[TensorMap]]: filename=target["read_from"], target_value=target["key"], fileformat=target["file_format"], + dtype=dtype, ) if target["forces"]: @@ -158,6 +184,7 @@ def read_targets(conf: DictConfig) -> Dict[str, List[TensorMap]]: filename=target["forces"]["read_from"], target_value=target["forces"]["key"], fileformat=target["forces"]["file_format"], + dtype=dtype, ) except KeyError: logger.warning( @@ -183,6 +210,7 @@ def read_targets(conf: DictConfig) -> Dict[str, List[TensorMap]]: filename=target["stress"]["read_from"], target_value=target["stress"]["key"], fileformat=target["stress"]["file_format"], + dtype=dtype, ) except KeyError: logger.warning( @@ -203,6 +231,7 @@ def read_targets(conf: DictConfig) -> Dict[str, List[TensorMap]]: filename=target["virial"]["read_from"], target_value=target["virial"]["key"], fileformat=target["virial"]["file_format"], + dtype=dtype, ) except KeyError: logger.warning( diff --git a/src/metatensor/models/utils/data/readers/structures/ase.py b/src/metatensor/models/utils/data/readers/structures/ase.py index ec7ec4d7f..8fcf8ea66 100644 --- a/src/metatensor/models/utils/data/readers/structures/ase.py +++ b/src/metatensor/models/utils/data/readers/structures/ase.py @@ -6,14 +6,17 @@ from rascaline.torch.system import System, systems_to_torch -def read_structures_ase(filename: str) -> List[System]: +def read_structures_ase( + filename: str, dtype: torch.dtype = torch.float64 +) -> List[System]: """Store structure informations using ase. :param filename: name of the file to read + :param dtype: desired data type of returned tensor :returns: A list of structures """ systems = [AseSystem(atoms) for atoms in ase.io.read(filename, ":")] - return [s.to(dtype=torch.get_default_dtype()) for s in systems_to_torch(systems)] + return [s.to(dtype=dtype) for s in systems_to_torch(systems)] diff --git a/src/metatensor/models/utils/data/readers/targets/ase.py b/src/metatensor/models/utils/data/readers/targets/ase.py index 5edcdef22..0cb689579 100644 --- a/src/metatensor/models/utils/data/readers/targets/ase.py +++ b/src/metatensor/models/utils/data/readers/targets/ase.py @@ -9,11 +9,13 @@ def read_energy_ase( filename: str, key: str, + dtype: torch.dtype = torch.float64, ) -> List[TensorBlock]: """Store energy information in a List of :class:`metatensor.TensorBlock`. :param filename: name of the file to read :param key: target value key name to be parsed from the file. + :param dtype: desired data type of returned tensor :returns: TensorMap containing the given information @@ -24,7 +26,7 @@ def read_energy_ase( blocks = [] for i_structure, atoms in enumerate(frames): - values = torch.tensor([[atoms.info[key]]], dtype=torch.get_default_dtype()) + values = torch.tensor([[atoms.info[key]]], dtype=dtype) samples = Labels(["structure"], torch.tensor([[i_structure]])) block = TensorBlock( @@ -41,12 +43,14 @@ def read_energy_ase( def read_forces_ase( filename: str, key: str = "energy", + dtype: torch.dtype = torch.float64, ) -> List[TensorBlock]: """Store force information in a List of :class:`metatensor.TensorBlock` which can be used as ``position`` gradients. :param filename: name of the file to read :param key: target value key name to be parsed from the file. + :param dtype: desired data type of returned tensor :returns: TensorMap containing the given information @@ -59,7 +63,7 @@ def read_forces_ase( blocks = [] for i_structure, atoms in enumerate(frames): # We store forces as positions gradients which means we invert the sign - values = -torch.tensor(atoms.arrays[key], dtype=torch.get_default_dtype()) + values = -torch.tensor(atoms.arrays[key], dtype=dtype) values = values.reshape(-1, 3, 1) samples = Labels( @@ -82,39 +86,48 @@ def read_forces_ase( def read_virial_ase( filename: str, key: str = "virial", + dtype: torch.dtype = torch.float64, ) -> List[TensorBlock]: """Store virial information in a List of :class:`metatensor.TensorBlock` which can be used as ``strain`` gradients. :param filename: name of the file to read :param key: target value key name to be parsed from the file + :param dtype: desired data type of returned tensor :returns: TensorMap containing the given information """ - return _read_virial_stress_ase(filename=filename, key=key, is_virial=True) + return _read_virial_stress_ase( + filename=filename, key=key, is_virial=True, dtype=dtype + ) def read_stress_ase( filename: str, key: str = "stress", + dtype: torch.dtype = torch.float64, ) -> List[TensorBlock]: """Store stress information in a List of :class:`metatensor.TensorBlock` which can be used as ``strain`` gradients. :param filename: name of the file to read :param key: target value key name to be parsed from the file + :param dtype: desired data type of returned tensor :returns: TensorMap containing the given information """ - return _read_virial_stress_ase(filename=filename, key=key, is_virial=False) + return _read_virial_stress_ase( + filename=filename, key=key, is_virial=False, dtype=dtype + ) def _read_virial_stress_ase( filename: str, key: str, is_virial: bool = True, + dtype: torch.dtype = torch.float64, ) -> List[TensorBlock]: """Store stress or virial information in a List of :class:`metatensor.TensorBlock` which can be used as ``strain`` gradients. @@ -122,6 +135,7 @@ def _read_virial_stress_ase( :param filename: name of the file to read :param key: target value key name to be parsed from the file :param is_virial: if target values are stored as stress or virials. + :param dtype: desired data type of returned tensor :returns: TensorMap containing the given information @@ -138,7 +152,7 @@ def _read_virial_stress_ase( blocks = [] for i_structure, atoms in enumerate(frames): - values = torch.tensor(atoms.info[key].tolist(), dtype=torch.get_default_dtype()) + values = torch.tensor(atoms.info[key].tolist(), dtype=dtype) if values.shape == (9,): warnings.warn( diff --git a/tests/utils/data/targets/test_targets_ase.py b/tests/utils/data/targets/test_targets_ase.py index 0c681f71a..62054b663 100644 --- a/tests/utils/data/targets/test_targets_ase.py +++ b/tests/utils/data/targets/test_targets_ase.py @@ -42,10 +42,10 @@ def test_read_energy_ase(monkeypatch, tmp_path): structures = ase_systems() ase.io.write(filename, structures) - results = read_energy_ase(filename=filename, key="true_energy") + results = read_energy_ase(filename=filename, key="true_energy", dtype=torch.float16) for result, atoms in zip(results, structures): - expected = torch.tensor([[atoms.info["true_energy"]]]) + expected = torch.tensor([[atoms.info["true_energy"]]], dtype=torch.float16) torch.testing.assert_close(result.values, expected) @@ -57,10 +57,11 @@ def test_read_forces_ase(monkeypatch, tmp_path): structures = ase_systems() ase.io.write(filename, structures) - results = read_forces_ase(filename=filename, key="forces") + results = read_forces_ase(filename=filename, key="forces", dtype=torch.float16) for result, atoms in zip(results, structures): - expected = -torch.tensor(atoms.get_array("forces")).reshape(-1, 3, 1) + expected = -torch.tensor(atoms.get_array("forces"), dtype=torch.float16) + expected = expected.reshape(-1, 3, 1) torch.testing.assert_close(result.values, expected) @@ -72,10 +73,12 @@ def test_read_stress_ase(monkeypatch, tmp_path): structures = ase_systems() ase.io.write(filename, structures) - results = read_stress_ase(filename=filename, key="stress-3x3") + results = read_stress_ase(filename=filename, key="stress-3x3", dtype=torch.float16) for result, atoms in zip(results, structures): - expected = atoms.cell.volume * torch.tensor(atoms.info["stress-3x3"]) + expected = atoms.cell.volume * torch.tensor( + atoms.info["stress-3x3"], dtype=torch.float16 + ) expected = expected.reshape(-1, 3, 3, 1) torch.testing.assert_close(result.values, expected) @@ -103,10 +106,10 @@ def test_read_virial_ase(monkeypatch, tmp_path): structures = ase_systems() ase.io.write(filename, structures) - results = read_virial_ase(filename=filename, key="stress-3x3") + results = read_virial_ase(filename=filename, key="stress-3x3", dtype=torch.float16) for result, atoms in zip(results, structures): - expected = -torch.tensor(atoms.info["stress-3x3"]) + expected = -torch.tensor(atoms.info["stress-3x3"], dtype=torch.float16) expected = expected.reshape(-1, 3, 3, 1) torch.testing.assert_close(result.values, expected) diff --git a/tests/utils/data/test_readers.py b/tests/utils/data/test_readers.py index 859578cdf..196eee981 100644 --- a/tests/utils/data/test_readers.py +++ b/tests/utils/data/test_readers.py @@ -29,14 +29,16 @@ def test_read_structures(fileformat, monkeypatch, tmp_path): structures = ase_systems() ase.io.write(filename, structures) - results = read_structures(filename, fileformat=fileformat) + results = read_structures(filename, fileformat=fileformat, dtype=torch.float16) assert isinstance(results, list) assert len(results) == len(structures) for structure, result in zip(structures, results): assert isinstance(result, torch.ScriptObject) - torch.testing.assert_close(result.positions, torch.tensor(structure.positions)) + torch.testing.assert_close( + result.positions, torch.tensor(structure.positions, dtype=torch.float16) + ) torch.testing.assert_close( result.species, torch.tensor([1, 1], dtype=torch.int32) ) @@ -55,11 +57,14 @@ def test_read_energies(fileformat, monkeypatch, tmp_path): structures = ase_systems() ase.io.write(filename, structures) - results = read_energy(filename, fileformat=fileformat, target_value="true_energy") + results = read_energy( + filename, fileformat=fileformat, target_value="true_energy", dtype=torch.float16 + ) assert type(results) is list assert len(results) == len(structures) for i_structure, result in enumerate(results): + assert result.values.dtype is torch.float16 assert result.samples.names == ["structure"] assert result.samples.values == torch.tensor([[i_structure]]) assert result.properties == Labels.single() @@ -73,11 +78,14 @@ def test_read_forces(fileformat, monkeypatch, tmp_path): structures = ase_systems() ase.io.write(filename, structures) - results = read_forces(filename, fileformat=fileformat, target_value="forces") + results = read_forces( + filename, fileformat=fileformat, target_value="forces", dtype=torch.float16 + ) assert type(results) is list assert len(results) == len(structures) for i_structure, result in enumerate(results): + assert result.values.dtype is torch.float16 assert result.samples.names == ["sample", "structure", "atom"] assert torch.all(result.samples["sample"] == torch.tensor(0)) assert torch.all(result.samples["structure"] == torch.tensor(i_structure)) @@ -94,7 +102,9 @@ def test_read_stress_virial(reader, fileformat, monkeypatch, tmp_path): structures = ase_systems() ase.io.write(filename, structures) - results = reader(filename, fileformat=fileformat, target_value="stress-3x3") + results = reader( + filename, fileformat=fileformat, target_value="stress-3x3", dtype=torch.float16 + ) assert type(results) is list assert len(results) == len(structures) @@ -103,6 +113,7 @@ def test_read_stress_virial(reader, fileformat, monkeypatch, tmp_path): Labels(["xyz_2"], torch.arange(3).reshape(-1, 1)), ] for result in results: + assert result.values.dtype is torch.float16 assert result.samples.names == ["sample"] assert result.samples.values == torch.tensor([[0]]) assert result.components == components @@ -149,7 +160,7 @@ def test_read_targets(stress_dict, virial_dict, monkeypatch, tmp_path, caplog): } caplog.set_level(logging.INFO) - result = read_targets(OmegaConf.create(conf)) + result = read_targets(OmegaConf.create(conf), dtype=torch.float16) assert any(["Forces found" in rec.message for rec in caplog.records]) @@ -166,10 +177,12 @@ def test_read_targets(stress_dict, virial_dict, monkeypatch, tmp_path, caplog): assert target.keys == Labels(["lambda", "sigma"], torch.tensor([(0, 1)])) result_block = target.block() + assert result_block.values.dtype is torch.float16 assert result_block.samples.names == ["structure"] assert result_block.properties == Labels.single() pos_grad = result_block.gradient("positions") + assert pos_grad.values.dtype is torch.float16 assert pos_grad.samples.names == ["sample", "structure", "atom"] assert pos_grad.components == [ Labels(["xyz"], torch.arange(3).reshape(-1, 1)) @@ -181,7 +194,7 @@ def test_read_targets(stress_dict, virial_dict, monkeypatch, tmp_path, caplog): Labels(["xyz_1"], torch.arange(3).reshape(-1, 1)), Labels(["xyz_2"], torch.arange(3).reshape(-1, 1)), ] - + assert disp_grad.values.dtype is torch.float16 assert disp_grad.samples.names == ["sample"] assert disp_grad.components == components assert disp_grad.properties == Labels.single()