Skip to content

Commit

Permalink
Merge pull request #873 from PrefectHQ/streaming
Browse files Browse the repository at this point in the history
Update assistants to use streaming API
  • Loading branch information
jlowin authored Mar 20, 2024
2 parents 166c88d + 5df04de commit 7257289
Show file tree
Hide file tree
Showing 12 changed files with 628 additions and 496 deletions.
Binary file modified docs/assets/images/docs/assistants/code_interpreter.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/assets/images/docs/assistants/quickstart.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/assets/images/docs/assistants/talking.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
221 changes: 139 additions & 82 deletions docs/docs/interactive/assistants.md

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ dependencies = [
"httpx>=0.24.1",
"jinja2>=3.1.2",
"jsonpatch>=1.33",
"openai>=1.1.0",
"openai>=1.4.0",
"pydantic>=2.4.2",
"pydantic_settings",
"rich>=12",
Expand All @@ -26,6 +26,7 @@ dependencies = [
# need for windows
"tzdata>=2023.3",
"uvicorn>=0.22.0",
"partialjson>=0.0.5",
]

[project.optional-dependencies]
Expand Down
5 changes: 3 additions & 2 deletions src/marvin/beta/assistants/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .runs import Run
from .threads import Thread, ThreadMonitor
from .threads import Thread
from .assistants import Assistant
from .formatting import pprint_message, pprint_messages
from .handlers import PrintHandler
from .formatting import pprint_messages, pprint_steps, pprint_run
from marvin.tools.assistants import Retrieval, CodeInterpreter
62 changes: 33 additions & 29 deletions src/marvin/beta/assistants/assistants.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from typing import TYPE_CHECKING, Callable, Optional, Union

from openai.types.beta.threads.required_action_function_tool_call import (
RequiredActionFunctionToolCall,
)
from pydantic import BaseModel, Field, PrivateAttr
from openai import AssistantEventHandler, AsyncAssistantEventHandler
from pydantic import BaseModel, Field, PrivateAttr, field_validator

import marvin.utilities.openai
import marvin.utilities.tools
from marvin.beta.assistants.handlers import PrintHandler
from marvin.tools.assistants import AssistantTool
from marvin.types import Tool
from marvin.utilities.asyncio import (
Expand All @@ -16,14 +15,16 @@
)
from marvin.utilities.logging import get_logger

from .threads import Message, Thread
from .threads import Thread

if TYPE_CHECKING:
from .runs import Run


logger = get_logger("Assistants")

NOT_PROVIDED = "__NOT_PROVIDED__"


class Assistant(BaseModel, ExposeSyncMethodsMixin):
"""
Expand All @@ -41,9 +42,10 @@ class Assistant(BaseModel, ExposeSyncMethodsMixin):
instructions (list): List of instructions for the assistant.
"""

model_config = dict(extra="forbid")
id: Optional[str] = None
name: str = "Assistant"
model: str = "gpt-4-1106-preview"
model: str = Field(None, validate_default=True)
instructions: Optional[str] = Field(None, repr=False)
tools: list[Union[AssistantTool, Callable]] = []
file_ids: list[str] = []
Expand All @@ -57,6 +59,12 @@ class Assistant(BaseModel, ExposeSyncMethodsMixin):
description="A default thread for the assistant.",
)

@field_validator("model", mode="before")
def default_model(cls, model):
if model is None:
model = marvin.settings.openai.assistants.model
return model

def clear_default_thread(self):
self.default_thread = Thread()

Expand All @@ -79,31 +87,32 @@ async def say_async(
message: str,
file_paths: Optional[list[str]] = None,
thread: Optional[Thread] = None,
return_user_message: bool = False,
event_handler_class: type[
Union[AssistantEventHandler, AsyncAssistantEventHandler]
] = NOT_PROVIDED,
**run_kwargs,
) -> list[Message]:
"""
A convenience method for adding a user message to the assistant's
default thread, running the assistant, and returning the assistant's
messages.
"""
) -> "Run":
thread = thread or self.default_thread

if event_handler_class is NOT_PROVIDED:
event_handler_class = PrintHandler

# post the message
user_message = await thread.add_async(message, file_paths=file_paths)

# run the thread
async with self:
await thread.run_async(assistant=self, **run_kwargs)
from marvin.beta.assistants.runs import Run

# load all messages, including the user message
response_messages = await thread.get_messages_async(
after_message=user_message.id
run = Run(
# provide the user message as part of the run to print
messages=[user_message],
assistant=self,
thread=thread,
event_handler_class=event_handler_class,
**run_kwargs,
)
result = await run.run_async()

if return_user_message:
response_messages = [user_message] + response_messages
return response_messages
return result

def __enter__(self):
return run_sync(self.__aenter__())
Expand Down Expand Up @@ -176,13 +185,8 @@ def chat(self, thread: Thread = None):
thread = self.default_thread
return thread.chat(assistant=self)

def pre_run_hook(self, run: "Run"):
def pre_run_hook(self):
pass

def post_run_hook(
self,
run: "Run",
tool_calls: Optional[list[RequiredActionFunctionToolCall]] = None,
tool_outputs: Optional[list[dict[str, str]]] = None,
):
def post_run_hook(self, run: "Run"):
pass
Loading

0 comments on commit 7257289

Please sign in to comment.