Skip to content

Commit

Permalink
Address review comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
Oleksii-Klimov committed Nov 30, 2023
1 parent 3d00d81 commit ddce14c
Show file tree
Hide file tree
Showing 8 changed files with 176 additions and 72 deletions.
43 changes: 26 additions & 17 deletions aidial_assistant/application/assistant_application.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,23 +52,33 @@ def _get_request_args(request: Request) -> dict[str, str]:
return {k: v for k, v in args.items() if v is not None}


def _extract_addon_url(addon: Addon) -> str:
if addon.url is None:
raise RequestParameterValidationError(
"Missing required addon url.",
param="addons",
)

return addon.url
def _validate_addons(addons: list[Addon] | None):
if addons and any(addon.url is None for addon in addons):
for index, addon in enumerate(addons):
if addon.url is None:
raise RequestParameterValidationError(
f"Missing required addon url at index {index}.",
param="addons",
)


def _validate_messages(messages: list[Message]) -> None:
if not messages:
raise RequestParameterValidationError(
"Message list cannot be empty.", param="messages"
)

if messages[-1].role != Role.USER:
raise RequestParameterValidationError(
"Last message must be from the user.", param="messages"
)


def _validate_request(request: Request) -> None:
_validate_messages(request.messages)
_validate_addons(request.addons)


class AssistantApplication(ChatCompletion):
def __init__(self, config_dir: Path):
self.args = parse_args(config_dir)
Expand All @@ -77,6 +87,7 @@ def __init__(self, config_dir: Path):
async def chat_completion(
self, request: Request, response: Response
) -> None:
_validate_request(request)
chat_args = self.args.openai_conf.dict() | _get_request_args(request)

model = ModelClient(
Expand All @@ -89,10 +100,8 @@ async def chat_completion(
buffer_size=self.args.chat_conf.buffer_size,
)

addons = (
[_extract_addon_url(addon) for addon in request.addons]
if request.addons
else []
addons: list[str] = (
[addon.url for addon in request.addons] if request.addons else [] # type: ignore
)
token_source = AddonTokenSource(request.headers, addons)

Expand Down Expand Up @@ -122,7 +131,6 @@ async def chat_completion(
chain = CommandChain(
model_client=model, name="ASSISTANT", command_dict=command_dict
)
_validate_messages(request.messages)
history = History(
assistant_system_message_template=MAIN_SYSTEM_DIALOG_MESSAGE.build(
tools=tool_descriptions
Expand All @@ -134,9 +142,10 @@ async def chat_completion(
)
discarded_messages: int | None = None
if request.max_prompt_tokens is not None:
old_size = history.user_message_count()
history = await history.trim(request.max_prompt_tokens, model)
discarded_messages = old_size - history.user_message_count()
original_size = history.user_message_count
history = await history.truncate(request.max_prompt_tokens, model)
truncated_size = history.user_message_count
discarded_messages = original_size - truncated_size

choice = response.create_single_choice()
choice.open()
Expand All @@ -155,5 +164,5 @@ async def chat_completion(

response.set_usage(model.prompt_tokens, model.completion_tokens)

if discarded_messages:
if discarded_messages is not None:
response.set_discarded_messages(discarded_messages)
4 changes: 2 additions & 2 deletions aidial_assistant/chain/command_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def _log_messages(self, messages: list[Message]):
async def run_chat(self, history: History, callback: ChainCallback):
dialogue = Dialogue()
try:
messages = history.to_protocol_messages_with_system_message()
messages = history.to_protocol_messages()
while True:
pair = await self._run_with_protocol_failure_retries(
callback, messages + dialogue.messages
Expand All @@ -85,7 +85,7 @@ async def run_chat(self, history: History, callback: ChainCallback):
dialogue,
)
if not dialogue.is_empty()
else history.to_client_messages()
else history.to_user_messages()
)
await self._generate_result(messages, callback)
except InvalidRequestError as e:
Expand Down
70 changes: 37 additions & 33 deletions aidial_assistant/chain/history.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,29 +65,7 @@ def __init__(
)

def to_protocol_messages(self) -> list[Message]:
messages: list[Message] = []
for scoped_message in self.scoped_messages:
message = scoped_message.message
if (
scoped_message.scope == MessageScope.USER
and message.role == Role.ASSISTANT
):
# Clients see replies in plain text, but the model should understand how to reply appropriately.
content = commands_to_text(
[
CommandInvocation(
command=Reply.token(), args=[message.content]
)
]
)
messages.append(Message.assistant(content=content))
else:
messages.append(message)

return messages

def to_protocol_messages_with_system_message(self) -> list[Message]:
messages = self.to_protocol_messages()
messages = self._format_assistant_commands()
if messages[0].role == Role.SYSTEM:
messages[0] = Message.system(
self.assistant_system_message_template.render(
Expand All @@ -102,7 +80,7 @@ def to_protocol_messages_with_system_message(self) -> list[Message]:

return messages

def to_client_messages(self) -> list[Message]:
def to_user_messages(self) -> list[Message]:
return [
scoped_message.message
for scoped_message in self.scoped_messages
Expand All @@ -112,7 +90,7 @@ def to_client_messages(self) -> list[Message]:
def to_best_effort_messages(
self, error: str, dialogue: Dialogue
) -> list[Message]:
messages = self.to_client_messages()
messages = self.to_user_messages()

last_message = messages[-1]
messages[-1] = Message(
Expand All @@ -126,12 +104,15 @@ def to_best_effort_messages(

return messages

async def trim(
async def truncate(
self, max_prompt_tokens: int, model_client: ModelClient
) -> "History":
extra_results_callback = ModelExtraResultsCallback()
# TODO: This will be replaced with a dedicated truncation call on model client once implemented.
stream = model_client.agenerate(
self.to_protocol_messages(),
# It is not expected for the user to include the assistant system message overhead
# in the max_prompt_tokens parameter, as it is unknown to the user.
self._format_assistant_commands(),
extra_results_callback,
max_prompt_tokens=max_prompt_tokens,
max_tokens=1,
Expand All @@ -154,11 +135,37 @@ async def trim(

return self

def _skip_messages(self, message_count: int) -> list[ScopedMessage]:
messages = []
@property
def user_message_count(self) -> int:
return self._user_message_count

def _format_assistant_commands(self) -> list[Message]:
messages: list[Message] = []
for scoped_message in self.scoped_messages:
message = scoped_message.message
if (
scoped_message.scope == MessageScope.USER
and message.role == Role.ASSISTANT
):
# Clients see replies in plain text, but the model should understand how to reply appropriately.
content = commands_to_text(
[
CommandInvocation(
command=Reply.token(), args=[message.content]
)
]
)
messages.append(Message.assistant(content=content))
else:
messages.append(message)

return messages

def _skip_messages(self, discarded_messages: int) -> list[ScopedMessage]:
messages: list[ScopedMessage] = []
current_message = self.scoped_messages[0]
message_iterator = iter(self.scoped_messages)
for _ in range(message_count):
for _ in range(discarded_messages):
current_message = next(message_iterator)
while current_message.message.role == Role.SYSTEM:
# System messages should be kept in the history
Expand All @@ -182,6 +189,3 @@ def _skip_messages(self, message_count: int) -> list[ScopedMessage]:
messages += remaining_messages

return messages

def user_message_count(self) -> int:
return self._user_message_count
4 changes: 2 additions & 2 deletions aidial_assistant/utils/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def _to_http_exception(e: Exception) -> HTTPException:
param=e.param,
)

if isinstance(e, OpenAIError) and e.error:
if isinstance(e, OpenAIError):
http_status = e.http_status or 500
if e.error:
return HTTPException(
Expand All @@ -37,7 +37,7 @@ def _to_http_exception(e: Exception) -> HTTPException:
param=e.error.param,
)

return HTTPException(message=e.error, status_code=http_status)
return HTTPException(message=str(e), status_code=http_status)

return HTTPException(
message=str(e), status_code=500, type="internal_server_error"
Expand Down
24 changes: 13 additions & 11 deletions tests/unit_tests/chain/test_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
ReasonLengthException,
)

TRIMMING_TEST_DATA = [
TRUNCATION_TEST_DATA = [
(0, [0, 1, 2, 3, 4, 5, 6]),
(1, [0, 2, 3, 4, 5, 6]),
(2, [0, 2, 6]),
Expand All @@ -36,9 +36,11 @@ async def agenerate(


@pytest.mark.asyncio
@pytest.mark.parametrize("message_count,expected_indices", TRIMMING_TEST_DATA)
async def test_history_trimming(
message_count: int, expected_indices: list[int]
@pytest.mark.parametrize(
"discarded_messages,expected_indices", TRUNCATION_TEST_DATA
)
async def test_history_truncation(
discarded_messages: int, expected_indices: list[int]
):
history = History(
assistant_system_message_template=Template(""),
Expand All @@ -59,11 +61,11 @@ async def test_history_trimming(
],
)

side_effect = ModelSideEffect(discarded_messages=message_count)
side_effect = ModelSideEffect(discarded_messages=discarded_messages)
model_client = Mock(spec=ModelClient)
model_client.agenerate.side_effect = side_effect.agenerate

actual = await history.trim(MAX_PROMPT_TOKENS, model_client)
actual = await history.truncate(MAX_PROMPT_TOKENS, model_client)

assert (
actual.assistant_system_message_template
Expand All @@ -80,7 +82,7 @@ async def test_history_trimming(


@pytest.mark.asyncio
async def test_trimming_overflow():
async def test_truncation_overflow():
history = History(
assistant_system_message_template=Template(""),
best_effort_template=Template(""),
Expand All @@ -95,15 +97,15 @@ async def test_trimming_overflow():
model_client.agenerate.side_effect = side_effect.agenerate

with pytest.raises(Exception) as exc_info:
await history.trim(MAX_PROMPT_TOKENS, model_client)
await history.truncate(MAX_PROMPT_TOKENS, model_client)

assert (
str(exc_info.value) == "No user messages left after history truncation."
)


@pytest.mark.asyncio
async def test_trimming_with_incorrect_message_sequence():
async def test_truncation_with_incorrect_message_sequence():
history = History(
assistant_system_message_template=Template(""),
best_effort_template=Template(""),
Expand All @@ -120,7 +122,7 @@ async def test_trimming_with_incorrect_message_sequence():
model_client.agenerate.side_effect = side_effect.agenerate

with pytest.raises(Exception) as exc_info:
await history.trim(MAX_PROMPT_TOKENS, model_client)
await history.truncate(MAX_PROMPT_TOKENS, model_client)

assert (
str(exc_info.value)
Expand All @@ -144,7 +146,7 @@ def test_protocol_messages_with_system_message():
],
)

assert history.to_protocol_messages_with_system_message() == [
assert history.to_protocol_messages() == [
Message.system(f"system message={system_message}"),
Message.user(user_message),
Message.assistant(
Expand Down
10 changes: 5 additions & 5 deletions tests/unit_tests/chain/test_model_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
ReasonLengthException,
)
from aidial_assistant.utils.text import join_string
from tests.utils.async_helper import to_async_list
from tests.utils.async_helper import to_async_iterator

API_METHOD = "openai.ChatCompletion.acreate"
MODEL_ARGS = {"model": "args"}
Expand All @@ -21,7 +21,7 @@
@pytest.mark.asyncio
async def test_discarded_messages(api):
model_client = ModelClient(MODEL_ARGS, BUFFER_SIZE)
api.return_value = to_async_list(
api.return_value = to_async_iterator(
[
{
"choices": [{"delta": {"content": ""}}],
Expand All @@ -42,7 +42,7 @@ async def test_discarded_messages(api):
@pytest.mark.asyncio
async def test_content(api):
model_client = ModelClient(MODEL_ARGS, BUFFER_SIZE)
api.return_value = to_async_list(
api.return_value = to_async_iterator(
[
{"choices": [{"delta": {"content": "one, "}}]},
{"choices": [{"delta": {"content": "two, "}}]},
Expand All @@ -57,7 +57,7 @@ async def test_content(api):
@pytest.mark.asyncio
async def test_reason_length_with_usage(api):
model_client = ModelClient(MODEL_ARGS, BUFFER_SIZE)
api.return_value = to_async_list(
api.return_value = to_async_iterator(
[
{"choices": [{"delta": {"content": "text"}}]},
{
Expand All @@ -84,7 +84,7 @@ async def test_reason_length_with_usage(api):
@pytest.mark.asyncio
async def test_api_args(api):
model_client = ModelClient(MODEL_ARGS, BUFFER_SIZE)
api.return_value = to_async_list([])
api.return_value = to_async_iterator([])
messages = [
Message.system(content="a"),
Message.user(content="b"),
Expand Down
Loading

0 comments on commit ddce14c

Please sign in to comment.