Skip to content

Commit

Permalink
Add InvalidConversationError (#1565)
Browse files Browse the repository at this point in the history
  • Loading branch information
b-chu authored Oct 1, 2024
1 parent 8cf3d87 commit a462f03
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 3 deletions.
7 changes: 4 additions & 3 deletions llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]:
DatasetTooSmallError,
IncorrectMessageKeyQuantityError,
InvalidContentTypeError,
InvalidConversationError,
InvalidExampleTypeError,
InvalidFileExtensionError,
InvalidLastChatMessageRoleError,
Expand Down Expand Up @@ -270,17 +271,17 @@ def slice_out_last_turn(
if conversation_through_previous_turn != full_conversation[:len(
conversation_through_previous_turn,
)]:
raise ValueError(
raise InvalidConversationError(
f'The full conversation must start with the conversation through the previous turn. {conversation_through_previous_turn=}, {full_conversation=}',
)
if conversation_through_previous_turn != prompt_with_history[:len(
conversation_through_previous_turn,
)]:
raise ValueError(
raise InvalidConversationError(
f'The prompt_with_history must start with the conversation through the previous turn. {conversation_through_previous_turn=}, {prompt_with_history=}',
)
if prompt_with_history != full_conversation[:len(prompt_with_history)]:
raise ValueError(
raise InvalidConversationError(
f'prompt_with_history must be the first part of the full conversation. {prompt_with_history=}, {full_conversation=}',
)
prompt = prompt_with_history[len(conversation_through_previous_turn):]
Expand Down
15 changes: 15 additions & 0 deletions llmfoundry/utils/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,3 +497,18 @@ def __init__(self, files_searched: list[str]) -> None:
message,
files_searched=files_searched,
)


class InvalidConversationError(UserError):
"""Error thrown when the conversation is invalid."""

def __init__(self, message: str) -> None:
self.message = message
super().__init__(message)

def __reduce__(self):
# Return a tuple of class, a tuple of arguments, and optionally state
return (InvalidConversationError, (self.message,))

def __str__(self):
return self.message

0 comments on commit a462f03

Please sign in to comment.