Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Add GET endpoints for documents #547

Merged
merged 51 commits into from
Feb 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
a9ebd9c
Give option for database to return all documents
smokestacklightnin Jan 27, 2025
ce37b0c
Add `get_documents`
smokestacklightnin Jan 27, 2025
d0197f5
Merge remote-tracking branch 'upstream/main' into ui/enh/document-viewer
smokestacklightnin Jan 28, 2025
b970b96
Add `GET` endpoint for `/documents`
smokestacklightnin Jan 28, 2025
f33e8ea
Add `GET` endpoint for a specific document `/documents/{id}`
smokestacklightnin Jan 28, 2025
c68f874
Fix mypy error
smokestacklightnin Jan 28, 2025
e7bf467
Add `get_document` to engine for convenience
smokestacklightnin Jan 28, 2025
ce6a4f2
Clean up
smokestacklightnin Jan 28, 2025
c4bbcaf
Add support for MIME types in `core.Document`s
smokestacklightnin Jan 28, 2025
1ab6f58
Add `GET` `/documents/{id}/content` endpoint
smokestacklightnin Jan 28, 2025
e57950c
Call correct method
smokestacklightnin Jan 28, 2025
1bc57b0
Use the builtin `mimetypes` library instead of custom logic
smokestacklightnin Jan 28, 2025
54e20a3
Merge remote-tracking branch 'upstream/main' into ui/enh/document-viewer
smokestacklightnin Jan 28, 2025
f67997e
Add mime_type to `Document` schema
smokestacklightnin Jan 28, 2025
23898f0
Add MIME type to `Document` ORM object
smokestacklightnin Jan 28, 2025
e53b180
Add MIME type to ORM <> Schema converters and Core <> Schema converters
smokestacklightnin Jan 28, 2025
2cfa02a
Add `mime_type` to initializer for `LocalDocument`
smokestacklightnin Jan 28, 2025
32dc1ab
Remove unnecessary type conversion
smokestacklightnin Jan 29, 2025
32e11ba
Make code more concise
smokestacklightnin Jan 29, 2025
162a3ff
Enforce keyword arguments
smokestacklightnin Jan 29, 2025
e122713
Help expression scale
smokestacklightnin Jan 29, 2025
06b5ab7
Use `__getitem__` instead of `next(iter(...))`
smokestacklightnin Jan 29, 2025
a9bf0fd
Use traditional `if` statement instead of ternary operator
smokestacklightnin Jan 29, 2025
94ca3f9
Add empty `test_endpoints.py` file
smokestacklightnin Jan 30, 2025
fffd104
Prevent naming collisions
smokestacklightnin Jan 30, 2025
e199054
Add test for `GET documents` endpoint
smokestacklightnin Jan 31, 2025
773a030
Add test for `GET document`
smokestacklightnin Jan 31, 2025
d0cd74c
Add test for `GET` document content
smokestacklightnin Jan 31, 2025
e107d5d
Fix typo in `PUT` methods
smokestacklightnin Jan 31, 2025
72d007f
Clean up `raise_for_status()`
smokestacklightnin Jan 31, 2025
f4028fd
Use `__getitem__` instead of `next(iter(...))`
smokestacklightnin Feb 1, 2025
6abfb00
Remove unique names where appropriate
smokestacklightnin Feb 1, 2025
565ccab
Make sorting key not private
smokestacklightnin Feb 1, 2025
3f0000f
Add and use `upload_documents` to minimize repeated code in repeated …
smokestacklightnin Feb 1, 2025
22a31b6
Store document text content in a variable to be reused
smokestacklightnin Feb 1, 2025
d6db904
Use `upload_documents` in `test_components.py`
smokestacklightnin Feb 2, 2025
435a268
Fix typo
smokestacklightnin Feb 2, 2025
a6acfd6
Allow for specification of MIME types in `upload_documents`
smokestacklightnin Feb 3, 2025
4cfa644
Test with user-specified MIME types in `test_get_documents`
smokestacklightnin Feb 3, 2025
0959572
Make `mime_types` parametrization reusable across multiple tests
smokestacklightnin Feb 3, 2025
2f9f2b7
Test with user-specified MIME types in `test_get_document`
smokestacklightnin Feb 3, 2025
a3456b7
Test with user-specified MIME types in `test_get_document_content`
smokestacklightnin Feb 3, 2025
4d1b1d5
Add `mime_type` to `DocumentRegistration`
smokestacklightnin Feb 4, 2025
d5c6c16
Include `mime_type` when registering documents
smokestacklightnin Feb 4, 2025
6cfb136
Use `read` instead of `iter_lines`
smokestacklightnin Feb 4, 2025
5eafb6d
Remove redundant assertion of equality
smokestacklightnin Feb 4, 2025
249c487
Use `zip(..., strict=True)` to force arguments to be the same length
smokestacklightnin Feb 4, 2025
5fd702e
Remove assertion that should be part of other tests
smokestacklightnin Feb 4, 2025
83e7607
Test equality of bytes, rather than of strings
smokestacklightnin Feb 4, 2025
3a2d98e
Assert equal lengths instead of using `zip(..., strict=True)`
smokestacklightnin Feb 4, 2025
68f302e
Use `iter_lines` instead of receiving bytes
smokestacklightnin Feb 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion ragna/core/_document.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import abc
import io
import mimetypes
import uuid
from functools import cached_property
from pathlib import Path
Expand All @@ -25,11 +26,15 @@ def __init__(
name: str,
metadata: dict[str, Any],
handler: Optional[DocumentHandler] = None,
mime_type: str | None = None,
):
self.id = id or uuid.uuid4()
self.name = name
self.metadata = metadata
self.handler = handler or self.get_handler(name)
self.mime_type = (
mime_type or mimetypes.guess_type(name)[0] or "application/octet-stream"
)

@staticmethod
def supported_suffixes() -> set[str]:
Expand Down Expand Up @@ -76,8 +81,11 @@ def __init__(
name: str,
metadata: dict[str, Any],
handler: Optional[DocumentHandler] = None,
mime_type: str | None = None,
):
super().__init__(id=id, name=name, metadata=metadata, handler=handler)
super().__init__(
id=id, name=name, metadata=metadata, handler=handler, mime_type=mime_type
)
if "path" not in self.metadata:
metadata["path"] = str(ragna.local_root() / "documents" / str(self.id))

Expand Down
23 changes: 23 additions & 0 deletions ragna/deploy/_api.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import io
import uuid
from typing import Annotated, Any, AsyncIterator

Expand Down Expand Up @@ -40,6 +41,28 @@ async def content_stream() -> AsyncIterator[bytes]:
],
)

@router.get("/documents")
async def get_documents(user: UserDependency) -> list[schemas.Document]:
return engine.get_documents(user=user.name)

@router.get("/documents/{id}")
async def get_document(user: UserDependency, id: uuid.UUID) -> schemas.Document:
return engine.get_document(user=user.name, id=id)

@router.get("/documents/{id}/content")
async def get_document_content(
user: UserDependency, id: uuid.UUID
) -> StreamingResponse:
schema_document = engine.get_document(user=user.name, id=id)
core_document = engine._to_core.document(schema_document)
headers = {"Content-Disposition": f"inline; filename={schema_document.name}"}

return StreamingResponse(
io.BytesIO(core_document.read()),
media_type=core_document.mime_type,
headers=headers,
)

@router.get("/components")
def get_components() -> schemas.Components:
return engine.get_components()
Expand Down
22 changes: 13 additions & 9 deletions ragna/deploy/_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,24 +132,24 @@ def add_documents(
session.commit()

def _get_orm_documents(
self, session: Session, *, user: str, ids: Collection[uuid.UUID]
self, session: Session, *, user: str, ids: Collection[uuid.UUID] | None = None
) -> list[orm.Document]:
# FIXME also check if the user is allowed to access the documents
# FIXME: maybe just take the user id to avoid getting it twice in add_chat?
documents = (
session.execute(select(orm.Document).where(orm.Document.id.in_(ids)))
.scalars()
.all()
)
if len(documents) != len(ids):
expr = select(orm.Document)
if ids is not None:
expr = expr.where(orm.Document.id.in_(ids))
documents = session.execute(expr).scalars().all()

if (ids is not None) and (len(documents) != len(ids)):
raise RagnaException(
str(set(ids) - {document.id for document in documents})
)

return documents # type: ignore[no-any-return]

def get_documents(
self, session: Session, *, user: str, ids: Collection[uuid.UUID]
self, session: Session, *, user: str, ids: Collection[uuid.UUID] | None = None
) -> list[schemas.Document]:
return [
self._to_schema.document(document)
Expand Down Expand Up @@ -288,6 +288,7 @@ def document(
user_id=user_id,
name=document.name,
metadata_=document.metadata,
mime_type=document.mime_type,
)

def source(self, source: schemas.Source) -> orm.Source:
Expand Down Expand Up @@ -354,7 +355,10 @@ def api_key(self, api_key: orm.ApiKey) -> schemas.ApiKey:

def document(self, document: orm.Document) -> schemas.Document:
return schemas.Document(
id=document.id, name=document.name, metadata=document.metadata_
id=document.id,
name=document.name,
metadata=document.metadata_,
mime_type=document.mime_type,
)

def source(self, source: orm.Source) -> schemas.Source:
Expand Down
22 changes: 16 additions & 6 deletions ragna/deploy/_engine.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import secrets
import uuid
from typing import Any, AsyncIterator, Optional, cast
from typing import Any, AsyncIterator, Collection, Optional, cast

from fastapi import status as http_status_code

Expand Down Expand Up @@ -156,7 +156,9 @@ def register_documents(
# We create core.Document's first, because they might update the metadata
core_documents = [
self._config.document(
name=registration.name, metadata=registration.metadata
name=registration.name,
metadata=registration.metadata,
mime_type=registration.mime_type,
)
for registration in document_registrations
]
Expand All @@ -182,17 +184,23 @@ async def store_documents(

streams = dict(ids_and_streams)

with self._database.get_session() as session:
documents = self._database.get_documents(
session, user=user, ids=streams.keys()
)
documents = self.get_documents(user=user, ids=streams.keys())

for document in documents:
core_document = cast(
ragna.core.LocalDocument, self._to_core.document(document)
)
await core_document._write(streams[document.id])

def get_documents(
self, *, user: str, ids: Collection[uuid.UUID] | None = None
) -> list[schemas.Document]:
with self._database.get_session() as session:
return self._database.get_documents(session, user=user, ids=ids)

def get_document(self, *, user: str, id: uuid.UUID) -> schemas.Document:
return self.get_documents(user=user, ids=[id])[0]

def create_chat(
self, *, user: str, chat_creation: schemas.ChatCreation
) -> schemas.Chat:
Expand Down Expand Up @@ -280,6 +288,7 @@ def document(self, document: schemas.Document) -> core.Document:
id=document.id,
name=document.name,
metadata=document.metadata,
mime_type=document.mime_type,
)

def source(self, source: schemas.Source) -> core.Source:
Expand Down Expand Up @@ -328,6 +337,7 @@ def document(self, document: core.Document) -> schemas.Document:
id=document.id,
name=document.name,
metadata=document.metadata,
mime_type=document.mime_type,
)

def source(self, source: core.Source) -> schemas.Source:
Expand Down
1 change: 1 addition & 0 deletions ragna/deploy/_orm.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ class Document(Base):
# Mind the trailing underscore here. Unfortunately, this is necessary, because
# metadata without the underscore is reserved by SQLAlchemy
metadata_ = Column(Json, nullable=False)
mime_type = Column(types.String, nullable=False)
chats = relationship(
"Chat",
secondary=document_chat_association_table,
Expand Down
2 changes: 2 additions & 0 deletions ragna/deploy/_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,14 @@ class Components(BaseModel):
class DocumentRegistration(BaseModel):
name: str
metadata: dict[str, Any] = Field(default_factory=dict)
mime_type: str | None = None


class Document(BaseModel):
id: uuid.UUID = Field(default_factory=uuid.uuid4)
name: str
metadata: dict[str, Any]
mime_type: str


class Source(BaseModel):
Expand Down
20 changes: 6 additions & 14 deletions tests/deploy/api/test_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from ragna import assistants
from ragna.core import RagnaException
from ragna.deploy import Config
from tests.deploy.api.utils import upload_documents
from tests.deploy.utils import make_api_app, make_api_client


Expand Down Expand Up @@ -56,17 +57,8 @@ def test_unknown_component(tmp_local_root):
with open(document_path, "w") as file:
file.write("!\n")

with make_api_client(
config=Config(), ignore_unavailable_components=False
) as client:
document = (
client.post("/api/documents", json=[{"name": document_path.name}])
.raise_for_status()
.json()[0]
)

with open(document_path, "rb") as file:
client.put("/api/documents", files={"documents": (document["id"], file)})
with make_api_client(config=config, ignore_unavailable_components=False) as client:
document = upload_documents(client=client, document_paths=[document_path])[0]

response = client.post(
"/api/chats",
Expand All @@ -80,7 +72,7 @@ def test_unknown_component(tmp_local_root):
},
)

assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY

error = response.json()["error"]
assert "Unknown component" in error["message"]
error = response.json()["error"]
assert "Unknown component" in error["message"]
109 changes: 109 additions & 0 deletions tests/deploy/api/test_endpoints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import mimetypes

import pytest

from ragna.deploy import Config
from tests.deploy.api.utils import upload_documents
from tests.deploy.utils import make_api_client

_document_content_text = [
f"Needs more {needs_more_of}\n" for needs_more_of in ["reverb", "cowbell"]
]


mime_types = pytest.mark.parametrize(
("mime_type",),
[
(None,), # Let the mimetypes library decide
("text/markdown",),
("application/pdf",),
],
)


@mime_types
def test_get_documents(tmp_local_root, mime_type):
config = Config(local_root=tmp_local_root)

document_root = config.local_root / "documents"
document_root.mkdir()
document_paths = [
document_root / f"test{idx}.txt" for idx in range(len(_document_content_text))
]
for content, document_path in zip(_document_content_text, document_paths):
with open(document_path, "w") as file:
file.write(content)

with make_api_client(config=config, ignore_unavailable_components=False) as client:
documents = upload_documents(
client=client,
document_paths=document_paths,
mime_types=[mime_type for _ in document_paths],
)
response = client.get("/api/documents").raise_for_status()

# Sort the items in case they are retrieved in different orders
def sorting_key(d):
return d["id"]

assert sorted(documents, key=sorting_key) == sorted(
response.json(), key=sorting_key
)


@mime_types
def test_get_document(tmp_local_root, mime_type):
config = Config(local_root=tmp_local_root)

document_root = config.local_root / "documents"
document_root.mkdir()
document_path = document_root / "test.txt"
with open(document_path, "w") as file:
file.write(_document_content_text[0])

with make_api_client(config=config, ignore_unavailable_components=False) as client:
document = upload_documents(
client=client,
document_paths=[document_path],
mime_types=[mime_type],
)[0]
response = client.get(f"/api/documents/{document['id']}").raise_for_status()

assert document == response.json()


@mime_types
def test_get_document_content(tmp_local_root, mime_type):
config = Config(local_root=tmp_local_root)

document_root = config.local_root / "documents"
document_root.mkdir()
document_path = document_root / "test.txt"
document_content = _document_content_text[0]
with open(document_path, "w") as file:
file.write(document_content)

with make_api_client(config=config, ignore_unavailable_components=False) as client:
document = upload_documents(
client=client,
document_paths=[document_path],
mime_types=[mime_type],
)[0]

with client.stream(
"GET", f"/api/documents/{document['id']}/content"
) as response:
response_mime_type = response.headers["content-type"].split(";")[0]
received_lines = list(response.iter_lines())

assert received_lines == [document_content.replace("\n", "")]

assert (
document["mime_type"]
== response_mime_type
== (
mime_type
if mime_type is not None
else mimetypes.guess_type(document_path.name)[0]
)
)
37 changes: 37 additions & 0 deletions tests/deploy/api/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import contextlib


def upload_documents(*, client, document_paths, mime_types=None):
if mime_types is None:
mime_types = [None for _ in document_paths]
else:
assert len(mime_types) == len(document_paths)
documents = (
client.post(
"/api/documents",
json=[
{
"name": document_path.name,
"mime_type": mime_type,
}
for document_path, mime_type in zip(document_paths, mime_types)
],
)
.raise_for_status()
.json()
)

with contextlib.ExitStack() as stack:
files = [
stack.enter_context(open(document_path, "rb"))
for document_path in document_paths
]
client.put(
"/api/documents",
files=[
("documents", (document["id"], file))
for document, file in zip(documents, files)
],
)

return documents