From b9e32cca18d2f5b6b6694b3cf60cca6b5eaa6f7d Mon Sep 17 00:00:00 2001 From: Douglas Reid Date: Tue, 3 Oct 2023 10:17:35 -0700 Subject: [PATCH] fix(agents): generate proper history blocks in agent (#570) With the inflight changes for streaming, we unfortunately merged a commit that left message selection in a broken state. The existing testing was not enough to capture the issue. This PR attempts to restore proper message selection functionality. Follow on PRs will be necessary to clean up and streamline the selection bits added in this PR. Co-authored-by: Douglas Reid --- .../agents/examples/example_assistant.py | 18 ++- .../agents/functional/functions_based.py | 86 ++++++++++-- .../agents/functional/output_parser.py | 43 +++++- src/steamship/agents/schema/action.py | 52 ++++--- src/steamship/agents/schema/agent.py | 4 + src/steamship/agents/schema/chathistory.py | 6 +- .../agents/schema/message_selectors.py | 55 ++++++-- src/steamship/agents/service/agent_service.py | 6 +- src/steamship/data/tags/tag_constants.py | 2 + src/steamship/utils/repl.py | 2 +- .../agents/test_agent_service.py | 40 +++++- .../agents/test_functions_based_agent.py | 128 +++++++++++++++++- tests/steamship_tests/utils/fixtures.py | 3 +- 13 files changed, 384 insertions(+), 61 deletions(-) diff --git a/src/steamship/agents/examples/example_assistant.py b/src/steamship/agents/examples/example_assistant.py index 159c3f412..825fef539 100644 --- a/src/steamship/agents/examples/example_assistant.py +++ b/src/steamship/agents/examples/example_assistant.py @@ -1,9 +1,14 @@ +from typing import Type + +from pydantic.fields import Field + from steamship.agents.functional import FunctionsBasedAgent from steamship.agents.llms.openai import ChatOpenAI from steamship.agents.schema.message_selectors import MessageWindowMessageSelector from steamship.agents.service.agent_service import AgentService from steamship.agents.tools.image_generation import DalleTool from steamship.agents.tools.search import SearchTool +from steamship.invocable import Config from steamship.utils.repl import AgentREPL @@ -13,6 +18,13 @@ class MyFunctionsBasedAssistant(AgentService): to provide an overview of the types of tasks it can accomplish (here, search and image generation).""" + class AgentConfig(Config): + model_name: str = Field(default="gpt-4") + + @classmethod + def config_cls(cls) -> Type[Config]: + return MyFunctionsBasedAssistant.AgentConfig + def __init__(self, **kwargs): super().__init__(**kwargs) self.set_default_agent( @@ -21,7 +33,7 @@ def __init__(self, **kwargs): SearchTool(), DalleTool(), ], - llm=ChatOpenAI(self.client, temperature=0), + llm=ChatOpenAI(self.client, temperature=0, model_name=self.config.model_name), message_selector=MessageWindowMessageSelector(k=2), ) ) @@ -31,4 +43,6 @@ def __init__(self, **kwargs): # AgentREPL provides a mechanism for local execution of an AgentService method. # This is used for simplified debugging as agents and tools are developed and # added. - AgentREPL(MyFunctionsBasedAssistant, agent_package_config={}).run(dump_history_on_exit=True) + AgentREPL(MyFunctionsBasedAssistant, agent_package_config={"model_name": "gpt-3.5-turbo"}).run( + dump_history_on_exit=True + ) diff --git a/src/steamship/agents/functional/functions_based.py b/src/steamship/agents/functional/functions_based.py index 13ba79a1b..66fa9becb 100644 --- a/src/steamship/agents/functional/functions_based.py +++ b/src/steamship/agents/functional/functions_based.py @@ -1,9 +1,12 @@ +import json +from operator import attrgetter from typing import List -from steamship import Block +from steamship import Block, MimeTypes, Tag from steamship.agents.functional.output_parser import FunctionsBasedOutputParser -from steamship.agents.schema import Action, AgentContext, ChatAgent, ChatLLM, Tool -from steamship.data.tags.tag_constants import RoleTag +from steamship.agents.schema import Action, AgentContext, ChatAgent, ChatLLM, FinishAction, Tool +from steamship.data.tags.tag_constants import ChatTag, RoleTag, TagKind, TagValueKey +from steamship.data.tags.tag_utils import get_tag class FunctionsBasedAgent(ChatAgent): @@ -54,6 +57,8 @@ def build_chat_history_for_tool(self, context: AgentContext) -> List[Block]: # get most recent context messages_from_memory.extend(context.chat_history.select_messages(self.message_selector)) + messages_from_memory.sort(key=attrgetter("index_in_file")) + # de-dupe the messages from memory ids = [context.chat_history.last_user_message.id] for msg in messages_from_memory: @@ -67,10 +72,8 @@ def build_chat_history_for_tool(self, context: AgentContext) -> List[Block]: # this should happen BEFORE any agent/assistant messages related to tool selection messages.append(context.chat_history.last_user_message) - # get completed steps - actions = context.completed_steps - for action in actions: - messages.extend(action.to_chat_messages()) + # get working history (completed actions) + messages.extend(self._function_calls_since_last_user_message(context)) return messages @@ -81,4 +84,71 @@ def next_action(self, context: AgentContext) -> Action: # Run the default LLM on those messages output_blocks = self.llm.chat(messages=messages, tools=self.tools) - return self.output_parser.parse(output_blocks[0].text, context) + future_action = self.output_parser.parse(output_blocks[0].text, context) + if not isinstance(future_action, FinishAction): + # record the LLM's function response in history + self._record_action_selection(future_action, context) + return future_action + + def _function_calls_since_last_user_message(self, context: AgentContext) -> List[Block]: + function_calls = [] + for block in context.chat_history.messages[::-1]: # is this too inefficient at scale? + if block.chat_role == RoleTag.USER: + return reversed(function_calls) + if get_tag(block.tags, kind=TagKind.ROLE, name=RoleTag.FUNCTION): + function_calls.append(block) + elif get_tag(block.tags, kind=TagKind.FUNCTION_SELECTION): + function_calls.append(block) + return reversed(function_calls) + + def _to_openai_function_selection(self, action: Action) -> str: + """NOTE: Temporary placeholder. Should be refactored""" + fc = {"name": action.tool} + args = {} + for block in action.input: + for t in block.tags: + if t.kind == TagKind.FUNCTION_ARG: + args[t.name] = block.as_llm_input(exclude_block_wrapper=True) + + fc["arguments"] = json.dumps(args) # the arguments must be a string value NOT a dict + return json.dumps(fc) + + def _record_action_selection(self, action: Action, context: AgentContext): + tags = [ + Tag( + kind=TagKind.CHAT, + name=ChatTag.ROLE, + value={TagValueKey.STRING_VALUE: RoleTag.ASSISTANT}, + ), + Tag(kind=TagKind.FUNCTION_SELECTION, name=action.tool), + ] + context.chat_history.file.append_block( + text=self._to_openai_function_selection(action), tags=tags, mime_type=MimeTypes.TXT + ) + + def record_action_run(self, action: Action, context: AgentContext): + super().record_action_run(action, context) + + if isinstance(action, FinishAction): + return + + tags = [ + Tag( + kind=TagKind.ROLE, + name=RoleTag.FUNCTION, + value={TagValueKey.STRING_VALUE: action.tool}, + ), + # need the following tag for backwards compatibility with older gpt-4 plugin + Tag( + kind="name", + name=action.tool, + ), + ] + # TODO(dougreid): I'm not convinced this is correct for tools that return multiple values. + # It _feels_ like these should be named and inlined as a single message in history, etc. + for block in action.output: + context.chat_history.file.append_block( + text=block.as_llm_input(exclude_block_wrapper=True), + tags=tags, + mime_type=block.mime_type, + ) diff --git a/src/steamship/agents/functional/output_parser.py b/src/steamship/agents/functional/output_parser.py index 9dbb8fa79..30ced7cf7 100644 --- a/src/steamship/agents/functional/output_parser.py +++ b/src/steamship/agents/functional/output_parser.py @@ -4,9 +4,9 @@ from json import JSONDecodeError from typing import Dict, List, Optional -from steamship import Block, MimeTypes, Steamship +from steamship import Block, MimeTypes, Steamship, Tag from steamship.agents.schema import Action, AgentContext, FinishAction, OutputParser, Tool -from steamship.data.tags.tag_constants import RoleTag +from steamship.data.tags.tag_constants import RoleTag, TagKind from steamship.utils.utils import is_valid_uuid4 @@ -43,16 +43,45 @@ def _extract_action_from_function_call(self, text: str, context: AgentContext) - try: args = json.loads(arguments) if text := args.get("text"): - input_blocks.append(Block(text=text, mime_type=MimeTypes.TXT)) + input_blocks.append( + Block( + text=text, + tags=[Tag(kind=TagKind.FUNCTION_ARG, name="text")], + mime_type=MimeTypes.TXT, + ) + ) elif uuid_arg := args.get("uuid"): - input_blocks.append(Block.get(context.client, _id=uuid_arg)) + existing_block = Block.get(context.client, _id=uuid_arg) + tag = Tag.create( + existing_block.client, + file_id=existing_block.file_id, + block_id=existing_block.id, + kind=TagKind.FUNCTION_ARG, + name="uuid", + ) + existing_block.tags.append(tag) + input_blocks.append(existing_block) except json.decoder.JSONDecodeError: if isinstance(arguments, str): if is_valid_uuid4(arguments): - input_blocks.append(Block.get(context.client, _id=uuid_arg)) + existing_block = Block.get(context.client, _id=arguments) + tag = Tag.create( + existing_block.client, + file_id=existing_block.file_id, + block_id=existing_block.id, + kind=TagKind.FUNCTION_ARG, + name="uuid", + ) + existing_block.tags.append(tag) + input_blocks.append(existing_block) else: - input_blocks.append(Block(text=arguments, mime_type=MimeTypes.TXT)) - + input_blocks.append( + Block( + text=arguments, + tags=[Tag(kind=TagKind.FUNCTION_ARG, name="text")], + mime_type=MimeTypes.TXT, + ) + ) return Action(tool=tool.name, input=input_blocks, context=context) @staticmethod diff --git a/src/steamship/agents/schema/action.py b/src/steamship/agents/schema/action.py index 929f5afde..3fd154863 100644 --- a/src/steamship/agents/schema/action.py +++ b/src/steamship/agents/schema/action.py @@ -2,9 +2,7 @@ from pydantic import BaseModel -from steamship import Block, Tag -from steamship.data import TagKind -from steamship.data.tags.tag_constants import RoleTag +from steamship import Block class Action(BaseModel): @@ -28,25 +26,35 @@ class Action(BaseModel): Setting this to True means that the executing Agent should halt any reasoning. """ - def to_chat_messages(self) -> List[Block]: - tags = [ - Tag(kind=TagKind.ROLE, name=RoleTag.FUNCTION), - Tag(kind="name", name=self.tool), - ] - blocks = [] - for block in self.output: - # TODO(dougreid): should we revisit as_llm_input? we might need only the UUID... - blocks.append( - Block( - text=block.as_llm_input(exclude_block_wrapper=True), - tags=tags, - mime_type=block.mime_type, - ) - ) - - # TODO(dougreid): revisit when have multiple output functions. - # Current thinking: LLM will be OK with multiple function blocks in a row. NEEDS validation. - return blocks + # def to_chat_messages(self) -> List[Block]: + # blocks = [] + # for arg in self.input: + # + # + # blocks.append( + # Block( + # text=json.dumps({"name": f"{self.tool}", "arguments": "{ \"text\": \"who is the current president of Taiwan?\" }"}), + # ) + # ) + # + # tags = [ + # Tag(kind=TagKind.ROLE, name=RoleTag.FUNCTION), + # Tag(kind="name", name=self.tool), + # ] + # + # for block in self.output: + # # TODO(dougreid): should we revisit as_llm_input? we might need only the UUID... + # blocks.append( + # Block( + # text=block.as_llm_input(exclude_block_wrapper=True), + # tags=tags, + # mime_type=block.mime_type, + # ) + # ) + # + # # TODO(dougreid): revisit when have multiple output functions. + # # Current thinking: LLM will be OK with multiple function blocks in a row. NEEDS validation. + # return blocks class FinishAction(Action): diff --git a/src/steamship/agents/schema/agent.py b/src/steamship/agents/schema/agent.py index 890b29dc4..4d2767511 100644 --- a/src/steamship/agents/schema/agent.py +++ b/src/steamship/agents/schema/agent.py @@ -31,6 +31,10 @@ class Agent(BaseModel, ABC): def next_action(self, context: AgentContext) -> Action: pass + def record_action_run(self, action: Action, context: AgentContext): + # TODO(dougreid): should this method (or just bit) actually be on AgentContext? + context.completed_steps.append(action) + class LLMAgent(Agent): """LLMAgents choose next actions for an AgentService based on interactions with an LLM.""" diff --git a/src/steamship/agents/schema/chathistory.py b/src/steamship/agents/schema/chathistory.py index 43de4cf68..c8ecdc3e1 100644 --- a/src/steamship/agents/schema/chathistory.py +++ b/src/steamship/agents/schema/chathistory.py @@ -144,7 +144,11 @@ def append_message_with_role( text=text, tags=tags, content=content, url=url, mime_type=mime_type ) # don't index status messages - if self.embedding_index is not None and role is not RoleTag.AGENT: + if self.embedding_index is not None and role not in [ + RoleTag.AGENT, + RoleTag.TOOL, + RoleTag.LLM, + ]: chunk_tags = self.text_splitter.chunk_text_to_tags( block, kind=TagKind.CHAT, name=ChatTag.CHUNK ) diff --git a/src/steamship/agents/schema/message_selectors.py b/src/steamship/agents/schema/message_selectors.py index 6c5b0f260..545f67a2b 100644 --- a/src/steamship/agents/schema/message_selectors.py +++ b/src/steamship/agents/schema/message_selectors.py @@ -5,7 +5,8 @@ from pydantic.main import BaseModel from steamship import Block -from steamship.data.tags.tag_constants import RoleTag +from steamship.data.tags.tag_constants import RoleTag, TagKind +from steamship.data.tags.tag_utils import get_tag class MessageSelector(BaseModel, ABC): @@ -29,23 +30,53 @@ def is_assistant_message(block: Block) -> bool: return role == RoleTag.ASSISTANT +def is_function_message(block: Block) -> bool: + is_function_selection = get_tag(block.tags, kind=TagKind.FUNCTION_SELECTION) + return is_function_selection + + +def is_tool_function_message(block: Block) -> bool: + is_function_call = get_tag(block.tags, kind=TagKind.ROLE, name=RoleTag.FUNCTION) + return is_function_call + + +def is_user_history_message(block: Block) -> bool: + return is_user_message(block) or ( + is_assistant_message(block) and not is_function_message(block) + ) + + class MessageWindowMessageSelector(MessageSelector): k: int def get_messages(self, messages: List[Block]) -> List[Block]: msgs = messages[:] - msgs.pop() # don't add the current prompt to the memory - if len(msgs) <= (self.k * 2): - return msgs - + # msgs.pop() + have_seen_user_message = False + if is_user_message(msgs[-1]): + have_seen_user_message = True + msgs.pop() # don't add the current prompt to the memory selected_msgs = [] + conversation_messages = 0 limit = self.k * 2 - scope = msgs[len(messages) - limit :] - for block in scope: - if is_user_message(block) or is_assistant_message(block): + message_index = len(msgs) - 1 + while (conversation_messages < limit) and (message_index > 0): + # TODO(dougreid): i _think_ we don't need the function return if we have a user-assistant pair + # but, for safety here, we try to add non-current function blocks from past iterations. + block = msgs[message_index] + if is_user_message(block): + have_seen_user_message = True + if is_user_history_message(block): selected_msgs.append(block) + conversation_messages += 1 + elif have_seen_user_message and ( + is_function_message(block) or is_tool_function_message(block) + ): + # conditionally append working function call messages + selected_msgs.append(block) + message_index -= 1 - return selected_msgs + return reversed(selected_msgs) def tokens(block: Block) -> int: @@ -62,9 +93,11 @@ def get_messages(self, messages: List[Block]) -> List[Block]: current_tokens = 0 msgs = messages[:] - msgs.pop() # don't add the current prompt to the memory + if is_user_message(msgs[-1]): + msgs.pop() # don't add the current prompt to the memory + for block in reversed(msgs): - if block.chat_role != RoleTag.SYSTEM and current_tokens < self.max_tokens: + if is_user_history_message(block) and current_tokens < self.max_tokens: block_tokens = tokens(block) if block_tokens + current_tokens < self.max_tokens: selected_messages.append(block) diff --git a/src/steamship/agents/service/agent_service.py b/src/steamship/agents/service/agent_service.py index 26c75410c..e9cc17ac3 100644 --- a/src/steamship/agents/service/agent_service.py +++ b/src/steamship/agents/service/agent_service.py @@ -141,7 +141,7 @@ def run_action(self, agent: Agent, action: Action, context: AgentContext): }, ) action.output = output_blocks - context.completed_steps.append(action) + agent.record_action_run(action, context) return tool = next((tool for tool in agent.tools if tool.name == action.tool), None) @@ -182,7 +182,7 @@ def run_action(self, agent: Agent, action: Action, context: AgentContext): action.is_final = ( tool.is_final ) # Permit the tool to decide if this action should halt the reasoning loop. - context.completed_steps.append(action) + agent.record_action_run(action, context) if context.action_cache and tool.cacheable: context.action_cache.update(key=action, value=action.output) @@ -253,7 +253,7 @@ def run_agent(self, agent: Agent, context: AgentContext): }, ) - context.completed_steps.append(action) + agent.record_action_run(action, context) output_text_length = 0 if action.output is not None: output_text_length = sum([len(block.text or "") for block in action.output]) diff --git a/src/steamship/data/tags/tag_constants.py b/src/steamship/data/tags/tag_constants.py index d4847bf97..a5a61577c 100644 --- a/src/steamship/data/tags/tag_constants.py +++ b/src/steamship/data/tags/tag_constants.py @@ -33,6 +33,8 @@ class TagKind(str, Enum): AGENT_STATUS_MESSAGE = "agent-status-message" TOOL_STATUS_MESSAGE = "tool-status-message" LLM_STATUS_MESSAGE = "llm-status-message" + FUNCTION_ARG = "function-arg" + FUNCTION_SELECTION = "function-selection" class DocTag(str, Enum): diff --git a/src/steamship/utils/repl.py b/src/steamship/utils/repl.py index bcfd163c4..0378e7208 100644 --- a/src/steamship/utils/repl.py +++ b/src/steamship/utils/repl.py @@ -172,7 +172,7 @@ def run_with_client(self, client: Steamship, **kwargs): except ImportError: def colored(text: str, color: str, **kwargs): - print(text) + return text print("Starting REPL for Agent...") print("If you make code changes, restart this REPL. Press CTRL+C to exit at any time.\n") diff --git a/tests/steamship_tests/agents/test_agent_service.py b/tests/steamship_tests/agents/test_agent_service.py index ecb9ffecd..5e2517b36 100644 --- a/tests/steamship_tests/agents/test_agent_service.py +++ b/tests/steamship_tests/agents/test_agent_service.py @@ -25,7 +25,6 @@ def _blocks_from_invoke(client: Steamship, potential_blocks) -> List[Block]: @pytest.mark.usefixtures("client") def test_example_with_caching_service(client: Steamship): - # TODO(dougreid): replace the example agent with fake/free/fast tools to minimize test time / costs? example_caching_agent_path = ( @@ -85,7 +84,6 @@ def test_example_with_caching_service(client: Steamship): class FakeUncachableTool(Tool): - name = "FakeUncacheableTool" human_description = "Fake tool" agent_description = "Ignored" @@ -248,3 +246,41 @@ def test_context_logging_to_chat_history_everything(client: Steamship): assert not has_status_message(chat_history.messages, RoleTag.AGENT) assert not has_status_message(chat_history.messages, RoleTag.LLM) assert has_status_message(chat_history.messages, RoleTag.TOOL) + + +@pytest.mark.usefixtures("client") +def test_non_duplicate_messages(client: Steamship): + example_agent_service_path = ( + SRC_PATH / "steamship" / "agents" / "examples" / "example_assistant.py" + ) + version_config_template = { + "model_name": {"type": "string"}, + } + instance_config = {"model_name": "gpt-3.5-turbo"} + with deploy_package( + client=client, + py_path=example_agent_service_path, + version_config_template=version_config_template, + instance_config=instance_config, + wait_for_init=True, + ) as ( + _, + _, + agent_service, + ): + context_id = "test-for-message-duplication" + agent_service.blocks_from_invoke( + "prompt", prompt="who is the president of Taiwan?", context_id=context_id + ) + final_blocks = agent_service.blocks_from_invoke( + "prompt", prompt="totally. thanks.", context_id=context_id + ) + + assert ( + len(final_blocks) == 1 + ), f"There should only be a single block. Got: {len(final_blocks)}" + text = final_blocks[0].text + assert "\n" not in text, f"Unexpected response. Should be single line. Got: {text}" + assert ( + "function_call" not in text + ), f"Unexpected response. Should not include function call. Got: {text}" diff --git a/tests/steamship_tests/agents/test_functions_based_agent.py b/tests/steamship_tests/agents/test_functions_based_agent.py index ad788ea4a..54ce6946e 100644 --- a/tests/steamship_tests/agents/test_functions_based_agent.py +++ b/tests/steamship_tests/agents/test_functions_based_agent.py @@ -1,12 +1,17 @@ +import json + import pytest -from steamship import Block, Steamship +from steamship import Block, Steamship, Tag from steamship.agents.functional import FunctionsBasedAgent from steamship.agents.llms.openai import ChatOpenAI -from steamship.agents.schema import AgentContext, FinishAction +from steamship.agents.schema import Action, AgentContext, FinishAction from steamship.agents.schema.message_selectors import MessageWindowMessageSelector from steamship.agents.tools.image_generation import DalleTool from steamship.agents.tools.search import SearchTool +from steamship.data import TagKind, TagValueKey +from steamship.data.tags.tag_constants import ChatTag, RoleTag +from steamship.data.tags.tag_utils import get_tag, get_tag_value_key @pytest.mark.usefixtures("client") @@ -112,7 +117,7 @@ def test_functions_based_agent_tool_chaining_without_memory(client: Steamship): action.output = [] action.output.append(Block(text="George Washington")) - ctx.completed_steps.append(action) + agent.record_action_run(action, ctx) second_action = agent.next_action(context=ctx) assert not isinstance(second_action, FinishAction) @@ -150,3 +155,120 @@ def test_functions_based_agent_tools_with_memory(client: Steamship): found = True assert found + + +@pytest.mark.usefixtures("client") +def test_proper_message_selection(client: Steamship): + context_id = "test-for-message-selection" + context_keys = {"id": context_id} + agent_context = AgentContext.get_or_create(client=client, context_keys=context_keys) + + test_agent = FunctionsBasedAgent( + tools=[ + SearchTool(), + DalleTool(), + ], + llm=ChatOpenAI(client, temperature=0), + message_selector=MessageWindowMessageSelector(k=2), + ) + + # simulate prompting and tool selecttion + agent_context.chat_history.append_user_message(text="Who is the current president of Taiwan?") + agent_context.chat_history.append_agent_message(text="Ignore me.") + agent_context.chat_history.append_llm_message(text="OpenAI ChatComplete...") + + # simulate running the Tool + arg_block = Block( + text="current president of Taiwan", tags=[Tag(kind=TagKind.FUNCTION_ARG, name="text")] + ) + action = Action(tool="SearchTool", input=[arg_block], output=[Block(text="Tsai Ing-wen")]) + test_agent._record_action_selection(action=action, context=agent_context) + agent_context.chat_history.append_tool_message(text="Some tool message.") + test_agent.record_action_run(action=action, context=agent_context) + + # simulate completing + agent_context.chat_history.append_agent_message(text="something about Tsai Ing-wen") + agent_context.chat_history.append_llm_message(text="OpenAI ChatComplete...") + agent_context.chat_history.append_agent_message(text="Finish Action...") + agent_context.chat_history.append_assistant_message( + text="The current president of Taiwan is Tsai Ing-wen." + ) + + # simulate next prompt + agent_context.chat_history.append_user_message(text="totally. thanks.") + + selected_messages = test_agent.build_chat_history_for_tool(agent_context) + + expected_messages = [ + Block( + text=test_agent.PROMPT, + tags=[ + Tag( + kind=TagKind.CHAT, name=ChatTag.ROLE, value={TagValueKey.STRING_VALUE: "system"} + ) + ], + ), + Block( + text="Who is the current president of Taiwan?", + tags=[ + Tag(kind=TagKind.CHAT, name=ChatTag.ROLE, value={TagValueKey.STRING_VALUE: "user"}) + ], + ), + Block( + text=json.dumps( + {"name": "SearchTool", "arguments": '{"text": "current president of Taiwan"}'} + ), + tags=[ + Tag( + kind=TagKind.CHAT, + name=ChatTag.ROLE, + value={TagValueKey.STRING_VALUE: "assistant"}, + ), + Tag(kind="function-selection", name="SearchTool"), + ], + ), + Block( + text="Tsai Ing-wen", + tags=[ + Tag( + kind=ChatTag.ROLE, + name=RoleTag.FUNCTION, + value={TagValueKey.STRING_VALUE: "SearchTool"}, + ) + ], + ), + Block( + text="The current president of Taiwan is Tsai Ing-wen.", + tags=[ + Tag( + kind=TagKind.CHAT, + name=ChatTag.ROLE, + value={TagValueKey.STRING_VALUE: "assistant"}, + ) + ], + ), + Block( + text="totally. thanks.", + tags=[ + Tag(kind=TagKind.CHAT, name=ChatTag.ROLE, value={TagValueKey.STRING_VALUE: "user"}) + ], + ), + ] + + assert len(selected_messages) == len( + expected_messages + ), "Missing selected messages from prepared messages" + for idx, selected_msg in enumerate(selected_messages): + expected_msg = expected_messages[idx] + assert ( + selected_msg.text == expected_msg.text + ), f"Got: {selected_msg.text}, want: {expected_msg.text}" + for t in expected_msg.tags: + if t.value: + assert get_tag_value_key( + tags=selected_msg.tags, key=TagValueKey.STRING_VALUE, kind=t.kind, name=t.name + ), "Expected tag not found in selected message" + else: + assert get_tag( + tags=selected_msg.tags, kind=t.kind, name=t.name + ), "Expected tag not found in selected message" diff --git a/tests/steamship_tests/utils/fixtures.py b/tests/steamship_tests/utils/fixtures.py index 36135cfb3..5540115a8 100644 --- a/tests/steamship_tests/utils/fixtures.py +++ b/tests/steamship_tests/utils/fixtures.py @@ -59,7 +59,8 @@ def _test_something(invocable_handler): steamship = get_steamship_client() workspace_handle = random_name() workspace = Workspace.create(client=steamship, handle=workspace_handle) - new_client = get_steamship_client(workspace=workspace_handle) + # NOTE: get_steamship_client takes either `workspace_handle` or `workspace_id`, but NOT `workspace` as a keyword arg + new_client = get_steamship_client(workspace_handle=workspace_handle) def handle(verb: str, invocation_path: str, arguments: Optional[dict] = None) -> dict: _handler = _create_handler(known_invocable_for_testing=invocable)