Skip to content
This repository was archived by the owner on Jan 7, 2025. It is now read-only.

Commit

Permalink
Improve error handling in model response parsing
Browse files Browse the repository at this point in the history
Add error handling for invalid JSON and format in model responses

- Introduce `ParseError` class to represent parsing errors, which can be returned instead of `ParsedLLMResponse` when an error occurs.
- Handle `JSONDecodeError` and `ValidationError` exceptions in the `JsonParser` class, returning a `ParseError` instance with the appropriate error message and partial response.
- Modify the `Conversation` class to handle `ParseError` instances, displaying the error message to the user and not adding the error response to the conversation messages.
- Update the `AgentHandler` class to handle `ParseError` instances returned from the parser, displaying the error message to the user and returning an empty list of commands.

This change improves the robustness of the system by providing better error handling and feedback to the user when issues occur during the parsing of model responses.
  • Loading branch information
mentatai[bot] committed Dec 6, 2024
1 parent 129c616 commit 7b90f43
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 29 deletions.
8 changes: 7 additions & 1 deletion mentat/agent_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
ChatCompletionSystemMessageParam,
)

from mentat.parsers.parser import ParseError, ParsedEdit

from mentat.prompts.prompts import read_prompt
from mentat.session_context import SESSION_CONTEXT
from mentat.session_input import ask_yes_no, collect_user_input
Expand Down Expand Up @@ -91,8 +93,12 @@ async def _determine_commands(self) -> List[str]:

messages.append(ChatCompletionAssistantMessageParam(role="assistant", content=content))
parsed_llm_response = await ctx.config.parser.parse_llm_response(content)

if isinstance(parsed_llm_response, ParseError):
ctx.stream.send(f"Error parsing model response: {parsed_llm_response.error_message}", style="error")
return []

ctx.conversation.add_model_message(content, messages, parsed_llm_response)

commands = content.strip().split("\n")
return commands

Expand Down
30 changes: 19 additions & 11 deletions mentat/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,14 @@
raise_if_context_exceeds_max,
)
from mentat.parsers.file_edit import FileEdit
from mentat.parsers.parser import ParsedLLMResponse
from mentat.parsers.parser import ParsedLLMResponse, ParsedEdit, ParseError
from mentat.session_context import SESSION_CONTEXT
from mentat.transcripts import ModelMessage, TranscriptMessage, UserMessage
from mentat.utils import add_newline


class MentatAssistantMessageParam(ChatCompletionAssistantMessageParam):
parsed_llm_response: ParsedLLMResponse
parsed_llm_response: ParsedEdit # Only successful parses should be stored in messages


class Conversation:
Expand Down Expand Up @@ -72,6 +72,9 @@ def add_model_message(
parsed_llm_response: ParsedLLMResponse,
):
"""Used for actual model output messages"""
if isinstance(parsed_llm_response, ParseError):
raise ValueError("Cannot add error response to conversation messages")

self.add_transcript_message(ModelMessage(message=message, prior_messages=messages_snapshot))
self.add_message(
MentatAssistantMessageParam(
Expand Down Expand Up @@ -203,16 +206,17 @@ async def _stream_model_response(
async with stream.interrupt_catcher(parser.shutdown):
parsed_llm_response = await parser.stream_and_parse_llm_response(add_newline(response))

# Sampler and History require previous_file_lines
for file_edit in parsed_llm_response.file_edits:
file_edit.previous_file_lines = code_file_manager.file_lines.get(file_edit.file_path, []).copy()

llm_api_handler.display_cost_stats(response.current_response())

messages.append(
ChatCompletionAssistantMessageParam(role="assistant", content=parsed_llm_response.full_response)
)
self.add_model_message(parsed_llm_response.full_response, messages, parsed_llm_response)
if isinstance(parsed_llm_response, ParsedEdit):
# Sampler and History require previous_file_lines
for file_edit in parsed_llm_response.file_edits:
file_edit.previous_file_lines = code_file_manager.file_lines.get(file_edit.file_path, []).copy()

messages.append(
ChatCompletionAssistantMessageParam(role="assistant", content=parsed_llm_response.full_response)
)
self.add_model_message(parsed_llm_response.full_response, messages, parsed_llm_response)

return parsed_llm_response

Expand All @@ -235,7 +239,11 @@ async def get_model_response(self) -> ParsedLLMResponse:
" different model.",
style="error",
)
return ParsedLLMResponse("", "", list[FileEdit]())
return ParseError(error_message="Rate limit error from OpenAI servers")

if isinstance(response, ParseError):
stream.send(f"Error in model response: {response.error_message}", style="error")

return response

async def remaining_context(self) -> int | None:
Expand Down
4 changes: 2 additions & 2 deletions mentat/parsers/json_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,10 @@ async def stream_and_parse_llm_response(self, response: AsyncIterator[str]) -> P
except JSONDecodeError:
# Should never happen with OpenAI's response_format set to json
stream.send("Error processing model response: Invalid JSON", style="error")
return ParsedLLMResponse(message, "", [])
return ParseError(error_message="Invalid JSON in model response", partial_response=message)
except ValidationError:
stream.send("Error processing model response: Invalid format given", style="error")
return ParsedLLMResponse(message, "", [])
return ParseError(error_message="Invalid format in model response", partial_response=message)

file_edits: Dict[Path, FileEdit] = {}
for obj in response_json["content"]:
Expand Down
41 changes: 26 additions & 15 deletions mentat/parsers/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,11 @@
from pathlib import Path
from typing import AsyncIterator

from typing import Union

import attr
from openai.types.chat.completion_create_params import ResponseFormat
from pydantic import BaseModel

from mentat.code_file_manager import CodeFileManager
from mentat.errors import ModelError
Expand All @@ -27,12 +30,21 @@
from mentat.utils import convert_string_to_asynciter


@attr.define
class ParsedLLMResponse:
full_response: str = attr.field()
conversation: str = attr.field()
file_edits: list[FileEdit] = attr.field()
interrupted: bool = attr.field(default=False)
class ParsedEdit(BaseModel):
"""Represents a successful parse of an LLM response with edits"""
full_response: str
conversation: str
file_edits: list[FileEdit]
interrupted: bool = False


class ParseError(BaseModel):
"""Represents an error that occurred during parsing"""
error_message: str
partial_response: str = ""


ParsedLLMResponse = Union[ParsedEdit, ParseError]


class Parser(ABC):
Expand Down Expand Up @@ -186,10 +198,9 @@ async def stream_and_parse_llm_response(self, response: AsyncIterator[str]) -> P
await printer_task
logging.debug("LLM Response:")
logging.debug(message)
return ParsedLLMResponse(
message,
conversation,
[file_edit for file_edit in file_edits.values()],
return ParseError(
error_message=str(e),
partial_response=message,
)

# Rename map handling
Expand Down Expand Up @@ -275,11 +286,11 @@ async def stream_and_parse_llm_response(self, response: AsyncIterator[str]) -> P

logging.debug("LLM Response:")
logging.debug(message)
return ParsedLLMResponse(
message,
conversation,
[file_edit for file_edit in file_edits.values()],
interrupted,
return ParsedEdit(
full_response=message,
conversation=conversation,
file_edits=[file_edit for file_edit in file_edits.values()],
interrupted=interrupted,
)

# Ideally this would be called in this class instead of subclasses
Expand Down

0 comments on commit 7b90f43

Please sign in to comment.