diff --git a/setup.cfg b/setup.cfg index 75dbef3..339f8d9 100644 --- a/setup.cfg +++ b/setup.cfg @@ -23,6 +23,9 @@ package_dir= packages=find: python_requires = >=3.8 install_requires = + gitpython + xmlschema + pydantic [options.packages.find] where=src diff --git a/src/pynxxas/nxdl/__init__.py b/src/pynxxas/nxdl/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/pynxxas/nxdl/models.py b/src/pynxxas/nxdl/models.py new file mode 100644 index 0000000..e95bab6 --- /dev/null +++ b/src/pynxxas/nxdl/models.py @@ -0,0 +1,329 @@ +"""NXDL models""" + +from typing import Optional, Union, List + +import pydantic + + +class Item(pydantic.BaseModel): + + class Config: + extra = "forbid" + + +class EnumerationItem(Item): + value: str + doc: Optional[str] = None + + +class Enumeration(Item): + item: List[EnumerationItem] + + +class DimensionItem(Item): + index: int + required: bool + value: Optional[str] = None + ref: Optional[str] = None + + +class DocDimensionItem(Item): + index: str + value: str + + +class DocDimensions(Item): + dim: List[DocDimensionItem] + + +class Dimensions(Item): + rank: Optional[Union[int, str]] = None + dim: Optional[List[DimensionItem]] = None + doc: Optional[Union[DocDimensions, str]] = None + + +class Attribute(Item): + name: str + type: str + doc: Optional[str] = None + recommended: Optional[bool] = None + optional: Optional[bool] = None + dimensions: Optional[Dimensions] = None + enumeration: Optional[Enumeration] = None + deprecated: Optional[str] = None + + +class Field(Item): + name: str + type: str + nameType: Optional[str] = None + units: Optional[str] = None + signal: Optional[int] = None + axis: Optional[int] = None + primary: Optional[int] = None + axes: Optional[str] = None + doc: Optional[str] = None + recommended: Optional[bool] = None + optional: Optional[bool] = None + minOccurs: Optional[int] = None + maxOccurs: Optional[Union[int, str]] = None + stride: Optional[bool] = None + data_offset: Optional[bool] = None + dimensions: Optional[Dimensions] = None + enumeration: Optional[Enumeration] = None + attribute: Optional[List[Attribute]] = None + deprecated: Optional[str] = None + + class Config: + extra = "forbid" + + +class Group(Item): + type: str + name: Optional[str] = None # type[2:].upper() + doc: Optional[List[Optional[str]]] = None # NXmirror returns [None] + recommended: Optional[bool] = None + optional: Optional[bool] = None + minOccurs: Optional[int] = None + maxOccurs: Optional[Union[int, str]] = None + attribute: Optional[List[Attribute]] = None + field: Optional[List[Field]] = None + group: Optional[List["Group"]] = None + link: Optional[List["Link"]] = None + deprecated: Optional[str] = None + + @pydantic.model_validator(mode="after") + def name_from_type(self) -> "Group": + if self.name is None: + self.name = self.type[2:].upper() + return self + + +class Link(Item): + name: str + target: str + doc: Optional[str] = None + + class Config: + extra = "forbid" + + +class Choice(Item): + name: str + group: List[Group] + + class Config: + extra = "forbid" + + +class Symbol(Item): + name: str + doc: str + + +class Symbols(Item): + symbol: Optional[List[Symbol]] = None + doc: Optional[str] = None + + +class Definition(Item): + xmlns: str + xmlns_xsi: str + name: str + type: str + category: str + xsi_schemaLocation: str + ignoreExtraGroups: bool + ignoreExtraFields: bool + ignoreExtraAttributes: bool + xmlns_xs: Optional[str] = None + xmlns_ns: Optional[str] = None + extends: Optional[str] = None + deprecated: Optional[str] = None + doc: Optional[str] = None + symbols: Optional[Symbols] = None + attribute: Optional[List[Attribute]] = None + field: Optional[List[Field]] = None + group: Optional[List[Group]] = None + link: Optional[List[Link]] = None + choice: Optional[List[Choice]] = None + + class Config: + extra = "forbid" + + +def load_enumeration_item(enum_item: dict) -> dict: + data = dict() + for key, value in enum_item.items(): + if key.startswith("@"): + key = key[1:] + data[key] = value + return data + + +def load_enumeration(enumeration: dict) -> dict: + data = dict() + for key, value in enumeration.items(): + data[key] = [load_enumeration_item(item) for item in value] + return data + + +def load_doc_dimension_item(dim_item: dict) -> dict: + data = dict() + for key, value in dim_item.items(): + if key.startswith("@"): + key = key[1:] + data[key] = value + return data + + +def load_dimension_item(dim_item: dict) -> dict: + data = dict() + for key, value in dim_item.items(): + if key.startswith("@"): + key = key[1:] + data[key] = value + return data + + +def load_doc_dimensions(dimensions: dict) -> dict: + data = dict() + for key, value in dimensions.items(): + if key.startswith("@"): + key = key[1:] + elif key == "dim": + value = [load_doc_dimension_item(item) for item in value] + data[key] = value + return data + + +def load_dimensions(dimensions: dict) -> dict: + data = dict() + for key, value in dimensions.items(): + if key.startswith("@"): + key = key[1:] + elif key == "dim": + value = [load_dimension_item(item) for item in value] + elif key == "doc" and isinstance(value, dict): + value = load_doc_dimensions(value) + data[key] = value + return data + + +def load_link(link: dict) -> dict: + data = dict() + for key, value in link.items(): + if key.startswith("@"): + key = key[1:] + data[key] = value + return data + + +def load_choice(choice: dict) -> dict: + data = dict() + for key, value in choice.items(): + if key.startswith("@"): + key = key[1:] + elif key == "group": + value = [load_group(group) for group in value] + data[key] = value + return data + + +def load_attribute(attr: dict) -> dict: + data = dict() + for key, value in attr.items(): + if key.startswith("@"): + key = key[1:] + elif key == "enumeration": + value = load_enumeration(value) + elif key == "dimensions": + value = load_dimensions(value) + data[key] = value + return data + + +def load_field(field: dict) -> dict: + data = dict() + for key, value in field.items(): + if key.startswith("@"): + key = key[1:] + elif key == "attribute": + value = [load_attribute(attr) for attr in value] + elif key == "enumeration": + value = load_enumeration(value) + elif key == "dimensions": + value = load_dimensions(value) + data[key] = value + return data + + +def load_group(group: dict) -> dict: + data = dict() + for key, value in group.items(): + if key.startswith("@"): + key = key[1:] + elif key == "attribute": + value = [load_attribute(attr) for attr in value] + elif key == "field": + value = [load_field(field) for field in value] + elif key == "group": + value = [load_group(group) for group in value] + elif key == "link": + value = [load_link(attr) for attr in value] + data[key] = value + return data + + +def load_symbol(symbol: dict) -> dict: + data = dict() + for key, value in symbol.items(): + if key.startswith("@"): + key = key[1:] + data[key] = value + return data + + +def load_symbols(symbols: dict) -> dict: + data = dict() + for key, value in symbols.items(): + if key.startswith("@"): + key = key[1:] + elif key == "symbol": + value = [load_symbol(attr) for attr in value] + data[key] = value + return data + + +def load_definition(definition: str) -> Definition: + data = dict() + + for key, value in definition.items(): + if key.startswith("@"): + key = key[1:].replace(":", "_") + elif key == "attribute": + value = [load_attribute(attr) for attr in value] + elif key == "field": + value = [load_field(attr) for attr in value] + elif key == "group": + value = [load_group(attr) for attr in value] + elif key == "link": + value = [load_link(attr) for attr in value] + elif key == "choice": + value = [load_choice(attr) for attr in value] + elif key == "symbols": + value = load_symbols(value) + data[key] = value + + return Definition(**data) + + +if __name__ == "__main__": + from . import repo + + names = repo.get_nxdl_class_names( + url="https://github.com/XraySpectroscopy/nexus_definitions.git" + ) + + for name in names: + load_definition(repo.get_nxdl_class(name)) diff --git a/src/pynxxas/nxdl/repo.py b/src/pynxxas/nxdl/repo.py new file mode 100644 index 0000000..678093f --- /dev/null +++ b/src/pynxxas/nxdl/repo.py @@ -0,0 +1,75 @@ +"""NXDL repository""" + +import os +import tempfile +from glob import glob +from functools import lru_cache +from typing import Optional + +import git +import xmlschema + +DEFAULT_URL = "https://github.com/nexusformat/definitions.git" + + +@lru_cache(maxsize=1) +def get_nxdl_schema(**repo_options): + """Returns the NDXL schema""" + repo = _get_repo(**repo_options) + schema_file = os.path.join(repo.working_dir, "nxdl.xsd") + return xmlschema.XMLSchema(schema_file, validation="lax") + + +@lru_cache(maxsize=1) +def get_nxdl_class_names(**repo_options): + """Returns all nxdl file names from the repo""" + working_dir = _get_repo(**repo_options).working_dir + pattern = os.path.join(working_dir, "*", "*.nxdl.xml") + return [ + os.path.basename(filename).replace(".nxdl.xml", "") + for filename in glob(pattern) + ] + + +def get_nxdl_class(name: str, **repo_options) -> dict: + """Load the content of an nxdl file""" + schema = get_nxdl_schema(**repo_options) + base_url = schema.base_url.replace("file://", "") + for dirname in ("base_classes", "applications", "contributed_definitions"): + xml_file = os.path.join(base_url, dirname, f"{name}.nxdl.xml") + if os.path.exists(xml_file): + break + return schema.to_dict(xml_file, process_namespaces=True, use_defaults=True) + + +@lru_cache(maxsize=1) +def _get_repo( + nxdl_version: Optional[str] = None, + localdir: Optional[str] = None, + url: Optional[str] = None, + branch: Optional[str] = None, +) -> git.Repo: + """Git repository with NeXus definition files (*.nxdl.xml)""" + if not localdir: + localdir = os.path.join(tempfile.gettempdir(), "nexus_definitions") + if not url: + url = DEFAULT_URL + if not branch: + branch = "main" + remote = "origin" + + if os.path.exists(localdir): + repo = git.Repo(localdir) + origin = repo.remotes[remote] + origin.set_url(url) + origin.fetch() + origin.pull() + else: + repo = git.Repo.clone_from(url, localdir) + + if nxdl_version: + repo.git.checkout(f"v{nxdl_version}") + else: + repo.git.checkout(branch) + repo.git.reset(f"{remote}/{branch}", hard=True) + return repo diff --git a/src/pynxxas/tests/conftest.py b/src/pynxxas/tests/conftest.py new file mode 100644 index 0000000..2f4092a --- /dev/null +++ b/src/pynxxas/tests/conftest.py @@ -0,0 +1,8 @@ +import pytest +from ..nxdl import repo + + +@pytest.fixture(scope="session") +def repo_directory(tmpdir_factory) -> str: + root = tmpdir_factory.mktemp("nexus_definitions") + return repo._get_repo(localdir=str(root / "official_repo")).working_dir diff --git a/src/pynxxas/tests/test_nxdl.py b/src/pynxxas/tests/test_nxdl.py new file mode 100644 index 0000000..1a12315 --- /dev/null +++ b/src/pynxxas/tests/test_nxdl.py @@ -0,0 +1,10 @@ +from ..nxdl import repo +from ..nxdl import models + + +def test_nxdl_models(repo_directory): + names = repo.get_nxdl_class_names(localdir=repo_directory) + + for name in names: + definition = repo.get_nxdl_class(name, localdir=repo_directory) + assert models.load_definition(definition) diff --git a/src/pynxxas/tests/test_todo.py b/src/pynxxas/tests/test_todo.py deleted file mode 100644 index 4f6c6c3..0000000 --- a/src/pynxxas/tests/test_todo.py +++ /dev/null @@ -1,2 +0,0 @@ -def test_todo(): - pass