diff --git a/basilisk/conversation/__init__.py b/basilisk/conversation/__init__.py index 88073e25..0d1a93a4 100644 --- a/basilisk/conversation/__init__.py +++ b/basilisk/conversation/__init__.py @@ -1,3 +1,12 @@ +from .attached_file import ( + URL_PATTERN, + AttachmentFile, + AttachmentFileTypes, + ImageFile, + NotImageError, + get_mime_type, + parse_supported_attachment_formats, +) from .conversation_helper import PROMPT_TITLE from .conversation_model import ( Conversation, @@ -5,16 +14,19 @@ MessageBlock, MessageRoleEnum, ) -from .image_model import URL_PATTERN, ImageFile, ImageFileTypes, NotImageError __all__ = [ + "AttachmentFile", + "AttachmentFileTypes", "Conversation", + "get_mime_type", "ImageFile", "ImageFileTypes", "Message", "MessageBlock", "MessageRoleEnum", "NotImageError", + "parse_supported_attachment_formats", "PROMPT_TITLE", "URL_PATTERN", ] diff --git a/basilisk/conversation/image_model.py b/basilisk/conversation/attached_file.py similarity index 76% rename from basilisk/conversation/image_model.py rename to basilisk/conversation/attached_file.py index 2d5c5ec3..b16bd82f 100644 --- a/basilisk/conversation/image_model.py +++ b/basilisk/conversation/attached_file.py @@ -86,20 +86,44 @@ def resize_image( return True -class ImageFileTypes(Enum): +def parse_supported_attachment_formats( + supported_attachment_formats: set[str], +) -> str: + """ + Parse the supported attachment formats into a wildcard string for use in file dialogs. + """ + wildcard_parts = [] + for mime_type in sorted(supported_attachment_formats): + exts = mimetypes.guess_all_extensions(mime_type) + if exts: + log.debug(f"Adding wildcard for MIME type {mime_type}: {exts}") + wildcard_parts.append("*" + ";*".join(exts)) + else: + log.warning(f"No extensions found for MIME type {mime_type}") + + wildcard = ";".join(wildcard_parts) + return wildcard + + +def get_mime_type(path: str) -> str | None: + """ + Get the MIME type of a file. + """ + return mimetypes.guess_type(path)[0] + + +class AttachmentFileTypes(Enum): UNKNOWN = "unknown" - IMAGE_LOCAL = "local" - IMAGE_MEMORY = "memory" - IMAGE_URL = "http" + LOCAL = "local" + MEMORY = "memory" + URL = "http" @classmethod - def _missing_(cls, value: object) -> ImageFileTypes: - if isinstance(value, str) and value.lower() == "data": - return cls.IMAGE_URL - if isinstance(value, str) and value.lower() == "https": - return cls.IMAGE_URL + def _missing_(cls, value: object) -> AttachmentFileTypes: + if isinstance(value, str) and value.lower() in ("data", "https"): + return cls.URL if isinstance(value, str) and value.lower() == "zip": - return cls.IMAGE_LOCAL + return cls.LOCAL return cls.UNKNOWN @@ -107,11 +131,84 @@ class NotImageError(ValueError): pass -class ImageFile(BaseModel): +class AttachmentFile(BaseModel): location: PydanticUPath name: str | None = None description: str | None = None size: int | None = None + + def __init__(self, /, **data: Any) -> None: + super().__init__(**data) + if not self.name: + self.name = self._get_name() + self.size = self._get_size() + + @property + def type(self) -> AttachmentFileTypes: + return AttachmentFileTypes(self.location.protocol) + + def _get_name(self) -> str: + return self.location.name + + def _get_size(self) -> int | None: + if self.type == AttachmentFileTypes.URL: + return None + return self.location.stat().st_size + + @property + def display_size(self) -> str: + size = self.size + if size is None: + return _("Unknown") + if size < 1024: + return f"{size} B" + if size < 1024 * 1024: + return f"{size / 1024:.2f} KB" + return f"{size / 1024 / 1024:.2f} MB" + + @property + def send_location(self) -> UPath: + return self.location + + @property + def mime_type(self) -> str | None: + if self.type == AttachmentFileTypes.URL: + return None + mime_type, _ = mimetypes.guess_type(self.send_location) + return mime_type + + @property + def display_location(self): + location = str(self.location) + if location.startswith("data:"): + location = f"{location[:50]}...{location[-10:]}" + return location + + @staticmethod + def remove_location(location: UPath): + log.debug(f"Removing image at {location}") + try: + fs = location.fs + fs.rm(location.path) + except Exception as e: + log.error(f"Error deleting image at {location}: {e}") + + def read_as_str(self): + with self.location.open(mode="r") as file: + return file.read() + + def encode_base64(self) -> str: + with self.location.open(mode="rb") as file: + return base64.b64encode(file.read()).decode("utf-8") + + def __del__(self): + if self.type == AttachmentFileTypes.URL: + return + if self.type == AttachmentFileTypes.MEMORY: + self.remove_location(self.location) + + +class ImageFile(AttachmentFile): dimensions: tuple[int, int] | None = None resize_location: PydanticUPath | None = Field(default=None, exclude=True) @@ -131,7 +228,7 @@ def build_from_url(cls, url: str) -> ImageFile: dimensions = get_image_dimensions(BytesIO(r.content)) return cls( location=url, - type=ImageFileTypes.IMAGE_URL, + type=AttachmentFileTypes.URL, size=size, description=content_type, dimensions=dimensions, @@ -169,39 +266,17 @@ def validate_location( def __init__(self, /, **data: Any) -> None: super().__init__(**data) - if not self.name: - self.name = self._get_name() - self.size = self._get_size() if not self.dimensions: self.dimensions = self._get_dimensions() __init__.__pydantic_base_init__ = True @property - def type(self) -> ImageFileTypes: - return ImageFileTypes(self.location.protocol) - - def _get_name(self) -> str: - return self.location.name - - def _get_size(self) -> int | None: - if self.type == ImageFileTypes.IMAGE_URL: - return None - return self.location.stat().st_size - - @property - def display_size(self) -> str: - size = self.size - if size is None: - return _("Unknown") - if size < 1024: - return f"{size} B" - if size < 1024 * 1024: - return f"{size / 1024:.2f} KB" - return f"{size / 1024 / 1024:.2f} MB" + def send_location(self) -> UPath: + return self.resize_location or self.location def _get_dimensions(self) -> tuple[int, int] | None: - if self.type == ImageFileTypes.IMAGE_URL: + if self.type == AttachmentFileTypes.URL: return None with self.location.open(mode="rb") as image_file: return get_image_dimensions(image_file) @@ -216,7 +291,7 @@ def display_dimensions(self) -> str: def resize( self, conv_folder: UPath, max_width: int, max_height: int, quality: int ): - if ImageFileTypes.IMAGE_URL == self.type: + if AttachmentFileTypes.URL == self.type: return log.debug("Resizing image") resize_location = conv_folder.joinpath( @@ -234,10 +309,6 @@ def resize( ) self.resize_location = resize_location if success else None - @property - def send_location(self) -> UPath: - return self.resize_location or self.location - @measure_time def encode_image(self) -> str: if self.size and self.size > 1024 * 1024 * 1024: @@ -247,29 +318,15 @@ def encode_image(self) -> str: with self.send_location.open(mode="rb") as image_file: return base64.b64encode(image_file.read()).decode("utf-8") - @property - def mime_type(self) -> str | None: - if self.type == ImageFileTypes.IMAGE_URL: - return None - mime_type, _ = mimetypes.guess_type(self.send_location) - return mime_type - @property def url(self) -> str: - if not isinstance(self.type, ImageFileTypes): + if not isinstance(self.type, AttachmentFileTypes): raise ValueError("Invalid image type") - if self.type == ImageFileTypes.IMAGE_URL: + if self.type == AttachmentFileTypes.URL: return str(self.location) base64_image = self.encode_image() return f"data:{self.mime_type};base64,{base64_image}" - @property - def display_location(self): - location = str(self.location) - if location.startswith("data:image/"): - location = f"{location[:50]}...{location[-10:]}" - return location - @staticmethod def remove_location(location: UPath): log.debug(f"Removing image at {location}") @@ -280,9 +337,6 @@ def remove_location(location: UPath): log.error(f"Error deleting image at {location}: {e}") def __del__(self): - if self.type == ImageFileTypes.IMAGE_URL: - return if self.resize_location: self.remove_location(self.resize_location) - if self.type == ImageFileTypes.IMAGE_MEMORY: - self.remove_location(self.location) + super().__del__() diff --git a/basilisk/conversation/conversation_helper.py b/basilisk/conversation/conversation_helper.py index 175384cf..862b82ee 100644 --- a/basilisk/conversation/conversation_helper.py +++ b/basilisk/conversation/conversation_helper.py @@ -11,7 +11,7 @@ from basilisk.config import conf from basilisk.decorators import measure_time -from .image_model import ImageFile, ImageFileTypes +from .attached_file import AttachmentFile, AttachmentFileTypes, ImageFile if TYPE_CHECKING: from .conversation_model import Conversation @@ -23,11 +23,13 @@ def save_attachments( - attachments: list[ImageFile], attachment_path: str, fs: ZipFileSystem + attachments: list[AttachmentFile | ImageFile], + attachment_path: str, + fs: ZipFileSystem, ): attachment_mapping = {} for attachment in attachments: - if attachment.type == ImageFileTypes.IMAGE_URL: + if attachment.type == AttachmentFileTypes.URL: continue new_location = f"{attachment_path}/{attachment.location.name}" with attachment.location.open(mode="rb") as attachment_file: @@ -56,7 +58,7 @@ def create_conv_main_file(conversation: Conversation, fs: ZipFileSystem): def restore_attachments(attachments: list[ImageFile], storage_path: UPath): for attachment in attachments: - if attachment.type == ImageFileTypes.IMAGE_URL: + if attachment.type == AttachmentFileTypes.URL: continue new_path = storage_path / attachment.location.name with attachment.location.open(mode="rb") as attachment_file: diff --git a/basilisk/conversation/conversation_model.py b/basilisk/conversation/conversation_model.py index 77c3bcf2..ca7b0447 100644 --- a/basilisk/conversation/conversation_model.py +++ b/basilisk/conversation/conversation_model.py @@ -8,8 +8,8 @@ from basilisk.provider_ai_model import AIModelInfo +from .attached_file import AttachmentFile, ImageFile from .conversation_helper import create_bskc_file, open_bskc_file -from .image_model import ImageFile class MessageRoleEnum(Enum): @@ -21,7 +21,7 @@ class MessageRoleEnum(Enum): class Message(BaseModel): role: MessageRoleEnum content: str - attachments: list[ImageFile] | None = Field(default=None) + attachments: list[AttachmentFile | ImageFile] | None = Field(default=None) class MessageBlock(BaseModel): diff --git a/basilisk/gui/conversation_tab.py b/basilisk/gui/conversation_tab.py index 24896848..e9340c61 100644 --- a/basilisk/gui/conversation_tab.py +++ b/basilisk/gui/conversation_tab.py @@ -20,12 +20,15 @@ from basilisk.conversation import ( PROMPT_TITLE, URL_PATTERN, + AttachmentFile, Conversation, ImageFile, Message, MessageBlock, MessageRoleEnum, NotImageError, + get_mime_type, + parse_supported_attachment_formats, ) from basilisk.decorators import ensure_no_task_running from basilisk.message_segment_manager import ( @@ -39,6 +42,7 @@ from .base_conversation import BaseConversation from .html_view_window import show_html_view_window +from .read_only_message_dialog import ReadOnlyMessageDialog from .search_dialog import SearchDialog, SearchDirection if TYPE_CHECKING: @@ -100,7 +104,7 @@ def __init__( self.bskc_path = bskc_path self.conv_storage_path = conv_storage_path or self.conv_storage_path() self.conversation = conversation or Conversation() - self.image_files: list[ImageFile] = [] + self.attachment_files: list[AttachmentFile | ImageFile] = [] self.last_time = 0 self.message_segment_manager = MessageSegmentManager() self.recording_thread: Optional[RecordingThread] = None @@ -159,26 +163,26 @@ def init_ui(self): sizer.Add(self.prompt, proportion=1, flag=wx.EXPAND) self.prompt.SetFocus() - self.images_list_label = wx.StaticText( + self.attachments_list_label = wx.StaticText( self, # Translators: This is a label for models in the main window - label=_("&Images:"), + label=_("&Attachments:"), ) - sizer.Add(self.images_list_label, proportion=0, flag=wx.EXPAND) - self.images_list = wx.ListCtrl( + sizer.Add(self.attachments_list_label, proportion=0, flag=wx.EXPAND) + self.attachments_list = wx.ListCtrl( self, size=(800, 100), style=wx.LC_REPORT ) - self.images_list.Bind(wx.EVT_CONTEXT_MENU, self.on_images_context_menu) - self.images_list.Bind(wx.EVT_KEY_DOWN, self.on_images_key_down) - self.images_list.InsertColumn(0, _("Name")) - self.images_list.InsertColumn(1, _("Size")) - self.images_list.InsertColumn(2, _("Dimensions")) - self.images_list.InsertColumn(3, _("Path")) - self.images_list.SetColumnWidth(0, 200) - self.images_list.SetColumnWidth(1, 100) - self.images_list.SetColumnWidth(2, 100) - self.images_list.SetColumnWidth(3, 200) - sizer.Add(self.images_list, proportion=0, flag=wx.ALL | wx.EXPAND) + self.attachments_list.Bind( + wx.EVT_CONTEXT_MENU, self.on_attachments_context_menu + ) + self.attachments_list.Bind( + wx.EVT_KEY_DOWN, self.on_attachments_key_down + ) + self.attachments_list.InsertColumn(0, _("Name")) + self.attachments_list.InsertColumn(1, _("Size")) + self.attachments_list.SetColumnWidth(0, 200) + self.attachments_list.SetColumnWidth(1, 100) + sizer.Add(self.attachments_list, proportion=0, flag=wx.ALL | wx.EXPAND) label = self.create_model_widget() sizer.Add(label, proportion=0, flag=wx.EXPAND) sizer.Add(self.model_list, proportion=0, flag=wx.ALL | wx.EXPAND) @@ -237,6 +241,7 @@ def init_ui(self): self.Bind(wx.EVT_CHAR_HOOK, self.on_char_hook) def init_data(self, profile: Optional[config.ConversationProfile]): + self.refresh_attachments_list() self.apply_profile(profile, True) self.refresh_messages(need_clear=False) @@ -267,16 +272,20 @@ def on_account_change(self, event: wx.CommandEvent): ProviderCapability.STT in account.provider.engine_cls.capabilities ) - def on_images_context_menu(self, event: wx.ContextMenuEvent): - selected = self.images_list.GetFirstSelected() + def on_attachments_context_menu(self, event: wx.ContextMenuEvent): + selected = self.attachments_list.GetFirstSelected() menu = wx.Menu() if selected != wx.NOT_FOUND: + item = wx.MenuItem(menu, wx.ID_ANY, _("Show details") + " Enter") + menu.Append(item) + self.Bind(wx.EVT_MENU, self.on_show_attachment_details, item) + item = wx.MenuItem( menu, wx.ID_ANY, _("Remove selected image") + " (Shift+Del)" ) menu.Append(item) - self.Bind(wx.EVT_MENU, self.on_images_remove, item) + self.Bind(wx.EVT_MENU, self.on_attachments_remove, item) item = wx.MenuItem( menu, wx.ID_ANY, _("Copy image URL") + " (Ctrl+C)" @@ -287,37 +296,42 @@ def on_images_context_menu(self, event: wx.ContextMenuEvent): menu, wx.ID_ANY, _("Paste (image or text)") + " (Ctrl+V)" ) menu.Append(item) - self.Bind(wx.EVT_MENU, self.on_image_paste, item) + self.Bind(wx.EVT_MENU, self.on_attachments_paste, item) item = wx.MenuItem(menu, wx.ID_ANY, _("Add image files...")) menu.Append(item) - self.Bind(wx.EVT_MENU, self.add_image_files, item) + self.Bind(wx.EVT_MENU, self.add_attachments_dlg, item) item = wx.MenuItem(menu, wx.ID_ANY, _("Add image URL...")) menu.Append(item) self.Bind(wx.EVT_MENU, self.add_image_url_dlg, item) - self.images_list.PopupMenu(menu) + self.attachments_list.PopupMenu(menu) menu.Destroy() - def on_images_key_down(self, event: wx.KeyEvent): + def on_attachments_key_down(self, event: wx.KeyEvent): key_code = event.GetKeyCode() modifiers = event.GetModifiers() if modifiers == wx.MOD_CONTROL and key_code == ord("C"): self.on_copy_image_url(None) if modifiers == wx.MOD_CONTROL and key_code == ord("V"): - self.on_image_paste(None) + self.on_attachments_paste(None) if modifiers == wx.MOD_NONE and key_code == wx.WXK_DELETE: - self.on_images_remove(None) + self.on_attachments_remove(None) + if modifiers == wx.MOD_NONE and key_code in ( + wx.WXK_RETURN, + wx.WXK_NUMPAD_ENTER, + ): + self.on_show_attachment_details(None) event.Skip() - def on_image_paste(self, event: wx.CommandEvent): + def on_attachments_paste(self, event: wx.CommandEvent): with wx.TheClipboard as clipboard: if clipboard.IsSupported(wx.DataFormat(wx.DF_FILENAME)): log.debug("Pasting files from clipboard") file_data = wx.FileDataObject() clipboard.GetData(file_data) paths = file_data.GetFilenames() - self.add_images(paths) + self.add_attachments(paths) elif clipboard.IsSupported(wx.DataFormat(wx.DF_TEXT)): log.debug("Pasting text from clipboard") text_data = wx.TextDataObject() @@ -344,22 +358,34 @@ def on_image_paste(self, event: wx.CommandEvent): ) with path.open("wb") as f: img.SaveFile(f, wx.BITMAP_TYPE_PNG) - self.add_images([ImageFile(location=path)]) + self.add_attachments([ImageFile(location=path)]) else: log.info("Unsupported clipboard data") - def add_image_files(self, event: wx.CommandEvent = None): + def add_attachments_dlg(self, event: wx.CommandEvent = None): + wildcard = parse_supported_attachment_formats( + self.current_engine.supported_attachment_formats + ) + if not wildcard: + wx.MessageBox( + # Translators: This message is displayed when there are no supported attachment formats. + _("This provider does not support any attachment formats."), + _("Error"), + wx.OK | wx.ICON_ERROR, + ) + return + wildcard = _("All supported formats") + f" ({wildcard})|{wildcard}" + file_dialog = wx.FileDialog( self, - message=_("Select one or more image files"), + message=_("Select one or more files to attach"), style=wx.FD_OPEN | wx.FD_FILE_MUST_EXIST | wx.FD_MULTIPLE, - wildcard=_("Image files") - + " (*.png;*.jpeg;*.jpg;*.gif)|*.png;*.jpeg;*.jpg;*.gif", + wildcard=wildcard, ) if file_dialog.ShowModal() == wx.ID_OK: paths = file_dialog.GetPaths() - self.add_images(paths) + self.add_attachments(paths) file_dialog.Destroy() def add_image_url_dlg(self, event: wx.CommandEvent = None): @@ -397,7 +423,7 @@ def force_image_from_url(self, url: str, content_type: str): ) if force_add == wx.YES: log.info("Forcing image addition") - self.add_image_files([ImageFile(location=url)]) + self.add_attachments([ImageFile(location=url)]) def add_image_from_url(self, url: str): image_file = None @@ -424,7 +450,7 @@ def add_image_from_url(self, url: str): wx.OK | wx.ICON_ERROR, ) return - wx.CallAfter(self.add_images, [image_file]) + wx.CallAfter(self.add_attachments, [image_file]) self.task = None @@ -435,26 +461,48 @@ def add_image_url_thread(self, url: str): ) self.task.start() - def on_images_remove(self, vent: wx.CommandEvent): - selection = self.images_list.GetFirstSelected() + def on_show_attachment_details(self, event: wx.CommandEvent): + selected = self.attachments_list.GetFirstSelected() + if selected == wx.NOT_FOUND: + return + image_file = self.attachment_files[selected] + details = { + _("Name"): image_file.name, + _("Size"): image_file.display_size, + _("Location"): image_file.location, + } + mime_type = image_file.mime_type + if mime_type: + details[_("MIME type")] = mime_type + if mime_type.startswith("image/"): + details[_("Dimensions")] = image_file.display_dimensions + details_str = "\n".join( + _("%s: %s") % (k, v) for k, v in details.items() + ) + ReadOnlyMessageDialog( + self, _("Attachment details"), details_str + ).ShowModal() + + def on_attachments_remove(self, vent: wx.CommandEvent): + selection = self.attachments_list.GetFirstSelected() if selection == wx.NOT_FOUND: return - self.image_files.pop(selection) - self.refresh_images_list() - if selection >= self.images_list.GetItemCount(): + self.attachment_files.pop(selection) + self.refresh_attachments_list() + if selection >= self.attachments_list.GetItemCount(): selection -= 1 if selection >= 0: - self.images_list.SetItemState( + self.attachments_list.SetItemState( selection, wx.LIST_STATE_FOCUSED, wx.LIST_STATE_FOCUSED ) else: self.prompt.SetFocus() def on_copy_image_url(self, event: wx.CommandEvent): - selected = self.images_list.GetFirstSelected() + selected = self.attachments_list.GetFirstSelected() if selected == wx.NOT_FOUND: return - url = self.image_files[selected].location + url = self.attachment_files[selected].location with wx.TheClipboard as clipboard: clipboard.SetData(wx.TextDataObject(url)) @@ -475,36 +523,52 @@ def refresh_accounts(self): self.account_combo.SetSelection(0) self.account_combo.SetFocus() - def refresh_images_list(self): - self.images_list.DeleteAllItems() - if not self.image_files: - self.images_list_label.Hide() - self.images_list.Hide() + def refresh_attachments_list(self): + self.attachments_list.DeleteAllItems() + if not self.attachment_files: + self.attachments_list_label.Hide() + self.attachments_list.Hide() self.Layout() return - self.images_list_label.Show() - self.images_list.Show() + self.attachments_list_label.Show() + self.attachments_list.Show() self.Layout() - for i, image in enumerate(self.image_files): - self.images_list.InsertItem(i, image.name) - self.images_list.SetItem(i, 1, image.display_size) - self.images_list.SetItem(i, 2, image.display_dimensions) - self.images_list.SetItem(i, 3, image.display_location) - self.images_list.SetItemState( + for i, image in enumerate(self.attachment_files): + self.attachments_list.InsertItem(i, image.name) + self.attachments_list.SetItem(i, 1, image.display_size) + self.attachments_list.SetItemState( i, wx.LIST_STATE_FOCUSED, wx.LIST_STATE_FOCUSED ) - self.images_list.EnsureVisible(i) + self.attachments_list.EnsureVisible(i) - def add_images(self, paths: list[str | ImageFile]): - log.debug(f"Adding images: {paths}") + def add_attachments(self, paths: list[str | AttachmentFile | ImageFile]): + log.debug(f"Adding attachments: {paths}") for path in paths: - if isinstance(path, ImageFile): - self.image_files.append(path) + if isinstance(path, (AttachmentFile, ImageFile)): + self.attachment_files.append(path) else: - file = ImageFile(location=path) - self.image_files.append(file) - self.refresh_images_list() - self.images_list.SetFocus() + mime_type = get_mime_type(path) + supported_attachment_formats = ( + self.current_engine.supported_attachment_formats + ) + if mime_type not in supported_attachment_formats: + wx.MessageBox( + # Translators: This message is displayed when there are no supported attachment formats. + _( + "This attachment format is not supported by the current provider. Source:" + ) + + f"\n{path}", + _("Error"), + wx.OK | wx.ICON_ERROR, + ) + return + if mime_type.startswith("image/"): + file = ImageFile(location=path) + else: + file = AttachmentFile(location=path) + self.attachment_files.append(file) + self.refresh_attachments_list() + self.attachments_list.SetFocus() def on_config_change(self): self.refresh_accounts() @@ -871,7 +935,7 @@ def on_prompt_key_down(self, event: wx.KeyEvent): event.Skip() def on_prompt_paste(self, event): - self.on_image_paste(event) + self.on_attachments_paste(event) def insert_previous_prompt(self, event: wx.CommandEvent = None): if self.conversation.messages: @@ -960,8 +1024,8 @@ def refresh_messages(self, need_clear: bool = True): if need_clear: self.messages.Clear() self.message_segment_manager.clear() - self.image_files.clear() - self.refresh_images_list() + self.attachment_files.clear() + self.refresh_attachments_list() for block in self.conversation.messages: self.display_new_block(block) @@ -1066,7 +1130,7 @@ def ensure_model_compatibility(self) -> ProviderAIModel | None: _("Please select a model"), _("Error"), wx.OK | wx.ICON_ERROR ) return None - if self.image_files and not model.vision: + if self.attachment_files and not model.vision: vision_models = ", ".join( [m.name or m.id for m in self.current_engine.models if m.vision] ) @@ -1091,7 +1155,7 @@ def get_new_message_block(self) -> MessageBlock | None: if not model: return None if config.conf().images.resize: - for image in self.image_files: + for image in self.attachment_files: image.resize( self.conv_storage_path, config.conf().images.max_width, @@ -1102,7 +1166,7 @@ def get_new_message_block(self) -> MessageBlock | None: request=Message( role=MessageRoleEnum.USER, content=self.prompt.GetValue(), - attachments=self.image_files, + attachments=self.attachment_files, ), model_id=model.id, provider_id=self.current_account.provider.id, @@ -1124,11 +1188,42 @@ def get_completion_args(self) -> dict[str, Any] | None: "stream": new_block.stream, } + def _check_attachments_valid(self) -> bool: + supported_attachment_formats = ( + self.current_engine.supported_attachment_formats + ) + for attachment in self.attachment_files: + if attachment.mime_type not in supported_attachment_formats: + self.attachment_files.remove(attachment) + wx.MessageBox( + # Translators: This message is displayed when an attachment format is not supported. + _( + "This attachment format is not supported by the current provider. Source:" + ) + + f"\n{attachment.location}", + _("Error"), + wx.OK | wx.ICON_ERROR, + ) + return False + if not attachment.location.exists(): + self.attachment_files.remove(attachment) + wx.MessageBox( + # Translators: This message is displayed when an attachment file does not exist. + _("The attachment file does not exist: %s") + % attachment.location, + _("Error"), + wx.OK | wx.ICON_ERROR, + ) + return False + return True + @ensure_no_task_running def on_submit(self, event: wx.CommandEvent): if not self.submit_btn.IsEnabled(): return - if not self.prompt.GetValue() and not self.image_files: + if not self._check_attachments_valid(): + return + if not self.prompt.GetValue() and not self.attachment_files: self.prompt.SetFocus() return completion_kw = self.get_completion_args() @@ -1188,8 +1283,8 @@ def _pre_handle_completion_with_stream(self, new_block: MessageBlock): self.display_new_block(new_block) self.messages.SetInsertionPointEnd() self.prompt.Clear() - self.image_files.clear() - self.refresh_images_list() + self.attachment_files.clear() + self.refresh_attachments_list() def _handle_completion_with_stream(self, chunk: str): self.stream_buffer += chunk @@ -1297,8 +1392,8 @@ def _post_completion_without_stream(self, new_block: MessageBlock): self.display_new_block(new_block) self._handle_accessible_output(new_block.response.content) self.prompt.Clear() - self.image_files.clear() - self.refresh_images_list() + self.attachment_files.clear() + self.refresh_attachments_list() if config.conf().conversation.focus_history_after_send: self.messages.SetFocus() self._end_task() diff --git a/basilisk/gui/main_frame.py b/basilisk/gui/main_frame.py index 3fc5db32..f0e376a3 100644 --- a/basilisk/gui/main_frame.py +++ b/basilisk/gui/main_frame.py @@ -298,7 +298,7 @@ def screen_capture( def post_screen_capture(self, imagefile: ImageFile | str): log.debug("Screen capture received") - self.current_tab.add_images([imagefile]) + self.current_tab.add_attachments([imagefile]) if not self.IsShown(): self.Show() self.Restore() @@ -469,7 +469,7 @@ def on_add_image(self, event, from_url=False): if from_url: current_tab.add_image_url_dlg() else: - current_tab.add_image_files() + current_tab.add_attachments_dlg() def on_transcribe_audio( self, event: wx.Event, from_microphone: bool = False diff --git a/basilisk/provider_capability.py b/basilisk/provider_capability.py index b18897fe..fa02ff4c 100644 --- a/basilisk/provider_capability.py +++ b/basilisk/provider_capability.py @@ -2,6 +2,7 @@ class ProviderCapability(Enum): + DOCUMENT = "document" IMAGE = "image" TEXT = "text" STT = "stt" diff --git a/basilisk/provider_engine/anthropic_engine.py b/basilisk/provider_engine/anthropic_engine.py index cde1107f..921ff06b 100644 --- a/basilisk/provider_engine/anthropic_engine.py +++ b/basilisk/provider_engine/anthropic_engine.py @@ -7,11 +7,13 @@ from anthropic import Anthropic from anthropic.types import Message as AnthropicMessage from anthropic.types import TextBlock +from anthropic.types.document_block_param import DocumentBlockParam from anthropic.types.image_block_param import ImageBlockParam, Source +from anthropic.types.text_block_param import TextBlockParam from basilisk.conversation import ( + AttachmentFileTypes, Conversation, - ImageFileTypes, Message, MessageBlock, MessageRoleEnum, @@ -31,6 +33,15 @@ class AnthropicEngine(BaseEngine): capabilities: set[ProviderCapability] = { ProviderCapability.TEXT, ProviderCapability.IMAGE, + ProviderCapability.DOCUMENT, + } + supported_attachment_formats: set[str] = { + "image/gif", + "image/jpeg", + "image/png", + "image/webp", + "application/pdf", + "text/plain", } def __init__(self, account: Account) -> None: @@ -160,15 +171,29 @@ def convert_message(self, message: Message) -> dict: contents = [TextBlock(text=message.content, type="text")] if message.attachments: for attachment in message.attachments: - if attachment.type != ImageFileTypes.IMAGE_URL: + mime_type = attachment.mime_type + if attachment.type != AttachmentFileTypes.URL: source = Source( - data=attachment.encode_image(), + data=None, media_type=attachment.mime_type, type="base64", ) - contents.append( - ImageBlockParam(source=source, type="image") - ) + if mime_type.startswith("image/"): + source["data"] = attachment.encode_image() + contents.append( + ImageBlockParam(source=source, type="image") + ) + elif mime_type.startswith("application/"): + source["data"] = attachment.encode_base64() + contents.append( + DocumentBlockParam(source=source, type="document") + ) + elif mime_type in ("text/csv", "text/plain"): + source["data"] = attachment.read_as_str() + source["type"] = "text" + contents.append( + TextBlockParam(source=source, type="document") + ) return {"role": message.role.value, "content": contents} prepare_message_request = convert_message diff --git a/basilisk/provider_engine/base_engine.py b/basilisk/provider_engine/base_engine.py index 36a833b6..ef50a3c6 100644 --- a/basilisk/provider_engine/base_engine.py +++ b/basilisk/provider_engine/base_engine.py @@ -15,6 +15,7 @@ class BaseEngine(ABC): capabilities: set[ProviderCapability] = set() + supported_attachment_formats: set[str] = {} def __init__(self, account: Account) -> None: self.account = account diff --git a/basilisk/provider_engine/gemini_engine.py b/basilisk/provider_engine/gemini_engine.py index 1c0a091e..c085d7ed 100644 --- a/basilisk/provider_engine/gemini_engine.py +++ b/basilisk/provider_engine/gemini_engine.py @@ -7,9 +7,9 @@ import google.generativeai as genai from basilisk.conversation import ( + AttachmentFileTypes, Conversation, ImageFile, - ImageFileTypes, Message, MessageBlock, MessageRoleEnum, @@ -28,6 +28,13 @@ class GeminiEngine(BaseEngine): ProviderCapability.TEXT, ProviderCapability.IMAGE, } + supported_attachment_formats: set[str] = { + "image/png", + "image/jpeg", + "image/webp", + "image/heic", + "image/heif", + } def __init__(self, account: Account) -> None: super().__init__(account) @@ -115,7 +122,7 @@ def convert_role(self, role: MessageRoleEnum) -> str: ) def convert_image(self, image: ImageFile) -> genai.protos.Part: - if image.type == ImageFileTypes.IMAGE_URL: + if image.type == AttachmentFileTypes.URL: raise NotImplementedError("Image URL not supported") with image.send_location.open("rb") as f: blob = genai.protos.Blob(mime_type=image.mime_type, data=f.read()) diff --git a/basilisk/provider_engine/openai_engine.py b/basilisk/provider_engine/openai_engine.py index 5d3c6f48..c2fa7406 100644 --- a/basilisk/provider_engine/openai_engine.py +++ b/basilisk/provider_engine/openai_engine.py @@ -38,6 +38,12 @@ class OpenAIEngine(BaseEngine): ProviderCapability.STT, ProviderCapability.TTS, } + supported_attachment_formats: set[str] = { + "image/gif", + "image/jpeg", + "image/png", + "image/webp", + } def __init__(self, account: Account) -> None: super().__init__(account)