diff --git a/src/metatensor/models/utils/data/readers/structures/__init__.py b/src/metatensor/models/utils/data/readers/structures/__init__.py index 69e633e73..78457ca29 100644 --- a/src/metatensor/models/utils/data/readers/structures/__init__.py +++ b/src/metatensor/models/utils/data/readers/structures/__init__.py @@ -1,4 +1,4 @@ from .ase import read_structures_ase -STRUCTURE_READERS = {".xyz": read_structures_ase} +STRUCTURE_READERS = {".extxyz": read_structures_ase, ".xyz": read_structures_ase} """:py:class:`dict`: dictionary mapping file suffixes to a structure reader""" diff --git a/src/metatensor/models/utils/data/readers/targets/__init__.py b/src/metatensor/models/utils/data/readers/targets/__init__.py index 934d8387a..4536bda04 100644 --- a/src/metatensor/models/utils/data/readers/targets/__init__.py +++ b/src/metatensor/models/utils/data/readers/targets/__init__.py @@ -1,13 +1,13 @@ from .ase import read_energy_ase, read_forces_ase, read_stress_ase, read_virial_ase -ENERGY_READERS = {".xyz": read_energy_ase} +ENERGY_READERS = {".extxyz": read_energy_ase, ".xyz": read_energy_ase} """:py:class:`dict`: dictionary mapping file suffixes to a target energy reader""" -FORCES_READERS = {".xyz": read_forces_ase} +FORCES_READERS = {".extxyz": read_forces_ase, ".xyz": read_forces_ase} """:py:class:`dict`: dictionary mapping file suffixes to a target forces reader""" -STRESS_READERS = {".xyz": read_stress_ase} +STRESS_READERS = {".extxyz": read_stress_ase, ".xyz": read_stress_ase} """:py:class:`dict`: dictionary mapping file suffixes to a target stress reader""" -VIRIAL_READERS = {".xyz": read_virial_ase} +VIRIAL_READERS = {".extxyz": read_virial_ase, ".xyz": read_virial_ase} """:py:class:`dict`: dictionary mapping file suffixes to a target virial reader""" diff --git a/tests/utils/data/test_readers.py b/tests/utils/data/test_readers.py index 196eee981..96fdd094f 100644 --- a/tests/utils/data/test_readers.py +++ b/tests/utils/data/test_readers.py @@ -21,7 +21,7 @@ ) -@pytest.mark.parametrize("fileformat", (None, ".xyz")) +@pytest.mark.parametrize("fileformat", (None, ".xyz", ".extxyz")) def test_read_structures(fileformat, monkeypatch, tmp_path): monkeypatch.chdir(tmp_path) @@ -49,7 +49,7 @@ def test_read_structures_unknown_fileformat(): read_structures("foo.bar") -@pytest.mark.parametrize("fileformat", (None, ".xyz")) +@pytest.mark.parametrize("fileformat", (None, ".xyz", ".extxyz")) def test_read_energies(fileformat, monkeypatch, tmp_path): monkeypatch.chdir(tmp_path) @@ -70,7 +70,7 @@ def test_read_energies(fileformat, monkeypatch, tmp_path): assert result.properties == Labels.single() -@pytest.mark.parametrize("fileformat", (None, ".xyz")) +@pytest.mark.parametrize("fileformat", (None, ".xyz", ".extxyz")) def test_read_forces(fileformat, monkeypatch, tmp_path): monkeypatch.chdir(tmp_path) @@ -94,7 +94,7 @@ def test_read_forces(fileformat, monkeypatch, tmp_path): @pytest.mark.parametrize("reader", [read_stress, read_virial]) -@pytest.mark.parametrize("fileformat", (None, ".xyz")) +@pytest.mark.parametrize("fileformat", (None, ".xyz", ".extxyz")) def test_read_stress_virial(reader, fileformat, monkeypatch, tmp_path): monkeypatch.chdir(tmp_path)