Skip to content

Commit

Permalink
update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ciur committed Oct 14, 2024
1 parent 48d3fea commit 9d45e2b
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 10 deletions.
4 changes: 2 additions & 2 deletions papermerge/core/db/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
add_document_custom_field_values,
get_doc,
get_doc_cfv,
get_docs_by_type,
get_document_custom_field_values,
get_documents_by_type,
update_doc_cfv,
update_doc_type,
)
Expand Down Expand Up @@ -85,5 +85,5 @@
"get_document_custom_field_values",
"get_doc_cfv",
"update_doc_type",
"get_documents_by_type",
"get_docs_by_type",
]
21 changes: 16 additions & 5 deletions papermerge/core/db/doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,14 +374,16 @@ def get_document_custom_field_values(
return result


def get_documents_by_type(
def get_docs_by_type(
session: Session,
type_id: UUID,
parent_id: UUID,
ancestor_id: UUID,
user_id: UUID,
):
"""
Returns node / cf / cfv for all documents with `document_type_id = type_id`
Returns list of documents + doc CFv for all documents with of given type
All fetched documents are descendants of `ancestor_id` node.
"""
stmt = """
SELECT node.title,
Expand Down Expand Up @@ -419,19 +421,28 @@ def get_documents_by_type(
ON cfv.field_id = cf.cf_id AND cfv.document_id = doc_id
WHERE node.parent_id = :parent_id
"""
str_parent_id = str(parent_id).replace("-", "")
str_parent_id = str(ancestor_id).replace("-", "")
str_type_id = str(type_id).replace("-", "")
params = {"parent_id": str_parent_id, "document_type_id": str_type_id}
results = []
rows = session.execute(text(stmt), params)
for document_id, group in itertools.groupby(rows, lambda r: r.doc_id):
items = list(group)
custom_fields = []

for item in items:
if item.cf_type == "date":
value = str2date(item.cf_value)
else:
value = item.cf_value
custom_fields.append((item.cf_name, value))

results.append(
schemas.DocumentCFV(
id=uuid.UUID(document_id),
title=items[0].title,
document_type_id=uuid.UUID(items[0].document_type_id),
custom_fields=[(i.cf_name, str(i.cf_value)) for i in items],
custom_fields=custom_fields,
)
)

Expand Down
11 changes: 8 additions & 3 deletions papermerge/core/schemas/custom_fields.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from datetime import date
from enum import Enum
from typing import TypeAlias
from uuid import UUID

from pydantic import BaseModel, ConfigDict
Expand Down Expand Up @@ -48,6 +49,10 @@ class CustomFieldValue(CustomField):
field_id: UUID


CFValueType: TypeAlias = str | int | date | bool | float | None
CFNameType: TypeAlias = str


class CFV(BaseModel):
# custom field value
# `core_documents.id`
Expand All @@ -57,15 +62,15 @@ class CFV(BaseModel):
# `custom_fields.id`
custom_field_id: UUID
# `custom_fields.name`
name: str
name: CFNameType
# `custom_fields.type`
type: CustomFieldType
# `custom_fields.extra_data`
extra_data: str | None
# `custom_field_values.id`
custom_field_value_id: UUID | None = None
# `custom_field_values.value_text` or `custom_field_values.value_int` or ...
value: str | int | date | bool | float | None = None
value: CFValueType = None


class DocumentCFV(BaseModel):
Expand All @@ -77,4 +82,4 @@ class DocumentCFV(BaseModel):
# user_id: UUID
document_type_id: UUID | None = None
thumbnail_url: str | None = None
custom_fields: list[tuple[str, str]]
custom_fields: list[tuple[CFNameType, CFValueType]]
74 changes: 74 additions & 0 deletions tests/core/models/test_document.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,80 @@ def test_document_with_cfv_update_document_type_to_none(
assert db_session.execute(stmt).scalar() == 0


@pytest.mark.django_db(transaction=True)
def test_get_docs_by_type_basic(db_session: Session, make_document_receipt):
"""
`db.get_docs_by_type` must return all documents of specific type
regardless if they (documents) have or no associated custom field values.
In this scenario all returned documents must have custom fields with empty
values.
And number of returned items must be equal to the number of documents
of type "Grocery"
"""
doc_1: Document = make_document_receipt(title="receipt_1.pdf")
make_document_receipt(title="receipt_2.pdf")
user_id = doc_1.user.id
parent_id = doc_1.parent.id
type_id = doc_1.document_type.id

items: list[schemas.DocumentCFV] = db.get_docs_by_type(
db_session, type_id=type_id, user_id=user_id, ancestor_id=parent_id
)

assert len(items) == 2

for i in range(0, 2):
cf = dict(items[i].custom_fields)
assert cf["EffectiveDate"] is None
assert cf["Shop"] is None
assert cf["Total"] is None


@pytest.mark.django_db(transaction=True)
def test_get_docs_by_type_one_doc_with_nonempty_cfv(
db_session: Session, make_document_receipt
):
"""
`db.get_docs_by_type` must return all documents of specific type
regardless if they (documents) have or no associated custom field values.
In this scenario one of the returned documents has all CFVs set to
non empty values and the other one - to all values empty
"""
doc_1: Document = make_document_receipt(title="receipt_1.pdf")
make_document_receipt(title="receipt_2.pdf")
user_id = doc_1.user.id
parent_id = doc_1.parent.id
type_id = doc_1.document_type.id

# update all CFV of receipt_1.pdf to non-empty values
db.update_doc_cfv(
db_session,
document_id=doc_1.id,
custom_fields={"Shop": "rewe", "EffectiveDate": "2024-10-15", "Total": "15.63"},
)

items: list[schemas.DocumentCFV] = db.get_docs_by_type(
db_session, type_id=type_id, user_id=user_id, ancestor_id=parent_id
)

assert len(items) == 2

for i in range(0, 2):
cf = dict(items[i].custom_fields)
if items[i].id == doc_1.id:
# receipt_1.pdf has all cf set correctly
assert cf["EffectiveDate"] == Date(2024, 10, 15)
assert cf["Shop"] == "rewe"
assert cf["Total"] == 15.63
else:
# receipt_2.pdf has all cf set to None
assert cf["EffectiveDate"] is None
assert cf["Shop"] is None
assert cf["Total"] is None


def test_str2date():
assert str2date("2024-10-30") == datetime(2024, 10, 30).date()
assert str2date("2024-10-30 00:00:00") == datetime(2024, 10, 30).date()

0 comments on commit 9d45e2b

Please sign in to comment.