Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: read json lazily to tolerate incorrect closing brackets #37

Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion aidial_assistant/application/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def build(self, **kwargs) -> Template:
{
"command": "<command name>",
"args": [
// <array of arguments>
"<arg1>", "<arg2>", ...
]
}
]
Expand Down
62 changes: 31 additions & 31 deletions aidial_assistant/chain/command_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
skip_to_json_start,
)
from aidial_assistant.commands.base import Command, FinalCommand
from aidial_assistant.json_stream.characterstream import CharacterStream
from aidial_assistant.json_stream.chunked_char_stream import ChunkedCharStream
from aidial_assistant.json_stream.exceptions import JsonParsingException
from aidial_assistant.json_stream.json_node import JsonNode
from aidial_assistant.json_stream.json_parser import JsonParser
Expand Down Expand Up @@ -166,39 +166,39 @@ async def _run_with_protocol_failure_retries(
async def _run_commands(
self, chunk_stream: AsyncIterator[str], callback: ChainCallback
) -> Tuple[list[CommandInvocation], list[CommandResult]]:
char_stream = CharacterStream(chunk_stream)
char_stream = ChunkedCharStream(chunk_stream)
await skip_to_json_start(char_stream)

async with JsonParser.parse(char_stream) as root_node:
commands: list[CommandInvocation] = []
responses: list[CommandResult] = []
request_reader = CommandsReader(root_node)
async for invocation in request_reader.parse_invocations():
command_name = await invocation.parse_name()
command = self._create_command(command_name)
args = invocation.parse_args()
if isinstance(command, FinalCommand):
if len(responses) > 0:
continue
message = await anext(args)
await CommandChain._to_result(
message
if isinstance(message, JsonString)
else message.to_string_chunks(),
callback.result_callback(),
)
break
else:
response = await CommandChain._execute_command(
command_name, command, args, callback
)
root_node = await JsonParser().parse(char_stream)
commands: list[CommandInvocation] = []
responses: list[CommandResult] = []
request_reader = CommandsReader(root_node)
async for invocation in request_reader.parse_invocations():
command_name = await invocation.parse_name()
command = self._create_command(command_name)
args = invocation.parse_args()
if isinstance(command, FinalCommand):
if len(responses) > 0:
continue
message = await anext(args)
await CommandChain._to_result(
message
if isinstance(message, JsonString)
else message.to_chunks(),
callback.result_callback(),
)
break
else:
response = await CommandChain._execute_command(
command_name, command, args, callback
)

commands.append(
cast(CommandInvocation, invocation.node.value())
)
responses.append(response)
commands.append(
cast(CommandInvocation, invocation.node.value())
)
responses.append(response)

return commands, responses
return commands, responses

def _create_command(self, name: str) -> Command:
if name not in self.command_dict:
Expand Down Expand Up @@ -237,7 +237,7 @@ async def _to_args(
arg_callback = args_callback.arg_callback()
arg_callback.on_arg_start()
result = ""
async for chunk in arg.to_string_chunks():
async for chunk in arg.to_chunks():
arg_callback.on_arg(chunk)
result += chunk
arg_callback.on_arg_end()
Expand Down
6 changes: 3 additions & 3 deletions aidial_assistant/chain/model_response_reader.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from collections.abc import AsyncIterator

from aidial_assistant.json_stream.characterstream import AsyncPeekable
from aidial_assistant.json_stream.chunked_char_stream import ChunkedCharStream
from aidial_assistant.json_stream.json_array import JsonArray
from aidial_assistant.json_stream.json_node import JsonNode
from aidial_assistant.json_stream.json_object import JsonObject
Expand All @@ -16,12 +16,12 @@ class AssistantProtocolException(Exception):
pass


async def skip_to_json_start(stream: AsyncPeekable[str]):
async def skip_to_json_start(stream: ChunkedCharStream):
# Some models tend to provide explanations for their replies regardless of what the prompt says.
try:
while True:
char = await stream.apeek()
if char == JsonObject.token():
if JsonObject.starts_with(char):
break

await anext(stream)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,22 +1,10 @@
from abc import ABC, abstractmethod
from abc import ABC
from collections.abc import AsyncIterator
from typing import Generic, TypeVar

from typing_extensions import override

T = TypeVar("T")


class AsyncPeekable(ABC, Generic[T], AsyncIterator[T]):
@abstractmethod
async def apeek(self) -> T:
pass

async def askip(self) -> None:
await anext(self)


class CharacterStream(AsyncPeekable[str]):
class ChunkedCharStream(ABC, AsyncIterator[str]):
def __init__(self, source: AsyncIterator[str]):
self._source = source
self._chunk: str = ""
Expand All @@ -33,18 +21,29 @@ async def __anext__(self) -> str:
self._next_char_offset += 1
return result

@override
async def apeek(self) -> str:
while self._next_char_offset == len(self._chunk):
self._chunk_position += len(self._chunk)
self._chunk = await anext(self._source) # type: ignore
self._next_char_offset = 0
return self._chunk[self._next_char_offset]

async def askip(self):
await anext(self)

@property
def chunk_position(self) -> int:
return self._chunk_position

@property
def char_position(self) -> int:
return self._chunk_position + self._next_char_offset


async def skip_whitespaces(stream: ChunkedCharStream):
while True:
char = await stream.apeek()
if not str.isspace(char):
break

await stream.askip()
19 changes: 15 additions & 4 deletions aidial_assistant/json_stream/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,21 @@
class JsonParsingException(Exception):
pass
def __init__(self, message: str, char_position: int):
super().__init__(
f"Failed to parse json string at position {char_position}: {message}"
)


def unexpected_symbol_error(
char: str, char_position: int
) -> JsonParsingException:
return JsonParsingException(
f"Failed to parse json string: unexpected symbol {char} at position {char_position}"
)
return JsonParsingException(f"Unexpected symbol {char}.", char_position)
adubovik marked this conversation as resolved.
Show resolved Hide resolved


def unexpected_end_of_stream_error(char_position: int) -> JsonParsingException:
return JsonParsingException("Unexpected end of stream.", char_position)


def invalid_sequence_error(
sequence: str, char_position: int
) -> JsonParsingException:
return JsonParsingException(f"Invalid sequence {sequence}.", char_position)
adubovik marked this conversation as resolved.
Show resolved Hide resolved
134 changes: 69 additions & 65 deletions aidial_assistant/json_stream/json_array.py
Original file line number Diff line number Diff line change
@@ -1,90 +1,94 @@
from asyncio import Queue
from collections.abc import AsyncIterator
from typing import Any

from typing_extensions import override

from aidial_assistant.json_stream.characterstream import CharacterStream
from aidial_assistant.json_stream.exceptions import unexpected_symbol_error
from aidial_assistant.json_stream.chunked_char_stream import (
ChunkedCharStream,
skip_whitespaces,
)
from aidial_assistant.json_stream.exceptions import (
unexpected_end_of_stream_error,
unexpected_symbol_error,
)
from aidial_assistant.json_stream.json_node import (
ComplexNode,
CompoundNode,
JsonNode,
NodeResolver,
NodeParser,
)
from aidial_assistant.json_stream.json_normalizer import JsonNormalizer


class JsonArray(ComplexNode[list[Any]], AsyncIterator[JsonNode]):
def __init__(self, char_position: int):
super().__init__(char_position)
self.listener = Queue[JsonNode | None]()
self.array: list[JsonNode] = []
class JsonArray(CompoundNode[list[Any], JsonNode]):
def __init__(self, source: AsyncIterator[JsonNode], pos: int):
super().__init__(source, pos)
self._array: list[JsonNode] = []

@override
def type(self) -> str:
return "array"

@staticmethod
def token() -> str:
return "["

@override
def __aiter__(self) -> AsyncIterator[JsonNode]:
return self

@override
async def __anext__(self) -> JsonNode:
result = await self.listener.get()
if result is None:
raise StopAsyncIteration

self.array.append(result)
return result

@override
async def parse(
self, stream: CharacterStream, dependency_resolver: NodeResolver
):
normalised_stream = JsonNormalizer(stream)
char = await anext(normalised_stream)
self._char_position = stream.char_position
if not char == JsonArray.token():
raise unexpected_symbol_error(char, stream.char_position)

separate = False
while True:
char = await normalised_stream.apeek()
if char == "]":
await anext(normalised_stream)
break

if char == ",":
if not separate:
raise unexpected_symbol_error(char, stream.char_position)

await anext(normalised_stream)
separate = False
else:
value = await dependency_resolver.resolve(stream)
await self.listener.put(value)
if isinstance(value, ComplexNode):
await value.parse(stream, dependency_resolver)
separate = True

await self.listener.put(None)
async def read(
stream: ChunkedCharStream, node_parser: NodeParser
) -> AsyncIterator[JsonNode]:
try:
await skip_whitespaces(stream)
char = await anext(stream)
if not JsonArray.starts_with(char):
raise unexpected_symbol_error(char, stream.char_position)

is_comma_expected = False
while True:
await skip_whitespaces(stream)
char = await stream.apeek()
if char == "]":
await stream.askip()
break

if char == ",":
if not is_comma_expected:
raise unexpected_symbol_error(
char, stream.char_position
)

await stream.askip()
is_comma_expected = False
else:
value = await node_parser.parse(stream)
yield value

if isinstance(value, CompoundNode):
await value.read_to_end()
is_comma_expected = True
except StopAsyncIteration:
raise unexpected_end_of_stream_error(stream.char_position)

@override
adubovik marked this conversation as resolved.
Show resolved Hide resolved
async def to_string_chunks(self) -> AsyncIterator[str]:
yield JsonArray.token()
separate = False
async def to_chunks(self) -> AsyncIterator[str]:
yield "["
is_first_element = True
async for value in self:
if separate:
if not is_first_element:
yield ", "
async for chunk in value.to_string_chunks():
async for chunk in value.to_chunks():
yield chunk
separate = True
is_first_element = False
yield "]"

@override
def value(self) -> list[JsonNode]:
return [item.value() for item in self.array]
return [item.value() for item in self._array]

@override
def _accumulate(self, element: JsonNode):
self._array.append(element)

@classmethod
def parse(
cls, stream: ChunkedCharStream, node_parser: NodeParser
) -> "JsonArray":
return cls(JsonArray.read(stream, node_parser), stream.char_position)

@staticmethod
def starts_with(char: str) -> bool:
return char == "["
Loading