diff --git a/ragna/deploy/_engine.py b/ragna/deploy/_engine.py index 6694d32c..c3647510 100644 --- a/ragna/deploy/_engine.py +++ b/ragna/deploy/_engine.py @@ -1,7 +1,10 @@ import secrets import uuid +from datetime import datetime from typing import Any, AsyncIterator, Optional, cast +import panel as pn +from emoji import emojize from fastapi import status as http_status_code import ragna @@ -232,6 +235,21 @@ def get_chat(self, *, user: str, id: uuid.UUID) -> schemas.Chat: with self._database.get_session() as session: return self._database.get_chat(session, user=user, id=id) + # This and `improve_message` were copied from the old [`ApiWrapper`](https://github.com/Quansight/ragna/issues/521). + # The interface they provide is open for discussion + async def get_improved_chats(self): + json_data = [ + chat.model_dump(mode="json") for chat in self.get_chats(user=pn.state.user) + ] + for chat in json_data: + chat["messages"] = [self._improve_message(msg) for msg in chat["messages"]] + return json_data + + def _improve_message(self, msg): + msg["timestamp"] = datetime.strptime(msg["timestamp"], "%Y-%m-%dT%H:%M:%S.%fZ") + msg["content"] = emojize(msg["content"], language="alias") + return msg + async def prepare_chat(self, *, user: str, id: uuid.UUID) -> schemas.Message: core_chat = self._to_core.chat(self.get_chat(user=user, id=id), user=user) core_message = await core_chat.prepare() @@ -265,6 +283,12 @@ async def answer_stream( session, chat=self._to_schema.chat(core_chat), user=user ) + async def answer_improved(self, chat_id, prompt): + async for message in self.answer_stream( + user=pn.state.user, chat_id=uuid.UUID(chat_id), prompt=prompt + ): + yield self._improve_message(message.model_dump(mode="json")) + def delete_chat(self, *, user: str, id: uuid.UUID) -> None: with self._database.get_session() as session: self._database.delete_chat(session, user=user, id=id) @@ -375,3 +399,20 @@ def chat(self, chat: core.Chat) -> schemas.Chat: messages=[self.message(message) for message in chat._messages], prepared=chat._prepared, ) + + async def start_and_prepare( + self, name, input, corpus_name, source_storage, assistant, params + ): + chat = self.create_chat( + user=pn.state.user, + chat_creation=schemas.ChatCreation( + name=name, + input=input, + source_storage=source_storage, + assistant=assistant, + corpus_name=corpus_name, + params=params, + ), + ) + await self._engine.prepare_chat(user=pn.state.user, id=chat.id) + return str(chat.id) diff --git a/ragna/deploy/_ui/api_wrapper.py b/ragna/deploy/_ui/api_wrapper.py deleted file mode 100644 index 2f8f37b2..00000000 --- a/ragna/deploy/_ui/api_wrapper.py +++ /dev/null @@ -1,62 +0,0 @@ -import uuid -from datetime import datetime - -import emoji -import panel as pn -import param - -from ragna.deploy import _schemas as schemas -from ragna.deploy._engine import Engine - - -class ApiWrapper(param.Parameterized): - def __init__(self, engine: Engine): - super().__init__() - self._user = pn.state.user - self._engine = engine - - async def get_corpus_names(self): - return await self._engine.get_corpuses() - - async def get_corpus_metadata(self): - return await self._engine.get_corpus_metadata() - - async def get_chats(self): - json_data = [ - chat.model_dump(mode="json") - for chat in self._engine.get_chats(user=self._user) - ] - for chat in json_data: - chat["messages"] = [self.improve_message(msg) for msg in chat["messages"]] - return json_data - - async def answer(self, chat_id, prompt): - async for message in self._engine.answer_stream( - user=self._user, chat_id=uuid.UUID(chat_id), prompt=prompt - ): - yield self.improve_message(message.model_dump(mode="json")) - - def get_components(self): - return self._engine.get_components() - - async def start_and_prepare( - self, name, input, corpus_name, source_storage, assistant, params - ): - chat = self._engine.create_chat( - user=self._user, - chat_creation=schemas.ChatCreation( - name=name, - input=input, - source_storage=source_storage, - assistant=assistant, - corpus_name=corpus_name, - params=params, - ), - ) - await self._engine.prepare_chat(user=self._user, id=chat.id) - return str(chat.id) - - def improve_message(self, msg): - msg["timestamp"] = datetime.strptime(msg["timestamp"], "%Y-%m-%dT%H:%M:%S.%fZ") - msg["content"] = emoji.emojize(msg["content"], language="alias") - return msg diff --git a/ragna/deploy/_ui/app.py b/ragna/deploy/_ui/app.py index 4163c378..aa691311 100644 --- a/ragna/deploy/_ui/app.py +++ b/ragna/deploy/_ui/app.py @@ -5,7 +5,6 @@ from . import js from . import styles as ui -from .api_wrapper import ApiWrapper from .main_page import MainPage pn.extension( @@ -68,10 +67,8 @@ def get_template(self): return template def index_page(self): - api_wrapper = ApiWrapper(self._engine) - template = self.get_template() - main_page = MainPage(api_wrapper=api_wrapper, template=template) + main_page = MainPage(engine=self._engine, template=template) template.main.append(main_page) return template diff --git a/ragna/deploy/_ui/central_view.py b/ragna/deploy/_ui/central_view.py index 147ee686..557a83bc 100644 --- a/ragna/deploy/_ui/central_view.py +++ b/ragna/deploy/_ui/central_view.py @@ -167,12 +167,12 @@ def _build_message(self, *args, **kwargs) -> Optional[RagnaChatMessage]: class CentralView(pn.viewable.Viewer): current_chat = param.ClassSelector(class_=dict, default=None) - def __init__(self, api_wrapper, **params): + def __init__(self, engine, **params): super().__init__(**params) # FIXME: make this dynamic from the login self.user = "" - self.api_wrapper = api_wrapper + self._engine = engine self.chat_info_button = pn.widgets.Button( # The name will be filled at runtime in self.header name="", @@ -310,7 +310,9 @@ async def chat_callback( self, content: str, user: str, instance: pn.chat.ChatInterface ): try: - answer_stream = self.api_wrapper.answer(self.current_chat["id"], content) + answer_stream = self._engine.answer_improved( + self.current_chat["id"], content + ) answer = await anext(answer_stream) message = RagnaChatMessage( diff --git a/ragna/deploy/_ui/left_sidebar.py b/ragna/deploy/_ui/left_sidebar.py index b55ce1eb..4d2ca754 100644 --- a/ragna/deploy/_ui/left_sidebar.py +++ b/ragna/deploy/_ui/left_sidebar.py @@ -11,10 +11,10 @@ class LeftSidebar(pn.viewable.Viewer): current_chat_id = param.String(default=None) refresh_counter = param.Integer(default=0) - def __init__(self, api_wrapper, **params): + def __init__(self, engine, **params): super().__init__(**params) - self.api_wrapper = api_wrapper + self._engine = engine self.on_click_chat = None self.on_click_new_chat = None @@ -105,7 +105,7 @@ def __panel__(self): + self.chat_buttons + [ pn.layout.VSpacer(), - pn.pane.HTML(f"user: {self.api_wrapper._user}"), + pn.pane.HTML(f"user: {pn.state.user}"), pn.pane.HTML(f"version: {ragna_version}"), # self.footer() ] diff --git a/ragna/deploy/_ui/main_page.py b/ragna/deploy/_ui/main_page.py index 7e4822ae..020c0258 100644 --- a/ragna/deploy/_ui/main_page.py +++ b/ragna/deploy/_ui/main_page.py @@ -13,9 +13,9 @@ class MainPage(pn.viewable.Viewer, param.Parameterized): current_chat_id = param.String(default=None) chats = param.List(default=None) - def __init__(self, api_wrapper, template): + def __init__(self, engine, template): super().__init__() - self.api_wrapper = api_wrapper + self._engine = engine self.template = template self.components = None @@ -23,12 +23,12 @@ def __init__(self, api_wrapper, template): self.corpus_names = None self.modal = None - self.central_view = CentralView(api_wrapper=self.api_wrapper) + self.central_view = CentralView(engine=self._engine) self.central_view.on_click_chat_info = ( lambda event, title, content: self.show_right_sidebar(title, content) ) - self.left_sidebar = LeftSidebar(api_wrapper=self.api_wrapper) + self.left_sidebar = LeftSidebar(engine=self._engine) self.left_sidebar.on_click_chat = self.on_click_chat self.left_sidebar.on_click_new_chat = self.open_modal @@ -41,10 +41,10 @@ def __init__(self, api_wrapper, template): ) async def refresh_data(self): - self.chats = await self.api_wrapper.get_chats() - self.components = self.api_wrapper.get_components() - self.corpus_metadata = await self.api_wrapper.get_corpus_metadata() - self.corpus_names = await self.api_wrapper.get_corpus_names() + self.chats = await self._engine.get_improved_chats() + self.components = self._engine.get_components() + self.corpus_metadata = await self._engine.get_corpus_metadata() + self.corpus_names = await self._engine.get_corpuses() @param.depends("chats", watch=True) def after_update_chats(self): @@ -73,7 +73,7 @@ async def open_modal(self, event): await self.refresh_data() self.modal = ModalConfiguration( - api_wrapper=self.api_wrapper, + engine=self._engine, components=self.components, corpus_metadata=self.corpus_metadata, corpus_names=self.corpus_names, diff --git a/ragna/deploy/_ui/modal_configuration.py b/ragna/deploy/_ui/modal_configuration.py index 6acf3a33..f14da10d 100644 --- a/ragna/deploy/_ui/modal_configuration.py +++ b/ragna/deploy/_ui/modal_configuration.py @@ -89,12 +89,10 @@ class ModalConfiguration(pn.viewable.Viewer): error = param.Boolean(default=False) - def __init__( - self, api_wrapper, components, corpus_names, corpus_metadata, **params - ): + def __init__(self, engine, components, corpus_names, corpus_metadata, **params): super().__init__(chat_name=get_default_chat_name(), **params) - self.api_wrapper = api_wrapper + self._engine = engine self.corpus_names = corpus_names self.corpus_metadata = corpus_metadata @@ -105,7 +103,7 @@ def __init__( self.document_uploader = pn.widgets.FileInput( multiple=True, css_classes=["file-input"], - accept=",".join(self.api_wrapper.get_components().documents), + accept=",".join(self._engine.get_components().documents), ) # Most widgets (including those that use from_param) should be placed after the super init call @@ -158,15 +156,15 @@ async def did_click_on_start_chat_button(self, event): return self.start_chat_button.disabled = True - documents = self.api_wrapper._engine.register_documents( - user=self.api_wrapper._user, + documents = self._engine.register_documents( + user=pn.state.user, document_registrations=[ schemas.DocumentRegistration(name=name) for name in self.document_uploader.filename ], ) - if self.api_wrapper._engine.supports_store_documents: + if self._engine.supports_store_documents: def make_content_stream(data: bytes) -> AsyncIterator[bytes]: async def content_stream() -> AsyncIterator[bytes]: @@ -174,8 +172,8 @@ async def content_stream() -> AsyncIterator[bytes]: return content_stream() - await self.api_wrapper._engine.store_documents( - user=self.api_wrapper._user, + await self._engine.store_documents( + user=pn.state.user, ids_and_streams=[ (document.id, make_content_stream(data)) for document, data in zip( @@ -207,7 +205,7 @@ async def did_finish_upload(self, input, corpus_name=None): corpus_name = self.corpus_name_input.value try: - new_chat_id = await self.api_wrapper.start_and_prepare( + new_chat_id = await self._engine.start_and_prepare( name=self.chat_name, input=input, corpus_name=corpus_name, @@ -249,7 +247,7 @@ def change_upload_files_label(self, mode="normal"): def create_config(self, components): if self.config is None: # Retrieve the components from the API and build a config object - components = self.api_wrapper.get_components() + components = self._engine.get_components() # TODO : use the components to set up the default values for the various params config = ChatConfig()