Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: migrate latest fixes #18

Merged
merged 5 commits into from
Nov 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions aidial_assistant/application/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +0,0 @@
import logging

logger = logging.getLogger(__name__)
37 changes: 26 additions & 11 deletions aidial_assistant/application/assistant_application.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,26 @@
import logging
from pathlib import Path

from aidial_sdk import HTTPException
from aidial_sdk.chat_completion import FinishReason
from aidial_sdk.chat_completion.base import ChatCompletion
from aidial_sdk.chat_completion.request import Addon, Request
from aidial_sdk.chat_completion.response import Response
from aiohttp import hdrs
from openai import InvalidRequestError, OpenAIError

from aidial_assistant.application import logger
from aidial_assistant.application.args import parse_args
from aidial_assistant.application.prompts import (
MAIN_SYSTEM_DIALOG_MESSAGE,
RESP_DIALOG_PROMPT,
)
from aidial_assistant.application.server_callback import ServerChainCallback
from aidial_assistant.chain.command_chain import CommandChain, CommandDict
from aidial_assistant.chain.model_client import ModelClient, UsagePublisher
from aidial_assistant.chain.model_client import (
ModelClient,
ReasonLengthException,
UsagePublisher,
)
from aidial_assistant.commands.reply import Reply
from aidial_assistant.commands.run_plugin import PluginInfo, RunPlugin
from aidial_assistant.utils.open_ai_plugin import (
Expand All @@ -25,6 +30,8 @@
)
from aidial_assistant.utils.state import get_system_prefix, parse_history

logger = logging.getLogger(__name__)


def get_request_args(request: Request) -> dict[str, str]:
args = {
Expand Down Expand Up @@ -112,19 +119,27 @@ async def chat_completion(
request.messages,
system_message,
)
with response.create_single_choice() as choice:
callback = ServerChainCallback(choice)
try:
await chain.run_chat(history, callback, usage_publisher)
except OpenAIError as e:
logger.exception("Request processing has failed.")
choice = response.create_single_choice()
choice.open()

callback = ServerChainCallback(choice)
finish_reason = FinishReason.STOP
try:
await chain.run_chat(history, callback, usage_publisher)
except ReasonLengthException:
finish_reason = FinishReason.LENGTH
except OpenAIError as e:
if e.error:
raise HTTPException(
str(e),
e.error.message,
status_code=e.http_status or 500,
code=e.code,
code=e.error.code,
)

choice.set_state(callback.state)
raise

choice.set_state(callback.state)
choice.close(finish_reason)

response.set_usage(
usage_publisher.prompt_tokens, usage_publisher.completion_tokens
Expand Down
3 changes: 0 additions & 3 deletions aidial_assistant/chain/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +0,0 @@
import logging

logger = logging.getLogger(__name__)
40 changes: 15 additions & 25 deletions aidial_assistant/chain/command_chain.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import json
import logging
from typing import Any, AsyncIterator, Callable, List

from aidial_sdk.chat_completion.request import Message, Role
from jinja2 import Template

from aidial_assistant.chain import logger
from aidial_assistant.chain.callbacks.chain_callback import ChainCallback
from aidial_assistant.chain.callbacks.command_callback import CommandCallback
from aidial_assistant.chain.callbacks.result_callback import ResultCallback
Expand All @@ -19,15 +19,15 @@
CommandsReader,
)
from aidial_assistant.commands.base import Command, FinalCommand
from aidial_assistant.json_stream.json_node import (
JsonNode,
JsonParsingException,
)
from aidial_assistant.json_stream.exceptions import JsonParsingException
from aidial_assistant.json_stream.json_node import JsonNode
from aidial_assistant.json_stream.json_object import JsonObject
from aidial_assistant.json_stream.json_parser import JsonParser
from aidial_assistant.json_stream.json_string import JsonString
from aidial_assistant.json_stream.tokenator import AsyncPeekable, Tokenator

logger = logging.getLogger(__name__)

MAX_MESSAGE_COUNT = 20
MAX_RETRY_COUNT = 2

Expand Down Expand Up @@ -70,7 +70,7 @@ async def run_chat(
history: List[Message],
callback: ChainCallback,
usage_publisher: UsagePublisher,
) -> str:
):
for message in history[:-1]:
self._log_message(message.role, message.content)

Expand All @@ -95,20 +95,17 @@ async def run_chat(
if isinstance(command, FinalCommand):
if len(responses) > 0:
continue
arg = await anext(args)
result = await CommandChain._to_result(
arg
if isinstance(arg, JsonString)
else arg.to_string_tokens(),
message = await anext(args)
await CommandChain._to_result(
message
if isinstance(message, JsonString)
else message.to_string_tokens(),
# Some relatively large number to avoid CxSAST warning about potential DoS attack.
# Later, the upper limit will be provided by the DIAL Core (proxy).
32000,
callback.result_callback(),
)
self._log_message(
Role.ASSISTANT, json.dumps(root_node.value())
)
return result
return
else:
response = await CommandChain._execute_command(
command_name, command, args, callback
Expand All @@ -118,11 +115,7 @@ async def run_chat(
responses.append(response)

if len(responses) == 0:
# Assume the model has nothing to say
self._log_message(
Role.ASSISTANT, json.dumps(root_node.value())
)
return ""
return

normalized_model_response = json.dumps({"commands": commands})
history.append(
Expand Down Expand Up @@ -205,19 +198,16 @@ async def _to_result(
arg: AsyncIterator[str],
max_model_completion_tokens: int,
callback: ResultCallback,
) -> str:
result = ""
):
try:
for _ in range(max_model_completion_tokens):
token = await anext(arg)
callback.on_result(token)
result += token
logger.warn(
logger.warning(
f"Max token count of {max_model_completion_tokens} exceeded in the reply"
)
except StopAsyncIteration:
pass
return result

@staticmethod
async def _execute_command(
Expand Down
10 changes: 9 additions & 1 deletion aidial_assistant/chain/model_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
from aiohttp import ClientSession


class ReasonLengthException(Exception):
pass


class UsagePublisher:
def __init__(self):
self.total_usage = defaultdict(int)
Expand Down Expand Up @@ -55,6 +59,10 @@ async def agenerate(
if usage:
usage_publisher.publish(usage)

text = chunk["choices"][0]["delta"].get("content")
choice = chunk["choices"][0]
text = choice["delta"].get("content")
if text:
yield text

if choice.get("finish_reason") == "length":
raise ReasonLengthException()
11 changes: 10 additions & 1 deletion aidial_assistant/commands/plugin_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,15 @@ def on_result(self, token):
class PluginChainCallback(ChainCallback):
def __init__(self, callback: Callable[[str], None]):
self.callback = callback
self._result = ""

@override
def command_callback(self) -> PluginCommandCallback:
return PluginCommandCallback(self.callback)

@override
def result_callback(self) -> ResultCallback:
return PluginResultCallback(self.callback)
return PluginResultCallback(self._on_result)

@override
def on_state(self, request: str, response: str):
Expand All @@ -79,3 +80,11 @@ def on_state(self, request: str, response: str):
@override
def on_error(self, title: str, error: Exception):
pass

@property
def result(self) -> str:
return self._result

def _on_result(self, token):
self._result += token
self.callback(token)
21 changes: 12 additions & 9 deletions aidial_assistant/commands/run_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,16 @@
CommandChain,
CommandConstructor,
)
from aidial_assistant.chain.model_client import ModelClient, UsagePublisher
from aidial_assistant.chain.model_client import (
ModelClient,
ReasonLengthException,
UsagePublisher,
)
from aidial_assistant.commands.base import (
Command,
ExecutionCallback,
JsonResult,
ResultObject,
TextResult,
)
from aidial_assistant.commands.open_api import OpenAPIChatCommand
from aidial_assistant.commands.plugin_callback import PluginChainCallback
Expand Down Expand Up @@ -111,10 +115,9 @@ def create_command(op: APIOperation):
command_dict=command_dict,
)

return JsonResult(
await chat.run_chat(
init_messages,
PluginChainCallback(execution_callback),
usage_publisher,
)
)
callback = PluginChainCallback(execution_callback)
try:
await chat.run_chat(init_messages, callback, usage_publisher)
return TextResult(callback.result)
except ReasonLengthException:
return TextResult(callback.result)
3 changes: 0 additions & 3 deletions aidial_assistant/json_stream/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +0,0 @@
import logging

logger = logging.getLogger(__name__)
10 changes: 10 additions & 0 deletions aidial_assistant/json_stream/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
class JsonParsingException(Exception):
pass


def unexpected_symbol_error(
char: str, char_position: int
) -> JsonParsingException:
return JsonParsingException(
f"Failed to parse json string: unexpected symbol {char} at position {char_position}"
)
65 changes: 30 additions & 35 deletions aidial_assistant/json_stream/json_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@

from typing_extensions import override

from aidial_assistant.json_stream.exceptions import unexpected_symbol_error
from aidial_assistant.json_stream.json_node import (
ComplexNode,
JsonNode,
NodeResolver,
unexpected_symbol_error,
)
from aidial_assistant.json_stream.json_normalizer import JsonNormalizer
from aidial_assistant.json_stream.tokenator import Tokenator
Expand All @@ -17,7 +17,7 @@
class JsonArray(ComplexNode[list[Any]], AsyncIterator[JsonNode]):
def __init__(self, char_position: int):
super().__init__(char_position)
self.listener = Queue[JsonNode | None | BaseException]()
self.listener = Queue[JsonNode | None]()
self.array: list[JsonNode] = []

@override
Expand All @@ -34,7 +34,7 @@ def __aiter__(self) -> AsyncIterator[JsonNode]:

@override
async def __anext__(self) -> JsonNode:
result = ComplexNode.throw_if_exception(await self.listener.get())
result = await self.listener.get()
if result is None:
raise StopAsyncIteration

Expand All @@ -43,38 +43,33 @@ async def __anext__(self) -> JsonNode:

@override
async def parse(self, stream: Tokenator, dependency_resolver: NodeResolver):
try:
normalised_stream = JsonNormalizer(stream)
char = await anext(normalised_stream)
self._char_position = stream.char_position
if not char == JsonArray.token():
raise unexpected_symbol_error(char, stream.char_position)

separate = False
while True:
char = await normalised_stream.apeek()
if char == "]":
await anext(normalised_stream)
break

if char == ",":
if not separate:
raise unexpected_symbol_error(
char, stream.char_position
)

await anext(normalised_stream)
separate = False
else:
value = await dependency_resolver.resolve(stream)
await self.listener.put(value)
if isinstance(value, ComplexNode):
await value.parse(stream, dependency_resolver)
separate = True

await self.listener.put(None)
except BaseException as e:
await self.listener.put(e)
normalised_stream = JsonNormalizer(stream)
char = await anext(normalised_stream)
self._char_position = stream.char_position
if not char == JsonArray.token():
raise unexpected_symbol_error(char, stream.char_position)

separate = False
while True:
char = await normalised_stream.apeek()
if char == "]":
await anext(normalised_stream)
break

if char == ",":
if not separate:
raise unexpected_symbol_error(char, stream.char_position)

await anext(normalised_stream)
separate = False
else:
value = await dependency_resolver.resolve(stream)
await self.listener.put(value)
if isinstance(value, ComplexNode):
await value.parse(stream, dependency_resolver)
separate = True

await self.listener.put(None)

@override
async def to_string_tokens(self) -> AsyncIterator[str]:
Expand Down
Loading