Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Initial support for apiv2 #1085

Draft
wants to merge 18 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion client/python/cryoet_data_portal/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ test:

.PHONY: codegen
codegen:
python -m cryoet_data_portal._codegen
cd src && python3 -m cryoet_data_portal._codegen
# Need to run pre-commit twice because black and ruff fight with each other.
# Ignore the return code because that is non-zero when pre-commit applies a fix.
-pre-commit run --files src/cryoet_data_portal/_models.py src/cryoet_data_portal/data/schema.graphql
Expand Down
32 changes: 21 additions & 11 deletions client/python/cryoet_data_portal/src/cryoet_data_portal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,36 +6,46 @@

from ._client import Client
from ._models import (
Annotation,
Alignment,
AnnotationAuthor,
AnnotationFile,
Dataset,
AnnotationShape,
Annotation,
DatasetAuthor,
DatasetFunding,
Deposition,
Dataset,
DepositionAuthor,
Deposition,
Frame,
PerSectionAlignmentParameters,
PerSectionParameters,
Run,
TiltSeries,
Tomogram,
Tiltseries,
TomogramAuthor,
TomogramVoxelSpacing,
Tomogram,
)

__version__ = "3.1.0"

__all__ = [
"Client",
"Annotation",
"AnnotationFile",
"Alignment",
"AnnotationAuthor",
"Dataset",
"AnnotationFile",
"AnnotationShape",
"Annotation",
"DatasetAuthor",
"DatasetFunding",
"Deposition",
"Dataset",
"DepositionAuthor",
"Deposition",
"Frame",
"PerSectionAlignmentParameters",
"PerSectionParameters",
"Run",
"TiltSeries",
"Tomogram",
"Tiltseries",
"TomogramAuthor",
"TomogramVoxelSpacing",
"Tomogram",
]
15 changes: 10 additions & 5 deletions client/python/cryoet_data_portal/src/cryoet_data_portal/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from cryoet_data_portal._constants import USER_AGENT

DEFAULT_URL = "https://graphql.cryoetdataportal.cziscience.com/v1/graphql"
DEFAULT_URL = "http://localhost:9009/graphql"


class Client:
Expand Down Expand Up @@ -46,7 +46,9 @@ def __init__(self, url: Optional[str] = None):
self.client = GQLClient(transport=transport, schema=schema_str)
self.ds = DSLSchema(self.client.schema)

def build_query(self, cls, gql_class_name: str, query_filters=None):
def build_query(
self, cls, root_field: str, gql_class_name: str, query_filters=None
):
ds = self.ds
query_filters = {} if not query_filters else {"where": query_filters}
gql_type = getattr(ds, gql_class_name)
Expand All @@ -55,7 +57,7 @@ def build_query(self, cls, gql_class_name: str, query_filters=None):
]
query = dsl_gql(
DSLQuery(
getattr(ds.query_root, gql_class_name)(**query_filters).select(
getattr(ds.Query, root_field)(**query_filters).select(
*scalar_fields,
),
),
Expand All @@ -64,8 +66,11 @@ def build_query(self, cls, gql_class_name: str, query_filters=None):

def find(self, cls, query_filters=None):
gql_type = cls._get_gql_type()
response = self.client.execute(self.build_query(cls, gql_type, query_filters))
return [cls(self, **item) for item in response[gql_type]]
gql_root = cls._get_gql_root_field()
response = self.client.execute(
self.build_query(cls, gql_root, gql_type, query_filters)
)
return [cls(self, **item) for item in response[gql_root]]

def find_one(self, *args, **kwargs):
for result in self.find(*args, **kwargs):
Expand Down
61 changes: 41 additions & 20 deletions client/python/cryoet_data_portal/src/cryoet_data_portal/_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
GraphQLType,
build_schema,
print_schema,
get_named_type,
)
from jinja2 import Environment, FileSystemLoader

Expand All @@ -31,7 +32,7 @@
"Float": ("FloatField()", "float"),
"Int": ("IntField()", "int"),
"String": ("StringField()", "str"),
"date": ("DateField()", "date"),
"DateTime": ("DateField()", "date"),
"numeric": ("FloatField()", "float"),
"_numeric": ("StringField()", "str"),
"tomogram_type_enum": ("StringField()", "str"),
Expand All @@ -40,19 +41,24 @@

"""Maps GraphQL type names to model class names."""
GQL_TO_MODEL_TYPE = {
"datasets": "Dataset",
"dataset_authors": "DatasetAuthor",
"dataset_funding": "DatasetFunding",
"runs": "Run",
"tomogram_voxel_spacings": "TomogramVoxelSpacing",
"tomograms": "Tomogram",
"tomogram_authors": "TomogramAuthor",
"annotations": "Annotation",
"annotation_files": "AnnotationFile",
"annotation_authors": "AnnotationAuthor",
"tiltseries": "TiltSeries",
"depositions": "Deposition",
"deposition_authors": "DepositionAuthor",
"Alignment": "Alignment",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If these are always expected to be identical now, maybe this should just be a tuple of gql/model names instead of a mapping?

"AnnotationAuthor": "AnnotationAuthor",
"AnnotationFile": "AnnotationFile",
"AnnotationShape": "AnnotationShape",
"Annotation": "Annotation",
"DatasetAuthor": "DatasetAuthor",
"DatasetFunding": "DatasetFunding",
"Dataset": "Dataset",
"DepositionAuthor": "DepositionAuthor",
"Deposition": "Deposition",
"Frame": "Frame",
"PerSectionAlignmentParameters": "PerSectionAlignmentParameters",
"PerSectionParameters": "PerSectionParameters",
"Run": "Run",
"Tiltseries": "Tiltseries",
"TomogramAuthor": "TomogramAuthor",
"TomogramVoxelSpacing": "TomogramVoxelSpacing",
"Tomogram": "Tomogram",
}


Expand Down Expand Up @@ -81,7 +87,8 @@ class ModelInfo:
"""The information about a parsed model."""

name: str
gql_name: str
gql_type: str
root_field: str
fields: Tuple[FieldInfo, ...]
description: Optional[str] = None

Expand Down Expand Up @@ -127,7 +134,8 @@ def get_models(schema: GraphQLSchema) -> Tuple[ModelInfo, ...]:
models.append(
ModelInfo(
name=model,
gql_name=gql_type.name,
gql_type=gql_type.name,
root_field=get_root_field_name(schema, gql_type),
description=gql_type.description,
fields=fields,
),
Expand Down Expand Up @@ -156,6 +164,17 @@ def load_schema(path: Path) -> GraphQLSchema:
return build_schema(schema_str)


def get_root_field_name(schema, gql_type: GraphQLObjectType) -> str:
"""Look up the root field name that represents the given GQL Type"""
"""NOTE that this assumes all queried types are present at the query root!"""
root = schema.get_type("Query")
for name, field in root.fields.items():
field_type = get_named_type(field.type)
if field_type.name == gql_type.name:
return name
raise RuntimeError(f"Could not root field for {gql_type.name}")


def parse_fields(gql_type: GraphQLObjectType) -> Tuple[FieldInfo, ...]:
"""Returns the field information parsed from a GraphQL object type."""
fields = []
Expand All @@ -181,7 +200,7 @@ def _parse_field(
) -> Optional[FieldInfo]:
logging.debug("_parse_field: %s, %s", name, field)
field_type = _maybe_unwrap_non_null(field.type)
if isinstance(field_type, GraphQLList):
if field_type.name.endswith("Connection"): # TODO can we clean this up?
return _parse_model_list_field(gql_type, name, field_type)
if isinstance(field_type, GraphQLObjectType) and (
field_type.name in GQL_TO_MODEL_TYPE
Expand Down Expand Up @@ -225,7 +244,7 @@ def _parse_model_field(
name=name,
description=f"The {model_name} this {source_model_name} is a part of",
annotation_type=model,
default_value=f'ItemRelationship("{model}", "{model_field}_id", "id")',
default_value=f'ItemRelationship("{model}", "{model_field}Id", "id")',
)
return None

Expand All @@ -236,7 +255,9 @@ def _parse_model_list_field(
field_type: GraphQLList[GraphQLType],
) -> Optional[FieldInfo]:
logging.debug("_parse_model_list_field: %s", field_type)
of_type = _maybe_unwrap_non_null(field_type.of_type)
of_type = get_named_type(
get_named_type(field_type.fields["edges"].type).fields["node"].type
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I find this chained call pretty hard to read. Consider unwrapping it with something like.

Suggested change
of_type = get_named_type(
get_named_type(field_type.fields["edges"].type).fields["node"].type
)
edges_type = get_named_type(field_type.fields["edges"].type)
of_type = get_named_type(edges_type.fields["node"].type)

if not isinstance(of_type, GraphQLNamedType):
return None
of_model = GQL_TO_MODEL_TYPE.get(of_type.name)
Expand All @@ -249,7 +270,7 @@ def _parse_model_list_field(
name=name,
description=f"The {of_model_name} of this {source_model_name}",
annotation_type=f"List[{of_model}]",
default_value=f'ListRelationship("{of_model}", "id", "{source_field}_id")',
default_value=f'ListRelationship("{of_model}", "id", "{source_field}Id")',
)
return None

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,28 +92,22 @@ def convert(self, value):
return value


class StringField(BaseField):
...
class StringField(BaseField): ...


class IntField(BaseField):
...
class IntField(BaseField): ...


class DateField(BaseField):
def convert(self, value):
if value:
return datetime.date(
datetime.strptime(value, "%Y-%m-%d").astimezone(timezone.utc),
)
return datetime.fromisoformat(value)


class BooleanField(BaseField):
...
class BooleanField(BaseField): ...


class FloatField(BaseField):
...
class FloatField(BaseField): ...


class QueryChain(GQLField):
Expand Down Expand Up @@ -200,6 +194,7 @@ class Model:
"""The base class that all CryoET Portal Domain Object classes descend from. Documented methods apply to all domain objects."""

_gql_type: str
_gql_root_field: str

def __init__(self, client: Client, **kwargs):
self._client = client
Expand Down Expand Up @@ -233,6 +228,10 @@ def _get_relationship_fields(cls):
def _get_gql_type(cls):
return cls._gql_type

@classmethod
def _get_gql_root_field(cls):
return cls._gql_root_field

@classmethod
def find(
cls,
Expand Down
Loading
Loading