Skip to content

Commit

Permalink
Handle run timeouts on long function calls #115
Browse files Browse the repository at this point in the history
  • Loading branch information
VRSEN committed May 1, 2024
1 parent 53843ab commit 4c0ea9e
Showing 1 changed file with 18 additions and 5 deletions.
23 changes: 18 additions & 5 deletions agency_swarm/threads/thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,10 +145,11 @@ def get_completion(self,
self.client.beta.threads.messages.create(
thread_id=self.thread.id,
role="user",
content="Please repeat the exact same function calls again in the same order."
content="Please repeat the exact same tool calls in the same order with the same arguments."
)

self._create_run(recipient_agent, additional_instructions, event_handler, tool_choice)
self._create_run(recipient_agent, additional_instructions, event_handler, 'required',
temperature=0)

self._run_until_done()

Expand All @@ -157,8 +158,18 @@ def get_completion(self,

# change tool call ids
tool_calls = self.run.required_action.submit_tool_outputs.tool_calls
for i, tool_call in enumerate(tool_calls):
tool_outputs[i]["tool_call_id"] = tool_call.id

if len(tool_calls) != len(tool_outputs):
tool_outputs = []
for i, tool_call in enumerate(tool_calls):
tool_outputs.append({"tool_call_id": tool_call.id,
"output": "Error: openai run timed out. You can try again one more time."})
else:
for i, tool_name in enumerate(tool_names):
for tool_call in tool_calls:
if tool_call.function.name == tool_name:
tool_outputs[i]["tool_call_id"] = tool_call.id
break

self._submit_tool_outputs(tool_outputs, event_handler)
else:
Expand Down Expand Up @@ -210,7 +221,7 @@ def get_completion(self,

return full_message

def _create_run(self, recipient_agent, additional_instructions, event_handler, tool_choice):
def _create_run(self, recipient_agent, additional_instructions, event_handler, tool_choice, temperature=None):
if event_handler:
with self.client.beta.threads.runs.stream(
thread_id=self.thread.id,
Expand All @@ -221,6 +232,7 @@ def _create_run(self, recipient_agent, additional_instructions, event_handler, t
max_prompt_tokens=recipient_agent.max_prompt_tokens,
max_completion_tokens=recipient_agent.max_completion_tokens,
truncation_strategy=recipient_agent.truncation_strategy,
temperature=temperature
) as stream:
stream.until_done()
self.run = stream.get_final_run()
Expand All @@ -233,6 +245,7 @@ def _create_run(self, recipient_agent, additional_instructions, event_handler, t
max_prompt_tokens=recipient_agent.max_prompt_tokens,
max_completion_tokens=recipient_agent.max_completion_tokens,
truncation_strategy=recipient_agent.truncation_strategy,
temperature=temperature
)

def _run_until_done(self):
Expand Down

0 comments on commit 4c0ea9e

Please sign in to comment.