Skip to content

Commit

Permalink
refactoring: move document type into feature folder + add path_templa…
Browse files Browse the repository at this point in the history
…te field
  • Loading branch information
ciur authored Oct 18, 2024
1 parent f34f8bf commit 15bd548
Show file tree
Hide file tree
Showing 29 changed files with 325 additions and 114 deletions.
14 changes: 0 additions & 14 deletions papermerge/core/db/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,6 @@
update_doc_type,
)
from .doc_ver import get_doc_ver, get_last_doc_ver
from .document_types import (
create_document_type,
delete_document_type,
document_type_cf_count,
get_document_type,
get_document_types,
update_document_type,
)
from .engine import get_engine
from .exceptions import UserNotFound
from .folders import get_folder
Expand Down Expand Up @@ -75,15 +67,9 @@
"get_custom_field",
"delete_custom_field",
"update_custom_field",
"get_document_types",
"create_document_type",
"get_document_type",
"delete_document_type",
"update_document_type",
"update_doc_cfv",
"get_doc_cfv",
"update_doc_type",
"get_docs_by_type",
"document_type_cf_count",
"get_docs_count_by_type",
]
13 changes: 7 additions & 6 deletions papermerge/core/db/custom_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@
from sqlalchemy.orm import Session

from papermerge.core import schemas
from papermerge.core.db import models

from .models import CustomField

logger = logging.getLogger(__name__)


def get_custom_fields(session: Session) -> list[schemas.CustomField]:
stmt = select(models.CustomField)
stmt = select(CustomField)
db_items = session.scalars(stmt).all()
result = [schemas.CustomField.model_validate(db_item) for db_item in db_items]

Expand All @@ -25,7 +26,7 @@ def create_custom_field(
user_id: uuid.UUID,
extra_data: str | None = None,
) -> schemas.CustomField:
cfield = models.CustomField(
cfield = CustomField(
id=uuid.uuid4(),
name=name,
type=type,
Expand All @@ -42,14 +43,14 @@ def create_custom_field(
def get_custom_field(
session: Session, custom_field_id: uuid.UUID
) -> schemas.CustomField:
stmt = select(models.CustomField).where(models.CustomField.id == custom_field_id)
stmt = select(CustomField).where(CustomField.id == custom_field_id)
db_item = session.scalars(stmt).unique().one()
result = schemas.CustomField.model_validate(db_item)
return result


def delete_custom_field(session: Session, custom_field_id: uuid.UUID):
stmt = select(models.CustomField).where(models.CustomField.id == custom_field_id)
stmt = select(CustomField).where(CustomField.id == custom_field_id)
cfield = session.execute(stmt).scalars().one()
session.delete(cfield)
session.commit()
Expand All @@ -58,7 +59,7 @@ def delete_custom_field(session: Session, custom_field_id: uuid.UUID):
def update_custom_field(
session: Session, custom_field_id: uuid.UUID, attrs: schemas.UpdateCustomField
) -> schemas.CustomField:
stmt = select(models.CustomField).where(models.CustomField.id == custom_field_id)
stmt = select(CustomField).where(CustomField.id == custom_field_id)
cfield = session.execute(stmt).scalars().one()
session.add(cfield)

Expand Down
2 changes: 1 addition & 1 deletion papermerge/core/db/doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
Page,
)
from papermerge.core.exceptions import InvalidDateFormat
from papermerge.core.features.document_types.db import document_type_cf_count
from papermerge.core.types import OrderEnum

from .common import get_ancestors
from .document_types import document_type_cf_count


def str2date(value: str | None) -> Optional[datetime.date]:
Expand Down
14 changes: 1 addition & 13 deletions papermerge/core/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ class Document(Node):

ocr: Mapped[bool]
ocr_status: Mapped[str]
document_type: Mapped["DocumentType"] = relationship(
document_type: Mapped["DocumentType"] = relationship( # noqa: F821
primaryjoin="DocumentType.id == Document.document_type_id",
)
document_type_id: Mapped[UUID] = mapped_column(ForeignKey("document_types.id"))
Expand Down Expand Up @@ -276,15 +276,3 @@ class CustomFieldValue(Base):

def __repr__(self):
return f"CustomFieldValue(ID={self.id})"


class DocumentType(Base):
__tablename__ = "document_types"

id: Mapped[UUID] = mapped_column(primary_key=True)
name: Mapped[str]
custom_fields: Mapped[list["CustomField"]] = relationship(
secondary="document_type_custom_field"
)
user_id: Mapped[UUID] = mapped_column(ForeignKey("core_user.id"))
created_at: Mapped[datetime] = mapped_column(insert_default=func.now())
19 changes: 19 additions & 0 deletions papermerge/core/features/document_types/db/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from .api import (
create_document_type,
delete_document_type,
document_type_cf_count,
get_document_type,
get_document_types,
update_document_type,
)
from .orm import DocumentType

__all__ = [
"document_type_cf_count",
"create_document_type",
"get_document_types",
"get_document_type",
"delete_document_type",
"update_document_type",
"DocumentType",
]
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@
from sqlalchemy.orm import Session

from papermerge.core import schemas
from papermerge.core.db import models
from papermerge.core.db.models import CustomField

from .orm import DocumentType

logger = logging.getLogger(__name__)


def get_document_types(session: Session) -> list[schemas.DocumentType]:
stmt = select(models.DocumentType)
stmt = select(DocumentType)
db_items = session.scalars(stmt).all()
result = [schemas.DocumentType.model_validate(db_item) for db_item in db_items]

Expand All @@ -20,7 +22,7 @@ def get_document_types(session: Session) -> list[schemas.DocumentType]:

def document_type_cf_count(session: Session, document_type_id: uuid.UUID):
"""count number of custom fields associated to document type"""
stmt = select(models.DocumentType).where(models.DocumentType.id == document_type_id)
stmt = select(DocumentType).where(DocumentType.id == document_type_id)
dtype = session.scalars(stmt).one()
return len(dtype.custom_fields)

Expand All @@ -30,18 +32,20 @@ def create_document_type(
name: str,
user_id: uuid.UUID,
custom_field_ids: list[uuid.UUID] | None = None,
path_template: str | None = None,
) -> schemas.DocumentType:
if custom_field_ids is None:
cf_ids = []
else:
cf_ids = custom_field_ids

stmt = select(models.CustomField).where(models.CustomField.id.in_(cf_ids))
stmt = select(CustomField).where(CustomField.id.in_(cf_ids))
custom_fields = session.execute(stmt).scalars().all()
dtype = models.DocumentType(
dtype = DocumentType(
id=uuid.uuid4(),
name=name,
custom_fields=custom_fields,
path_template=path_template,
user_id=user_id,
)
session.add(dtype)
Expand All @@ -53,14 +57,14 @@ def create_document_type(
def get_document_type(
session: Session, document_type_id: uuid.UUID
) -> schemas.DocumentType:
stmt = select(models.DocumentType).where(models.DocumentType.id == document_type_id)
stmt = select(DocumentType).where(DocumentType.id == document_type_id)
db_item = session.scalars(stmt).unique().one()
result = schemas.DocumentType.model_validate(db_item)
return result


def delete_document_type(session: Session, document_type_id: uuid.UUID):
stmt = select(models.DocumentType).where(models.DocumentType.id == document_type_id)
stmt = select(DocumentType).where(DocumentType.id == document_type_id)
cfield = session.execute(stmt).scalars().one()
session.delete(cfield)
session.commit()
Expand All @@ -69,12 +73,10 @@ def delete_document_type(session: Session, document_type_id: uuid.UUID):
def update_document_type(
session: Session, document_type_id: uuid.UUID, attrs: schemas.UpdateDocumentType
) -> schemas.DocumentType:
stmt = select(models.DocumentType).where(models.DocumentType.id == document_type_id)
stmt = select(DocumentType).where(DocumentType.id == document_type_id)
doc_type = session.execute(stmt).scalars().one()

stmt = select(models.CustomField).where(
models.CustomField.id.in_(attrs.custom_field_ids)
)
stmt = select(CustomField).where(CustomField.id.in_(attrs.custom_field_ids))
custom_fields = session.execute(stmt).scalars().all()

if attrs.name:
Expand All @@ -83,6 +85,8 @@ def update_document_type(
if attrs.custom_field_ids:
doc_type.custom_fields = custom_fields

doc_type.path_template = attrs.path_template

session.add(doc_type)
session.commit()
result = schemas.DocumentType.model_validate(doc_type)
Expand Down
20 changes: 20 additions & 0 deletions papermerge/core/features/document_types/db/orm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from datetime import datetime
from uuid import UUID

from sqlalchemy import ForeignKey, func
from sqlalchemy.orm import Mapped, mapped_column, relationship

from papermerge.core.db.base import Base


class DocumentType(Base):
__tablename__ = "document_types"

id: Mapped[UUID] = mapped_column(primary_key=True)
name: Mapped[str]
path_template: Mapped[str]
custom_fields: Mapped[list["CustomField"]] = relationship( # noqa: F821
secondary="document_type_custom_field"
)
user_id: Mapped[UUID] = mapped_column(ForeignKey("core_user.id"))
created_at: Mapped[datetime] = mapped_column(insert_default=func.now())
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ class DocumentType(models.Model):
user = models.ForeignKey(
"User", related_name="document_types", on_delete=models.CASCADE
)
path_template = models.CharField(max_length=2048, null=True)
created_at = models.DateTimeField(
"created_at",
default=timezone.now,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
from fastapi import APIRouter, Depends, HTTPException, Security
from sqlalchemy.exc import NoResultFound

from papermerge.core import db, schemas, utils
from papermerge.core import schemas, utils
from papermerge.core.auth import get_current_user, scopes

from .common import OPEN_API_GENERIC_JSON_DETAIL
from .paginator import PaginatorGeneric, paginate
from .params import CommonQueryParams
from papermerge.core.db import Session, get_session
from papermerge.core.features.document_types import db
from papermerge.core.routers.common import OPEN_API_GENERIC_JSON_DETAIL
from papermerge.core.routers.paginator import PaginatorGeneric, paginate
from papermerge.core.routers.params import CommonQueryParams

router = APIRouter(
prefix="/document-types",
Expand All @@ -26,7 +27,7 @@ def get_document_types_without_pagination(
user: Annotated[
schemas.User, Security(get_current_user, scopes=[scopes.DOCUMENT_TYPE_VIEW])
],
db_session: db.Session = Depends(db.get_session),
db_session: Session = Depends(get_session),
):
"""Get all document types without pagination/filtering/sorting
Expand All @@ -44,7 +45,7 @@ def get_document_types(
schemas.User, Security(get_current_user, scopes=[scopes.CUSTOM_FIELD_VIEW])
],
params: CommonQueryParams = Depends(),
db_session: db.Session = Depends(db.get_session),
db_session: Session = Depends(get_session),
):
"""Get all (paginated) document types
Expand All @@ -60,7 +61,7 @@ def get_document_type(
user: Annotated[
schemas.User, Security(get_current_user, scopes=[scopes.DOCUMENT_TYPE_VIEW])
],
db_session: db.Session = Depends(db.get_session),
db_session: Session = Depends(get_session),
):
"""Get document type
Expand All @@ -80,7 +81,7 @@ def create_document_type(
user: Annotated[
schemas.User, Security(get_current_user, scopes=[scopes.DOCUMENT_TYPE_CREATE])
],
db_session: db.Session = Depends(db.get_session),
db_session: Session = Depends(get_session),
) -> schemas.DocumentType:
"""Creates document type
Expand All @@ -90,6 +91,7 @@ def create_document_type(
document_type = db.create_document_type(
db_session,
name=dtype.name,
path_template=dtype.path_template,
custom_field_ids=dtype.custom_field_ids,
user_id=user.id,
)
Expand Down Expand Up @@ -118,7 +120,7 @@ def delete_document_type(
user: Annotated[
schemas.User, Security(get_current_user, scopes=[scopes.DOCUMENT_TYPE_DELETE])
],
db_session: db.Session = Depends(db.get_session),
db_session: Session = Depends(get_session),
) -> None:
"""Deletes document type
Expand All @@ -140,7 +142,7 @@ def update_document_type(
cur_user: Annotated[
schemas.User, Security(get_current_user, scopes=[scopes.DOCUMENT_TYPE_UPDATE])
],
db_session: db.Session = Depends(db.get_session),
db_session: Session = Depends(get_session),
) -> schemas.DocumentType:
"""Updates document type
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@

from pydantic import BaseModel, ConfigDict

from .custom_fields import CustomField
from papermerge.core.schemas.custom_fields import CustomField


class DocumentType(BaseModel):
id: UUID
name: str
path_template: str | None = None
custom_fields: list[CustomField]

# Config
Expand All @@ -16,6 +17,7 @@ class DocumentType(BaseModel):

class CreateDocumentType(BaseModel):
name: str
path_template: str | None = None
custom_field_ids: list[UUID]

# Config
Expand All @@ -24,4 +26,5 @@ class CreateDocumentType(BaseModel):

class UpdateDocumentType(BaseModel):
name: str | None = None
path_template: str | None = None
custom_field_ids: list[UUID] | None = None
Loading

0 comments on commit 15bd548

Please sign in to comment.