Skip to content

Commit

Permalink
wrote method to recurse through datamodel
Browse files Browse the repository at this point in the history
  • Loading branch information
frehburg committed Oct 17, 2024
1 parent 9537006 commit 126c6ef
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 1 deletion.
27 changes: 26 additions & 1 deletion src/phenopacket_mapper/data_standards/data_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from dataclasses import dataclass, field
from pathlib import Path
from typing import Union, List, Literal, Dict, Optional, Any, Callable, Tuple
from typing import Union, List, Literal, Dict, Optional, Any, Callable, Tuple, Iterable
import warnings

import pandas as pd
Expand Down Expand Up @@ -522,3 +522,28 @@ def __getattr__(self, var_name: str) -> Union[DataField, DataSection, 'OrGroup']
if f.id == var_name:
return f
raise AttributeError(f"'OrGroup' object has no attribute '{var_name}'")


def recursive_collect_all_members_data_model(
data_model: Union[DataModel, DataSection, OrGroup, DataField]
) -> Iterable[Union[DataSection, OrGroup, DataField]]:
"""Recursively collect all members of a DataModel, DataSection, OrGroup, or DataField
:param data_model: DataModel, DataSection, OrGroup, or DataField to collect all members from
:return: Iterable of DataSection, OrGroup, and DataField members
"""
if isinstance(data_model, DataModel):
for f in data_model.fields:
yield from recursive_collect_all_members_data_model(f)
elif isinstance(data_model, DataSection):
yield data_model
for f in data_model.fields:
yield from recursive_collect_all_members_data_model(f)
elif isinstance(data_model, OrGroup):
yield data_model
for f in data_model.fields:
yield from recursive_collect_all_members_data_model(f)
elif isinstance(data_model, DataField):
yield data_model
else:
raise ValueError(f"Unsupported data_model type: {type(data_model)}")
62 changes: 62 additions & 0 deletions tests/utils/test_recursive_collect_all_members_data_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import pytest

from phenopacket_mapper import DataModel
from phenopacket_mapper.data_standards import DataField, DataSection, OrGroup
from phenopacket_mapper.data_standards.data_model import recursive_collect_all_members_data_model

df1 = DataField(
name="test_field_1",
specification=str,
)

df2 = DataField(
name="test_field_2",
specification=int,
)

df3 = DataField(
name="test_field_3",
specification=bool,
)

ds1 = DataSection(
name="test_section_1",
fields=(df1, df2)
)

og1 = OrGroup(
name="test_or_group_1",
fields=(df1, df2)
)


@pytest.mark.parametrize(
"data_model, members",
[
(
DataModel(
name="test",
fields=(df1, df2)
),
[df1, df2]
), # tabular data model
(
DataModel(
name="test",
fields=(ds1, df3)
),
[df1, df2, ds1, df3]
), # hierarchical with section data model
(
DataModel(
name="test",
fields=(og1, df3)
),
[df1, df2, og1, df3]
), # hierarchical with or group data model
]
)
def test_recursive_collect_all_members_data_model(data_model, members):
assert set(recursive_collect_all_members_data_model(data_model)) == set(members)

0 comments on commit 126c6ef

Please sign in to comment.