Skip to content

Commit

Permalink
simplification
Browse files Browse the repository at this point in the history
  • Loading branch information
zby committed Dec 10, 2024
1 parent d090360 commit 7f44963
Showing 1 changed file with 10 additions and 8 deletions.
18 changes: 10 additions & 8 deletions prompete/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def render(self, template: str, **kwargs: Any) -> str: ...
class Chat:
model: str
renderer: Optional[Renderer] = None
messages: list[Union[dict, Message]] = field(default_factory=list)
messages: list[dict] = field(default_factory=list)
system_prompt: Optional[Union[Prompt, str, dict, Message]] = None
fail_on_tool_error: bool = (
True # if False the error message is passed to the LLM to fix the call, if True exception is raised
Expand Down Expand Up @@ -186,9 +186,9 @@ def llm_reply(self, tools=[], strict=False, **kwargs) -> ModelResponse:
return result

def process(self, **kwargs):
if not self.messages:
raise ValueError("No messages to process")
message = Message(**self.messages[-1])
message = self.get_tool_calls_message()
if not message:
raise ValueError("No message to process")
results = process_message(message, self.saved_tools, **kwargs)
outputs = []
for result in results:
Expand All @@ -209,12 +209,14 @@ def process(self, **kwargs):

return outputs

def get_last_message(self) -> Optional[Union[dict, Message]]:
def get_tool_calls_message(self) -> Message:
"""
Return the last message in the chat history, or None if the history is empty.
Return the last message in the chat history if it has 'tool_calls' key, or None if the history is empty.
"""
return self.messages[-1] if self.messages else None

if not self.messages:
return None
message = Message(**self.messages[-1])
return message if hasattr(message, "tool_calls") else None

if __name__ == "__main__":
import os
Expand Down

0 comments on commit 7f44963

Please sign in to comment.