Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for sort in collection.find function #359

Merged
merged 7 commits into from
Jan 31, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -124,3 +124,6 @@ arango/version.py

# test results
*_results.txt

# devcontainers
.devcontainer
7 changes: 6 additions & 1 deletion arango/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,13 @@
from arango.typings import Fields, Headers, Json, Jsons, Params
from arango.utils import (
build_filter_conditions,
build_sort_expression,
get_batches,
get_doc_id,
is_none_or_bool,
is_none_or_int,
is_none_or_str,
validate_sort_parameters,
)


Expand Down Expand Up @@ -753,6 +755,7 @@ def find(
skip: Optional[int] = None,
limit: Optional[int] = None,
allow_dirty_read: bool = False,
sort: Sequence[Json] = [],
apetenchea marked this conversation as resolved.
Show resolved Hide resolved
) -> Result[Cursor]:
"""Return all documents that match the given filters.

Expand All @@ -771,16 +774,18 @@ def find(
assert isinstance(filters, dict), "filters must be a dict"
assert is_none_or_int(skip), "skip must be a non-negative int"
assert is_none_or_int(limit), "limit must be a non-negative int"
if sort:
validate_sort_parameters(sort)

skip_val = skip if skip is not None else 0
limit_val = limit if limit is not None else "null"
query = f"""
FOR doc IN @@collection
{build_filter_conditions(filters)}
LIMIT {skip_val}, {limit_val}
{build_sort_expression(sort)}
RETURN doc
"""

bind_vars = {"@collection": self.name}

request = Request(
Expand Down
39 changes: 39 additions & 0 deletions arango/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,3 +126,42 @@ def build_filter_conditions(filters: Json) -> str:
conditions.append(f"doc.{field} == {json.dumps(v)}")

return "FILTER " + " AND ".join(conditions)


def validate_sort_parameters(sort: Sequence[Json]) -> bool:
"""Validate sort parameters for an AQL query.

:param sort: Document sort parameters.
:type sort: Sequence[Json]
:return: Validation success.
:rtype: bool
:raise arango.exceptions.DocumentGetError: If sort parameters are invalid.
Copy link
Member

@apetenchea apetenchea Jan 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docstring says DocumentGetError, but the function raises DocumentParseError. I think it's better to create a new error type, inheriting ArangoClientError in exceptions.py, use that and fix the docstring (for example, SortValidationError).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. I created a new section (Parameter Validation Exceptions) under exceptions.py, which can come handy for other types of validation errors (filter etc.) in the future.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!
Please update the docstring of this function using the newly created SortValidationError, along with the two raise statements at lines 143 and 147.

"""
assert isinstance(sort, Sequence)
for param in sort:
if "sort_by" not in param or "sort_order" not in param:
raise DocumentParseError(
"Each sort parameter must have 'sort_by' and 'sort_order'."
)
if param["sort_order"].upper() not in ["ASC", "DESC"]:
raise DocumentParseError("'sort_order' must be either 'ASC' or 'DESC'")
return True


def build_sort_expression(sort: Sequence[Json]) -> str:
apetenchea marked this conversation as resolved.
Show resolved Hide resolved
"""Build a sort condition for an AQL query.

:param sort: Document sort parameters.
:type sort: Sequence[Json]
:return: The complete AQL sort condition.
:rtype: str
"""
if not sort:
return ""

sort_chunks = []
for sort_param in sort:
chunk = f"doc.{sort_param['sort_by']} {sort_param['sort_order']}"
sort_chunks.append(chunk)

return "SORT " + ", ".join(sort_chunks)
6 changes: 6 additions & 0 deletions docs/document.rst
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,12 @@ Standard documents are managed via collection API wrapper:
assert student['GPA'] == 3.6
assert student['last'] == 'Kim'

# Retrieve one or more matching documents, sorted by a field.
for student in students.find({'first': 'John'}, sort=[{'sort_by': 'GPA', 'sort_order': 'DESC'}]):
assert student['_key'] == 'john'
assert student['GPA'] == 3.6
assert student['last'] == 'Kim'

# Retrieve a document by key.
students.get('john')

Expand Down
20 changes: 20 additions & 0 deletions tests/test_document.py
Original file line number Diff line number Diff line change
Expand Up @@ -1162,6 +1162,26 @@ def test_document_find(col, bad_col, docs):
# Set up test documents
col.import_bulk(docs)

# Test find with sort expression (single field)
found = list(col.find({}, sort=[{"sort_by": "text", "sort_order": "ASC"}]))
assert len(found) == 6
assert found[0]["text"] == "bar"
assert found[-1]["text"] == "foo"

# Test find with sort expression (multiple fields)
found = list(
col.find(
{},
sort=[
{"sort_by": "text", "sort_order": "ASC"},
{"sort_by": "val", "sort_order": "DESC"},
],
)
)
assert len(found) == 6
assert found[0]["val"] == 6
assert found[-1]["val"] == 1
apetenchea marked this conversation as resolved.
Show resolved Hide resolved

# Test find (single match) with default options
found = list(col.find({"val": 2}))
assert len(found) == 1
Expand Down
Loading