Skip to content

Commit

Permalink
fix[converter]: convert_field_to_list resolver error
Browse files Browse the repository at this point in the history
new get_query_fields cannot find union types called. Introduced get_queried_union_types to find it
  • Loading branch information
mak626 committed Jan 13, 2025
1 parent cd132a7 commit 049cac8
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 103 deletions.
198 changes: 95 additions & 103 deletions graphene_mongo/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .utils import (
get_field_description,
get_query_fields,
get_queried_union_types,
get_field_is_required,
ExecutorEnum,
sync_to_async,
Expand Down Expand Up @@ -154,7 +155,7 @@ def convert_field_to_list(field, registry=None, executor: ExecutorEnum = Executo
if isinstance(field.field, mongoengine.GenericReferenceField):

def get_reference_objects(*args, **kwargs):
document = get_document(args[0][0])
document = get_document(args[0])
document_field = mongoengine.ReferenceField(document)
document_field = convert_mongoengine_field(document_field, registry)
document_field_type = document_field.get_type().type
Expand All @@ -164,75 +165,70 @@ def get_reference_objects(*args, **kwargs):
for key, values in document_field_type._meta.filter_fields.items():
for each in values:
filter_args.append(key + "__" + each)
for each in get_query_fields(args[0][3][0])[document_field_type._meta.name].keys():
for each in args[4]:
item = to_snake_case(each)
if item in document._fields_ordered + tuple(filter_args):
queried_fields.append(item)
return (
document.objects()
.no_dereference()
.only(*set(list(document_field_type._meta.required_fields) + queried_fields))
.filter(pk__in=args[0][1])
.filter(pk__in=args[1])
)

def get_non_querying_object(*args, **kwargs):
model = get_document(args[0][0])
return [model(pk=each) for each in args[0][1]]
model = get_document(args[0])
return [model(pk=each) for each in args[1]]

def reference_resolver(root, *args, **kwargs):
to_resolve = getattr(root, field.name or field.db_name)
if to_resolve:
choice_to_resolve = dict()
querying_union_types = list(get_query_fields(args[0]).keys())
if "__typename" in querying_union_types:
querying_union_types.remove("__typename")
to_resolve_models = list()
for each in querying_union_types:
if executor == ExecutorEnum.SYNC:
to_resolve_models.append(registry._registry_string_map[each])
else:
to_resolve_models.append(registry._registry_async_string_map[each])
to_resolve_object_ids = list()
for each in to_resolve:
if isinstance(each, LazyReference):
to_resolve_object_ids.append(each.pk)
model = each.document_type._class_name
if model not in choice_to_resolve:
choice_to_resolve[model] = list()
choice_to_resolve[model].append(each.pk)
else:
to_resolve_object_ids.append(each["_ref"].id)
if each["_cls"] not in choice_to_resolve:
choice_to_resolve[each["_cls"]] = list()
choice_to_resolve[each["_cls"]].append(each["_ref"].id)
pool = ThreadPoolExecutor(5)
futures = list()
for model, object_id_list in choice_to_resolve.items():
if model in to_resolve_models:
futures.append(
pool.submit(
get_reference_objects,
(model, object_id_list, registry, args),
)
if not to_resolve:
return None

choice_to_resolve = dict()
querying_union_types = get_queried_union_types(args[0])
to_resolve_models = dict()
for each, queried_fields in querying_union_types.items():
to_resolve_models[registry._registry_string_map[each]] = queried_fields
to_resolve_object_ids = list()
for each in to_resolve:
if isinstance(each, LazyReference):
to_resolve_object_ids.append(each.pk)
model = each.document_type._class_name
if model not in choice_to_resolve:
choice_to_resolve[model] = list()
choice_to_resolve[model].append(each.pk)
else:
to_resolve_object_ids.append(each["_ref"].id)
if each["_cls"] not in choice_to_resolve:
choice_to_resolve[each["_cls"]] = list()
choice_to_resolve[each["_cls"]].append(each["_ref"].id)
pool = ThreadPoolExecutor(5)
futures = list()
for model, object_id_list in choice_to_resolve.items():
if model in to_resolve_models:
queried_fields = to_resolve_models[model]
futures.append(
pool.submit(
get_reference_objects,
*(model, object_id_list, registry, args, queried_fields),
)
else:
futures.append(
pool.submit(
get_non_querying_object,
(model, object_id_list, registry, args),
)
)
else:
futures.append(
pool.submit(
get_non_querying_object,
*(model, object_id_list, registry, args),
)
result = list()
for x in as_completed(futures):
result += x.result()
result_object_ids = list()
for each in result:
result_object_ids.append(each.id)
ordered_result = list()
for each in to_resolve_object_ids:
ordered_result.append(result[result_object_ids.index(each)])
return ordered_result
return None
)
result = list()
for x in as_completed(futures):
result += x.result()
result_object_ids = [each.id for each in result]
ordered_result = [
result[result_object_ids.index(each)] for each in to_resolve_object_ids
]
return ordered_result

async def get_reference_objects_async(*args, **kwargs):
document = get_document(args[0])
Expand All @@ -247,7 +243,7 @@ async def get_reference_objects_async(*args, **kwargs):
for key, values in document_field_type._meta.filter_fields.items():
for each in values:
filter_args.append(key + "__" + each)
for each in get_query_fields(args[3][0])[document_field_type._meta.name].keys():
for each in args[4]:
item = to_snake_case(each)
if item in document._fields_ordered + tuple(filter_args):
queried_fields.append(item)
Expand All @@ -259,57 +255,53 @@ async def get_reference_objects_async(*args, **kwargs):
)

async def get_non_querying_object_async(*args, **kwargs):
model = get_document(args[0])
return [model(pk=each) for each in args[1]]
return get_non_querying_object(*args, **kwargs)

async def reference_resolver_async(root, *args, **kwargs):
to_resolve = getattr(root, field.name or field.db_name)
if to_resolve:
choice_to_resolve = dict()
querying_union_types = list(get_query_fields(args[0]).keys())
if "__typename" in querying_union_types:
querying_union_types.remove("__typename")
to_resolve_models = list()
for each in querying_union_types:
if executor == ExecutorEnum.SYNC:
to_resolve_models.append(registry._registry_string_map[each])
else:
to_resolve_models.append(registry._registry_async_string_map[each])
to_resolve_object_ids = list()
for each in to_resolve:
if isinstance(each, LazyReference):
to_resolve_object_ids.append(each.pk)
model = each.document_type._class_name
if model not in choice_to_resolve:
choice_to_resolve[model] = list()
choice_to_resolve[model].append(each.pk)
else:
to_resolve_object_ids.append(each["_ref"].id)
if each["_cls"] not in choice_to_resolve:
choice_to_resolve[each["_cls"]] = list()
choice_to_resolve[each["_cls"]].append(each["_ref"].id)
loop = asyncio.get_event_loop()
tasks = []
for model, object_id_list in choice_to_resolve.items():
if model in to_resolve_models:
task = loop.create_task(
get_reference_objects_async(model, object_id_list, registry, args)
)
else:
task = loop.create_task(
get_non_querying_object_async(model, object_id_list, registry, args)
if not to_resolve:
return None

choice_to_resolve = dict()
querying_union_types = get_queried_union_types(args[0])
to_resolve_models = dict()
for each, queried_fields in querying_union_types.items():
to_resolve_models[registry._registry_async_string_map[each]] = queried_fields
to_resolve_object_ids = list()
for each in to_resolve:
if isinstance(each, LazyReference):
to_resolve_object_ids.append(each.pk)
model = each.document_type._class_name
if model not in choice_to_resolve:
choice_to_resolve[model] = list()
choice_to_resolve[model].append(each.pk)
else:
to_resolve_object_ids.append(each["_ref"].id)
if each["_cls"] not in choice_to_resolve:
choice_to_resolve[each["_cls"]] = list()
choice_to_resolve[each["_cls"]].append(each["_ref"].id)
loop = asyncio.get_event_loop()
tasks = []
for model, object_id_list in choice_to_resolve.items():
if model in to_resolve_models:
queried_fields = to_resolve_models[model]
task = loop.create_task(
get_reference_objects_async(
model, object_id_list, registry, args, queried_fields
)
tasks.append(task)
result = await asyncio.gather(*tasks)
result_object = {}
for items in result:
for item in items:
result_object[item.id] = item
ordered_result = list()
for each in to_resolve_object_ids:
ordered_result.append(result_object[each])
return ordered_result
return None
)
else:
task = loop.create_task(
get_non_querying_object_async(model, object_id_list, registry, args)
)
tasks.append(task)
result = await asyncio.gather(*tasks)
result_object = {}
for items in result:
for item in items:
result_object[item.id] = item
ordered_result = [result_object[each] for each in to_resolve_object_ids]
return ordered_result

return graphene.List(
base_type._type,
Expand Down
34 changes: 34 additions & 0 deletions graphene_mongo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,40 @@ def get_query_fields(info):
return query


def get_queried_union_types(info):
"""A convenience function to get queried union types with its fields
Args:
info (ResolveInfo)
Returns:
dict[union_type_name, queried_fields(dict)]
"""

fragments = {}
node = ast_to_dict(info.field_nodes[0])
variables = info.variable_values

for name, value in info.fragments.items():
fragments[name] = ast_to_dict(value)

fragments_queries: dict[str, dict] = {}

selection_set = node.get("selection_set") if isinstance(node, dict) else node.selection_set
if selection_set:
for leaf in selection_set.selections:
if leaf.kind == "fragment_spread":
fragment_name = fragments[leaf.name.value].type_condition.name.value
fragments_queries[fragment_name] = collect_query_fields(
fragments[leaf.name.value], fragments, variables
)
elif leaf.kind == "inline_fragment":
fragment_name = leaf.type_condition.name.value
fragments_queries[fragment_name] = collect_query_fields(leaf, fragments, variables)

return fragments_queries


def has_page_info(info):
"""A convenience function to call collect_query_fields with info
for retrieving if page_info details are required
Expand Down

0 comments on commit 049cac8

Please sign in to comment.