Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(provider_engine): add support for uploading text and PDF files to Anthropic #498

Draft
wants to merge 5 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion basilisk/conversation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,32 @@
from .attached_file import (
URL_PATTERN,
AttachmentFile,
AttachmentFileTypes,
ImageFile,
NotImageError,
get_mime_type,
parse_supported_attachment_formats,
)
from .conversation_helper import PROMPT_TITLE
from .conversation_model import (
Conversation,
Message,
MessageBlock,
MessageRoleEnum,
)
from .image_model import URL_PATTERN, ImageFile, ImageFileTypes, NotImageError

__all__ = [
"AttachmentFile",
"AttachmentFileTypes",
"Conversation",
"get_mime_type",
"ImageFile",
"ImageFileTypes",
"Message",
"MessageBlock",
"MessageRoleEnum",
"NotImageError",
"parse_supported_attachment_formats",
"PROMPT_TITLE",
"URL_PATTERN",
]
Original file line number Diff line number Diff line change
Expand Up @@ -86,32 +86,129 @@ def resize_image(
return True


class ImageFileTypes(Enum):
def parse_supported_attachment_formats(
supported_attachment_formats: set[str],
) -> str:
"""
Parse the supported attachment formats into a wildcard string for use in file dialogs.
"""
wildcard_parts = []
for mime_type in sorted(supported_attachment_formats):
exts = mimetypes.guess_all_extensions(mime_type)
if exts:
log.debug(f"Adding wildcard for MIME type {mime_type}: {exts}")
wildcard_parts.append("*" + ";*".join(exts))
else:
log.warning(f"No extensions found for MIME type {mime_type}")

wildcard = ";".join(wildcard_parts)
return wildcard


def get_mime_type(path: str) -> str | None:
"""
Get the MIME type of a file.
"""
return mimetypes.guess_type(path)[0]


class AttachmentFileTypes(Enum):
UNKNOWN = "unknown"
IMAGE_LOCAL = "local"
IMAGE_MEMORY = "memory"
IMAGE_URL = "http"
LOCAL = "local"
MEMORY = "memory"
URL = "http"

@classmethod
def _missing_(cls, value: object) -> ImageFileTypes:
if isinstance(value, str) and value.lower() == "data":
return cls.IMAGE_URL
if isinstance(value, str) and value.lower() == "https":
return cls.IMAGE_URL
def _missing_(cls, value: object) -> AttachmentFileTypes:
if isinstance(value, str) and value.lower() in ("data", "https"):
return cls.URL
if isinstance(value, str) and value.lower() == "zip":
return cls.IMAGE_LOCAL
return cls.LOCAL
return cls.UNKNOWN


class NotImageError(ValueError):
pass


class ImageFile(BaseModel):
class AttachmentFile(BaseModel):
location: PydanticUPath
name: str | None = None
description: str | None = None
size: int | None = None

def __init__(self, /, **data: Any) -> None:
super().__init__(**data)
if not self.name:
self.name = self._get_name()
self.size = self._get_size()

@property
def type(self) -> AttachmentFileTypes:
return AttachmentFileTypes(self.location.protocol)

def _get_name(self) -> str:
return self.location.name

def _get_size(self) -> int | None:
if self.type == AttachmentFileTypes.URL:
return None
return self.location.stat().st_size

@property
def display_size(self) -> str:
size = self.size
if size is None:
return _("Unknown")
if size < 1024:
return f"{size} B"
if size < 1024 * 1024:
return f"{size / 1024:.2f} KB"
return f"{size / 1024 / 1024:.2f} MB"

@property
def send_location(self) -> UPath:
return self.location

@property
def mime_type(self) -> str | None:
if self.type == AttachmentFileTypes.URL:
return None
mime_type, _ = mimetypes.guess_type(self.send_location)
return mime_type

@property
def display_location(self):
location = str(self.location)
if location.startswith("data:"):
location = f"{location[:50]}...{location[-10:]}"
return location

@staticmethod
def remove_location(location: UPath):
log.debug(f"Removing image at {location}")
try:
fs = location.fs
fs.rm(location.path)
except Exception as e:
log.error(f"Error deleting image at {location}: {e}")

def read_as_str(self):
with self.location.open(mode="r") as file:
return file.read()

def encode_base64(self) -> str:
with self.location.open(mode="rb") as file:
return base64.b64encode(file.read()).decode("utf-8")

def __del__(self):
if self.type == AttachmentFileTypes.URL:
return
if self.type == AttachmentFileTypes.MEMORY:
self.remove_location(self.location)


class ImageFile(AttachmentFile):
dimensions: tuple[int, int] | None = None
resize_location: PydanticUPath | None = Field(default=None, exclude=True)

Expand All @@ -131,7 +228,7 @@ def build_from_url(cls, url: str) -> ImageFile:
dimensions = get_image_dimensions(BytesIO(r.content))
return cls(
location=url,
type=ImageFileTypes.IMAGE_URL,
type=AttachmentFileTypes.URL,
size=size,
description=content_type,
dimensions=dimensions,
Expand Down Expand Up @@ -169,39 +266,17 @@ def validate_location(

def __init__(self, /, **data: Any) -> None:
super().__init__(**data)
if not self.name:
self.name = self._get_name()
self.size = self._get_size()
if not self.dimensions:
self.dimensions = self._get_dimensions()

__init__.__pydantic_base_init__ = True

@property
def type(self) -> ImageFileTypes:
return ImageFileTypes(self.location.protocol)

def _get_name(self) -> str:
return self.location.name

def _get_size(self) -> int | None:
if self.type == ImageFileTypes.IMAGE_URL:
return None
return self.location.stat().st_size

@property
def display_size(self) -> str:
size = self.size
if size is None:
return _("Unknown")
if size < 1024:
return f"{size} B"
if size < 1024 * 1024:
return f"{size / 1024:.2f} KB"
return f"{size / 1024 / 1024:.2f} MB"
def send_location(self) -> UPath:
return self.resize_location or self.location

def _get_dimensions(self) -> tuple[int, int] | None:
if self.type == ImageFileTypes.IMAGE_URL:
if self.type == AttachmentFileTypes.URL:
return None
with self.location.open(mode="rb") as image_file:
return get_image_dimensions(image_file)
Expand All @@ -216,7 +291,7 @@ def display_dimensions(self) -> str:
def resize(
self, conv_folder: UPath, max_width: int, max_height: int, quality: int
):
if ImageFileTypes.IMAGE_URL == self.type:
if AttachmentFileTypes.URL == self.type:
return
log.debug("Resizing image")
resize_location = conv_folder.joinpath(
Expand All @@ -234,10 +309,6 @@ def resize(
)
self.resize_location = resize_location if success else None

@property
def send_location(self) -> UPath:
return self.resize_location or self.location

@measure_time
def encode_image(self) -> str:
if self.size and self.size > 1024 * 1024 * 1024:
Expand All @@ -247,29 +318,15 @@ def encode_image(self) -> str:
with self.send_location.open(mode="rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")

@property
def mime_type(self) -> str | None:
if self.type == ImageFileTypes.IMAGE_URL:
return None
mime_type, _ = mimetypes.guess_type(self.send_location)
return mime_type

@property
def url(self) -> str:
if not isinstance(self.type, ImageFileTypes):
if not isinstance(self.type, AttachmentFileTypes):
raise ValueError("Invalid image type")
if self.type == ImageFileTypes.IMAGE_URL:
if self.type == AttachmentFileTypes.URL:
return str(self.location)
base64_image = self.encode_image()
return f"data:{self.mime_type};base64,{base64_image}"

@property
def display_location(self):
location = str(self.location)
if location.startswith("data:image/"):
location = f"{location[:50]}...{location[-10:]}"
return location

@staticmethod
def remove_location(location: UPath):
log.debug(f"Removing image at {location}")
Expand All @@ -280,9 +337,6 @@ def remove_location(location: UPath):
log.error(f"Error deleting image at {location}: {e}")

def __del__(self):
if self.type == ImageFileTypes.IMAGE_URL:
return
if self.resize_location:
self.remove_location(self.resize_location)
if self.type == ImageFileTypes.IMAGE_MEMORY:
self.remove_location(self.location)
super().__del__()
10 changes: 6 additions & 4 deletions basilisk/conversation/conversation_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from basilisk.config import conf
from basilisk.decorators import measure_time

from .image_model import ImageFile, ImageFileTypes
from .attached_file import AttachmentFile, AttachmentFileTypes, ImageFile

if TYPE_CHECKING:
from .conversation_model import Conversation
Expand All @@ -23,11 +23,13 @@


def save_attachments(
attachments: list[ImageFile], attachment_path: str, fs: ZipFileSystem
attachments: list[AttachmentFile | ImageFile],
attachment_path: str,
fs: ZipFileSystem,
):
attachment_mapping = {}
for attachment in attachments:
if attachment.type == ImageFileTypes.IMAGE_URL:
if attachment.type == AttachmentFileTypes.URL:
continue
new_location = f"{attachment_path}/{attachment.location.name}"
with attachment.location.open(mode="rb") as attachment_file:
Expand Down Expand Up @@ -56,7 +58,7 @@ def create_conv_main_file(conversation: Conversation, fs: ZipFileSystem):

def restore_attachments(attachments: list[ImageFile], storage_path: UPath):
for attachment in attachments:
if attachment.type == ImageFileTypes.IMAGE_URL:
if attachment.type == AttachmentFileTypes.URL:
continue
new_path = storage_path / attachment.location.name
with attachment.location.open(mode="rb") as attachment_file:
Expand Down
4 changes: 2 additions & 2 deletions basilisk/conversation/conversation_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

from basilisk.provider_ai_model import AIModelInfo

from .attached_file import AttachmentFile, ImageFile
from .conversation_helper import create_bskc_file, open_bskc_file
from .image_model import ImageFile


class MessageRoleEnum(Enum):
Expand All @@ -21,7 +21,7 @@ class MessageRoleEnum(Enum):
class Message(BaseModel):
role: MessageRoleEnum
content: str
attachments: list[ImageFile] | None = Field(default=None)
attachments: list[AttachmentFile | ImageFile] | None = Field(default=None)


class MessageBlock(BaseModel):
Expand Down
Loading