Skip to content

Commit

Permalink
fix: pagination errors
Browse files Browse the repository at this point in the history
hasNextPage didn't become false when using after and first together

Fixed reverse Querying using last and before, in compliance to graphql relay spec

https://relay.dev/graphql/connections.htm#sec-Backward-pagination-arguments
https://relay.dev/graphql/connections.htm#sec-undefined.PageInfo
  • Loading branch information
mak626 committed Dec 21, 2023
1 parent b31b0b9 commit fc0f03b
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 121 deletions.
114 changes: 35 additions & 79 deletions graphene_mongo/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
find_skip_and_limit,
get_model_reference_fields,
get_query_fields,
has_page_info,
)

PYMONGO_VERSION = tuple(pymongo.version_tuple[:2])
Expand Down Expand Up @@ -276,7 +277,7 @@ def fields(self):
return self._type._meta.fields

def get_queryset(
self, model, info, required_fields=None, skip=None, limit=None, reversed=False, **args
self, model, info, required_fields=None, skip=None, limit=None, **args
) -> QuerySet:
if required_fields is None:
required_fields = list()
Expand Down Expand Up @@ -325,49 +326,22 @@ def get_queryset(
else:
args.update(queryset_or_filters)
if limit is not None:
if reversed:
if self.order_by:
order_by = self.order_by + ",-pk"
else:
order_by = "-pk"
return (
model.objects(**args)
.no_dereference()
.only(*required_fields)
.order_by(order_by)
.skip(skip if skip else 0)
.limit(limit)
)
else:
return (
model.objects(**args)
.no_dereference()
.only(*required_fields)
.order_by(self.order_by)
.skip(skip if skip else 0)
.limit(limit)
)
return (
model.objects(**args)
.no_dereference()
.only(*required_fields)
.order_by(self.order_by)
.skip(skip if skip else 0)
.limit(limit)
)
elif skip is not None:
if reversed:
if self.order_by:
order_by = self.order_by + ",-pk"
else:
order_by = "-pk"
return (
model.objects(**args)
.no_dereference()
.only(*required_fields)
.order_by(order_by)
.skip(skip)
)
else:
return (
model.objects(**args)
.no_dereference()
.only(*required_fields)
.order_by(self.order_by)
.skip(skip)
)
return (
model.objects(**args)
.no_dereference()
.only(*required_fields)
.order_by(self.order_by)
.skip(skip)
)
return model.objects(**args).no_dereference().only(*required_fields).order_by(self.order_by)

def default_resolver(self, _root, info, required_fields=None, resolved=None, **args):
Expand Down Expand Up @@ -401,7 +375,6 @@ def default_resolver(self, _root, info, required_fields=None, resolved=None, **a
skip = 0
count = 0
limit = None
reverse = False
first = args.pop("first", None)
after = args.pop("after", None)
if after:
Expand All @@ -410,14 +383,15 @@ def default_resolver(self, _root, info, required_fields=None, resolved=None, **a
before = args.pop("before", None)
if before:
before = cursor_to_offset(before)
requires_page_info = has_page_info(info)
has_next_page = False

if resolved is not None:
items = resolved

if isinstance(items, QuerySet):
try:
if last is not None and after is not None:
if last is not None:
count = items.count(with_limit_and_skip=False)
else:
count = None
Expand All @@ -426,29 +400,24 @@ def default_resolver(self, _root, info, required_fields=None, resolved=None, **a
else:
count = len(items)

skip, limit, reverse = find_skip_and_limit(
skip, limit = find_skip_and_limit(
first=first, last=last, after=after, before=before, count=count
)

if isinstance(items, QuerySet):
if limit:
_base_query: QuerySet = (
items.order_by("-pk").skip(skip) if reverse else items.skip(skip)
)
_base_query: QuerySet = items.skip(skip)
items = _base_query.limit(limit)
has_next_page = len(_base_query.skip(limit).only("id").limit(1)) != 0
has_next_page = len(_base_query.skip(skip + limit).only("id").limit(1)) != 0
elif skip:
items = items.skip(skip)
else:
if limit:
if reverse:
_base_query = items[::-1]
items = _base_query[skip : skip + limit]
has_next_page = (skip + limit) < len(_base_query)
else:
_base_query = items
items = items[skip : skip + limit]
has_next_page = (skip + limit) < len(_base_query)
_base_query = items
items = items[skip : skip + limit]
has_next_page = (
(skip + limit) < len(_base_query) if requires_page_info else False
)
elif skip:
items = items[skip:]
iterables = list(items)
Expand Down Expand Up @@ -503,11 +472,11 @@ def default_resolver(self, _root, info, required_fields=None, resolved=None, **a
else:
count = self.model.objects(args_copy).count()
if count != 0:
skip, limit, reverse = find_skip_and_limit(
skip, limit = find_skip_and_limit(
first=first, after=after, last=last, before=before, count=count
)
iterables = self.get_queryset(
self.model, info, required_fields, skip, limit, reverse, **args
self.model, info, required_fields, skip, limit, **args
)
list_length = len(iterables)
if isinstance(info, GraphQLResolveInfo):
Expand All @@ -519,14 +488,11 @@ def default_resolver(self, _root, info, required_fields=None, resolved=None, **a

elif "pk__in" in args and args["pk__in"]:
count = len(args["pk__in"])
skip, limit, reverse = find_skip_and_limit(
skip, limit = find_skip_and_limit(
first=first, last=last, after=after, before=before, count=count
)
if limit:
if reverse:
args["pk__in"] = args["pk__in"][::-1][skip : skip + limit]
else:
args["pk__in"] = args["pk__in"][skip : skip + limit]
args["pk__in"] = args["pk__in"][skip : skip + limit]
elif skip:
args["pk__in"] = args["pk__in"][skip:]
iterables = self.get_queryset(self.model, info, required_fields, **args)
Expand All @@ -542,18 +508,13 @@ def default_resolver(self, _root, info, required_fields=None, resolved=None, **a
field_name = to_snake_case(info.field_name)
items = getattr(_root, field_name, [])
count = len(items)
skip, limit, reverse = find_skip_and_limit(
skip, limit = find_skip_and_limit(
first=first, last=last, after=after, before=before, count=count
)
if limit:
if reverse:
_base_query = items[::-1]
items = _base_query[skip : skip + limit]
has_next_page = (skip + limit) < len(_base_query)
else:
_base_query = items
items = items[skip : skip + limit]
has_next_page = (skip + limit) < len(_base_query)
_base_query = items
items = items[skip : skip + limit]
has_next_page = (skip + limit) < len(_base_query) if requires_page_info else False
elif skip:
items = items[skip:]
iterables = items
Expand All @@ -567,11 +528,6 @@ def default_resolver(self, _root, info, required_fields=None, resolved=None, **a
)
has_previous_page = True if skip else False

if reverse:
iterables = list(iterables)
iterables.reverse()
skip = limit

connection = connection_from_iterables(
edges=iterables,
start_offset=skip,
Expand Down
52 changes: 18 additions & 34 deletions graphene_mongo/fields_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@ async def default_resolver(self, _root, info, required_fields=None, resolved=Non
skip = 0
count = 0
limit = None
reverse = False
first = args.pop("first", None)
after = args.pop("after", None)
if after:
Expand All @@ -109,7 +108,7 @@ async def default_resolver(self, _root, info, required_fields=None, resolved=Non

if isinstance(items, QuerySet):
try:
if last is not None and after is not None:
if last is not None:
count = await sync_to_async(items.count)(with_limit_and_skip=False)
else:
count = None
Expand All @@ -118,33 +117,30 @@ async def default_resolver(self, _root, info, required_fields=None, resolved=Non
else:
count = len(items)

skip, limit, reverse = find_skip_and_limit(
skip, limit = find_skip_and_limit(
first=first, last=last, after=after, before=before, count=count
)

if isinstance(items, QuerySet):
if limit:
_base_query: QuerySet = (
await sync_to_async(items.order_by("-pk").skip)(skip)
if reverse
else await sync_to_async(items.skip)(skip)
)
_base_query: QuerySet = await sync_to_async(items.skip)(skip)
items = await sync_to_async(_base_query.limit)(limit)
has_next_page = (
(await sync_to_async(len)(_base_query.skip(limit).only("id").limit(1)) != 0)
(
await sync_to_async(len)(
_base_query.skip(skip + limit).only("id").limit(1)
)
!= 0
)
if requires_page_info
else False
)
elif skip:
items = await sync_to_async(items.skip)(skip)
else:
if limit:
if reverse:
_base_query = items[::-1]
items = _base_query[skip : skip + limit]
else:
_base_query = items
items = items[skip : skip + limit]
_base_query = items
items = items[skip : skip + limit]
has_next_page = (
(skip + limit) < len(_base_query) if requires_page_info else False
)
Expand Down Expand Up @@ -195,11 +191,11 @@ async def default_resolver(self, _root, info, required_fields=None, resolved=Non
else:
count = await sync_to_async(self.model.objects(args_copy).count)()
if count != 0:
skip, limit, reverse = find_skip_and_limit(
skip, limit = find_skip_and_limit(
first=first, after=after, last=last, before=before, count=count
)
iterables = self.get_queryset(
self.model, info, required_fields, skip, limit, reverse, **args
self.model, info, required_fields, skip, limit, **args
)
iterables = await sync_to_async(list)(iterables)
list_length = len(iterables)
Expand All @@ -212,14 +208,11 @@ async def default_resolver(self, _root, info, required_fields=None, resolved=Non

elif "pk__in" in args and args["pk__in"]:
count = len(args["pk__in"])
skip, limit, reverse = find_skip_and_limit(
skip, limit = find_skip_and_limit(
first=first, last=last, after=after, before=before, count=count
)
if limit:
if reverse:
args["pk__in"] = args["pk__in"][::-1][skip : skip + limit]
else:
args["pk__in"] = args["pk__in"][skip : skip + limit]
args["pk__in"] = args["pk__in"][skip : skip + limit]
elif skip:
args["pk__in"] = args["pk__in"][skip:]
iterables = self.get_queryset(self.model, info, required_fields, **args)
Expand All @@ -236,16 +229,12 @@ async def default_resolver(self, _root, info, required_fields=None, resolved=Non
field_name = to_snake_case(info.field_name)
items = getattr(_root, field_name, [])
count = len(items)
skip, limit, reverse = find_skip_and_limit(
skip, limit = find_skip_and_limit(
first=first, last=last, after=after, before=before, count=count
)
if limit:
if reverse:
_base_query = items[::-1]
items = _base_query[skip : skip + limit]
else:
_base_query = items
items = items[skip : skip + limit]
_base_query = items
items = items[skip : skip + limit]
has_next_page = (skip + limit) < len(_base_query) if requires_page_info else False
elif skip:
items = items[skip:]
Expand All @@ -261,11 +250,6 @@ async def default_resolver(self, _root, info, required_fields=None, resolved=Non
)
has_previous_page = True if requires_page_info and skip else False

if reverse:
iterables = await sync_to_async(list)(iterables)
iterables.reverse()
skip = limit

connection = connection_from_iterables(
edges=iterables,
start_offset=skip,
Expand Down
16 changes: 8 additions & 8 deletions graphene_mongo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,9 +259,12 @@ def ast_to_dict(node, include_loc=False):


def find_skip_and_limit(first, last, after, before, count=None):
reverse = False
skip = 0
limit = None

if last is not None and count is None:
raise ValueError("Count Missing")

if first is not None and after is not None:
skip = after + 1
limit = first
Expand All @@ -274,29 +277,26 @@ def find_skip_and_limit(first, last, after, before, count=None):
skip = 0
limit = first
elif last is not None and before is not None:
reverse = False
if last >= before:
limit = before
else:
limit = last
skip = before - last
elif last is not None and after is not None:
if not count:
raise ValueError("Count Missing")
reverse = True
skip = after + 1
if last + after < count:
limit = last
else:
limit = count - after - 1
elif last is not None:
skip = 0
skip = count - last
limit = last
reverse = True
elif after is not None:
skip = after + 1
elif before is not None:
limit = before
return skip, limit, reverse

return skip, limit


def connection_from_iterables(
Expand Down

0 comments on commit fc0f03b

Please sign in to comment.