diff --git a/graphene_mongo/converter.py b/graphene_mongo/converter.py index 80c3108..b07bc71 100644 --- a/graphene_mongo/converter.py +++ b/graphene_mongo/converter.py @@ -11,6 +11,7 @@ from .utils import ( get_field_description, get_query_fields, + get_queried_union_types, get_field_is_required, ExecutorEnum, sync_to_async, @@ -154,7 +155,7 @@ def convert_field_to_list(field, registry=None, executor: ExecutorEnum = Executo if isinstance(field.field, mongoengine.GenericReferenceField): def get_reference_objects(*args, **kwargs): - document = get_document(args[0][0]) + document = get_document(args[0]) document_field = mongoengine.ReferenceField(document) document_field = convert_mongoengine_field(document_field, registry) document_field_type = document_field.get_type().type @@ -164,7 +165,7 @@ def get_reference_objects(*args, **kwargs): for key, values in document_field_type._meta.filter_fields.items(): for each in values: filter_args.append(key + "__" + each) - for each in get_query_fields(args[0][3][0])[document_field_type._meta.name].keys(): + for each in args[4]: item = to_snake_case(each) if item in document._fields_ordered + tuple(filter_args): queried_fields.append(item) @@ -172,67 +173,62 @@ def get_reference_objects(*args, **kwargs): document.objects() .no_dereference() .only(*set(list(document_field_type._meta.required_fields) + queried_fields)) - .filter(pk__in=args[0][1]) + .filter(pk__in=args[1]) ) def get_non_querying_object(*args, **kwargs): - model = get_document(args[0][0]) - return [model(pk=each) for each in args[0][1]] + model = get_document(args[0]) + return [model(pk=each) for each in args[1]] def reference_resolver(root, *args, **kwargs): to_resolve = getattr(root, field.name or field.db_name) - if to_resolve: - choice_to_resolve = dict() - querying_union_types = list(get_query_fields(args[0]).keys()) - if "__typename" in querying_union_types: - querying_union_types.remove("__typename") - to_resolve_models = list() - for each in querying_union_types: - if executor == ExecutorEnum.SYNC: - to_resolve_models.append(registry._registry_string_map[each]) - else: - to_resolve_models.append(registry._registry_async_string_map[each]) - to_resolve_object_ids = list() - for each in to_resolve: - if isinstance(each, LazyReference): - to_resolve_object_ids.append(each.pk) - model = each.document_type._class_name - if model not in choice_to_resolve: - choice_to_resolve[model] = list() - choice_to_resolve[model].append(each.pk) - else: - to_resolve_object_ids.append(each["_ref"].id) - if each["_cls"] not in choice_to_resolve: - choice_to_resolve[each["_cls"]] = list() - choice_to_resolve[each["_cls"]].append(each["_ref"].id) - pool = ThreadPoolExecutor(5) - futures = list() - for model, object_id_list in choice_to_resolve.items(): - if model in to_resolve_models: - futures.append( - pool.submit( - get_reference_objects, - (model, object_id_list, registry, args), - ) + if not to_resolve: + return None + + choice_to_resolve = dict() + querying_union_types = get_queried_union_types(args[0]) + to_resolve_models = dict() + for each, queried_fields in querying_union_types.items(): + to_resolve_models[registry._registry_string_map[each]] = queried_fields + to_resolve_object_ids = list() + for each in to_resolve: + if isinstance(each, LazyReference): + to_resolve_object_ids.append(each.pk) + model = each.document_type._class_name + if model not in choice_to_resolve: + choice_to_resolve[model] = list() + choice_to_resolve[model].append(each.pk) + else: + to_resolve_object_ids.append(each["_ref"].id) + if each["_cls"] not in choice_to_resolve: + choice_to_resolve[each["_cls"]] = list() + choice_to_resolve[each["_cls"]].append(each["_ref"].id) + pool = ThreadPoolExecutor(5) + futures = list() + for model, object_id_list in choice_to_resolve.items(): + if model in to_resolve_models: + queried_fields = to_resolve_models[model] + futures.append( + pool.submit( + get_reference_objects, + *(model, object_id_list, registry, args, queried_fields), ) - else: - futures.append( - pool.submit( - get_non_querying_object, - (model, object_id_list, registry, args), - ) + ) + else: + futures.append( + pool.submit( + get_non_querying_object, + *(model, object_id_list, registry, args), ) - result = list() - for x in as_completed(futures): - result += x.result() - result_object_ids = list() - for each in result: - result_object_ids.append(each.id) - ordered_result = list() - for each in to_resolve_object_ids: - ordered_result.append(result[result_object_ids.index(each)]) - return ordered_result - return None + ) + result = list() + for x in as_completed(futures): + result += x.result() + result_object_ids = [each.id for each in result] + ordered_result = [ + result[result_object_ids.index(each)] for each in to_resolve_object_ids + ] + return ordered_result async def get_reference_objects_async(*args, **kwargs): document = get_document(args[0]) @@ -247,7 +243,7 @@ async def get_reference_objects_async(*args, **kwargs): for key, values in document_field_type._meta.filter_fields.items(): for each in values: filter_args.append(key + "__" + each) - for each in get_query_fields(args[3][0])[document_field_type._meta.name].keys(): + for each in args[4]: item = to_snake_case(each) if item in document._fields_ordered + tuple(filter_args): queried_fields.append(item) @@ -259,57 +255,53 @@ async def get_reference_objects_async(*args, **kwargs): ) async def get_non_querying_object_async(*args, **kwargs): - model = get_document(args[0]) - return [model(pk=each) for each in args[1]] + return get_non_querying_object(*args, **kwargs) async def reference_resolver_async(root, *args, **kwargs): to_resolve = getattr(root, field.name or field.db_name) - if to_resolve: - choice_to_resolve = dict() - querying_union_types = list(get_query_fields(args[0]).keys()) - if "__typename" in querying_union_types: - querying_union_types.remove("__typename") - to_resolve_models = list() - for each in querying_union_types: - if executor == ExecutorEnum.SYNC: - to_resolve_models.append(registry._registry_string_map[each]) - else: - to_resolve_models.append(registry._registry_async_string_map[each]) - to_resolve_object_ids = list() - for each in to_resolve: - if isinstance(each, LazyReference): - to_resolve_object_ids.append(each.pk) - model = each.document_type._class_name - if model not in choice_to_resolve: - choice_to_resolve[model] = list() - choice_to_resolve[model].append(each.pk) - else: - to_resolve_object_ids.append(each["_ref"].id) - if each["_cls"] not in choice_to_resolve: - choice_to_resolve[each["_cls"]] = list() - choice_to_resolve[each["_cls"]].append(each["_ref"].id) - loop = asyncio.get_event_loop() - tasks = [] - for model, object_id_list in choice_to_resolve.items(): - if model in to_resolve_models: - task = loop.create_task( - get_reference_objects_async(model, object_id_list, registry, args) - ) - else: - task = loop.create_task( - get_non_querying_object_async(model, object_id_list, registry, args) + if not to_resolve: + return None + + choice_to_resolve = dict() + querying_union_types = get_queried_union_types(args[0]) + to_resolve_models = dict() + for each, queried_fields in querying_union_types.items(): + to_resolve_models[registry._registry_async_string_map[each]] = queried_fields + to_resolve_object_ids = list() + for each in to_resolve: + if isinstance(each, LazyReference): + to_resolve_object_ids.append(each.pk) + model = each.document_type._class_name + if model not in choice_to_resolve: + choice_to_resolve[model] = list() + choice_to_resolve[model].append(each.pk) + else: + to_resolve_object_ids.append(each["_ref"].id) + if each["_cls"] not in choice_to_resolve: + choice_to_resolve[each["_cls"]] = list() + choice_to_resolve[each["_cls"]].append(each["_ref"].id) + loop = asyncio.get_event_loop() + tasks = [] + for model, object_id_list in choice_to_resolve.items(): + if model in to_resolve_models: + queried_fields = to_resolve_models[model] + task = loop.create_task( + get_reference_objects_async( + model, object_id_list, registry, args, queried_fields ) - tasks.append(task) - result = await asyncio.gather(*tasks) - result_object = {} - for items in result: - for item in items: - result_object[item.id] = item - ordered_result = list() - for each in to_resolve_object_ids: - ordered_result.append(result_object[each]) - return ordered_result - return None + ) + else: + task = loop.create_task( + get_non_querying_object_async(model, object_id_list, registry, args) + ) + tasks.append(task) + result = await asyncio.gather(*tasks) + result_object = {} + for items in result: + for item in items: + result_object[item.id] = item + ordered_result = [result_object[each] for each in to_resolve_object_ids] + return ordered_result return graphene.List( base_type._type, diff --git a/graphene_mongo/utils.py b/graphene_mongo/utils.py index 0a6053a..e5fd457 100644 --- a/graphene_mongo/utils.py +++ b/graphene_mongo/utils.py @@ -203,6 +203,40 @@ def get_query_fields(info): return query +def get_queried_union_types(info): + """A convenience function to get queried union types with its fields + + Args: + info (ResolveInfo) + + Returns: + dict[union_type_name, queried_fields(dict)] + """ + + fragments = {} + node = ast_to_dict(info.field_nodes[0]) + variables = info.variable_values + + for name, value in info.fragments.items(): + fragments[name] = ast_to_dict(value) + + fragments_queries: dict[str, dict] = {} + + selection_set = node.get("selection_set") if isinstance(node, dict) else node.selection_set + if selection_set: + for leaf in selection_set.selections: + if leaf.kind == "fragment_spread": + fragment_name = fragments[leaf.name.value].type_condition.name.value + fragments_queries[fragment_name] = collect_query_fields( + fragments[leaf.name.value], fragments, variables + ) + elif leaf.kind == "inline_fragment": + fragment_name = leaf.type_condition.name.value + fragments_queries[fragment_name] = collect_query_fields(leaf, fragments, variables) + + return fragments_queries + + def has_page_info(info): """A convenience function to call collect_query_fields with info for retrieving if page_info details are required