Skip to content

Commit

Permalink
Address some review comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
Oleksii-Klimov committed Dec 1, 2023
1 parent be78c58 commit b77f887
Show file tree
Hide file tree
Showing 14 changed files with 199 additions and 218 deletions.
8 changes: 4 additions & 4 deletions aidial_assistant/chain/command_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@
skip_to_json_start,
)
from aidial_assistant.commands.base import Command, FinalCommand
from aidial_assistant.json_stream.characterstream import CharacterStream
from aidial_assistant.json_stream.chunked_char_stream import ChunkedCharStream
from aidial_assistant.json_stream.exceptions import JsonParsingException
from aidial_assistant.json_stream.json_node import JsonNode
from aidial_assistant.json_stream.json_parser import JsonParser
from aidial_assistant.json_stream.json_parser import parse_json
from aidial_assistant.json_stream.json_string import JsonString
from aidial_assistant.utils.stream import CumulativeStream

Expand Down Expand Up @@ -166,10 +166,10 @@ async def _run_with_protocol_failure_retries(
async def _run_commands(
self, chunk_stream: AsyncIterator[str], callback: ChainCallback
) -> Tuple[list[CommandInvocation], list[CommandResult]]:
char_stream = CharacterStream(chunk_stream)
char_stream = ChunkedCharStream(chunk_stream)
await skip_to_json_start(char_stream)

root_node = await JsonParser.parse(char_stream)
root_node = await parse_json(char_stream)
commands: list[CommandInvocation] = []
responses: list[CommandResult] = []
request_reader = CommandsReader(root_node)
Expand Down
6 changes: 3 additions & 3 deletions aidial_assistant/chain/model_response_reader.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from collections.abc import AsyncIterator

from aidial_assistant.json_stream.characterstream import AsyncPeekable
from aidial_assistant.json_stream.chunked_char_stream import ChunkedCharStream
from aidial_assistant.json_stream.json_array import JsonArray
from aidial_assistant.json_stream.json_node import JsonNode
from aidial_assistant.json_stream.json_object import JsonObject
Expand All @@ -16,12 +16,12 @@ class AssistantProtocolException(Exception):
pass


async def skip_to_json_start(stream: AsyncPeekable[str]):
async def skip_to_json_start(stream: ChunkedCharStream):
# Some models tend to provide explanations for their replies regardless of what the prompt says.
try:
while True:
char = await stream.apeek()
if char == JsonObject.token():
if JsonObject.starts_with(char):
break

await anext(stream)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,22 +1,10 @@
from abc import ABC, abstractmethod
from abc import ABC
from collections.abc import AsyncIterator
from typing import Generic, TypeVar

from typing_extensions import override

T = TypeVar("T")


class AsyncPeekable(ABC, Generic[T], AsyncIterator[T]):
@abstractmethod
async def apeek(self) -> T:
pass

async def askip(self) -> None:
await anext(self)


class CharacterStream(AsyncPeekable[str]):
class ChunkedCharStream(ABC, AsyncIterator[str]):
def __init__(self, source: AsyncIterator[str]):
self._source = source
self._chunk: str = ""
Expand All @@ -33,14 +21,25 @@ async def __anext__(self) -> str:
self._next_char_offset += 1
return result

@override
async def apeek(self) -> str:
while self._next_char_offset == len(self._chunk):
self._chunk_position += len(self._chunk)
self._chunk = await anext(self._source) # type: ignore
self._next_char_offset = 0
return self._chunk[self._next_char_offset]

async def askip(self):
await anext(self)

async def skip_whitespaces(self) -> "ChunkedCharStream":
while True:
char = await self.apeek()
if not str.isspace(char):
break
await self.askip()

return self

@property
def chunk_position(self) -> int:
return self._chunk_position
Expand Down
19 changes: 12 additions & 7 deletions aidial_assistant/json_stream/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
class JsonParsingException(Exception):
pass
def __init__(self, message: str, char_position: int):
super().__init__(
f"Failed to parse json string at position {char_position}: {message}"
)


def unexpected_symbol_error(
char: str, char_position: int
) -> JsonParsingException:
return JsonParsingException(
f"Failed to parse json string: unexpected symbol {char} at position {char_position}"
)
return JsonParsingException(f"Unexpected symbol {char}.", char_position)


def unexpected_end_of_stream_error(char_position: int) -> JsonParsingException:
return JsonParsingException(
f"Failed to parse json string: unexpected end of stream at position {char_position}"
)
return JsonParsingException("Unexpected end of stream.", char_position)


def invalid_sequence_error(
sequence: str, char_position: int
) -> JsonParsingException:
return JsonParsingException(f"Invalid sequence {sequence}.", char_position)
54 changes: 25 additions & 29 deletions aidial_assistant/json_stream/json_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,19 @@

from typing_extensions import override

from aidial_assistant.json_stream.characterstream import CharacterStream
from aidial_assistant.json_stream.chunked_char_stream import ChunkedCharStream
from aidial_assistant.json_stream.exceptions import (
unexpected_end_of_stream_error,
unexpected_symbol_error,
)
from aidial_assistant.json_stream.json_node import (
CompoundNode,
JsonNode,
NodeResolver,
ReadableNode,
)
from aidial_assistant.json_stream.json_normalizer import JsonNormalizer


class JsonArray(ReadableNode[list[Any], JsonNode]):
class JsonArray(CompoundNode[list[Any], JsonNode]):
def __init__(self, source: AsyncIterator[JsonNode], char_position: int):
super().__init__(source, char_position)
self._array: list[JsonNode] = []
Expand All @@ -25,55 +24,50 @@ def __init__(self, source: AsyncIterator[JsonNode], char_position: int):
def type(self) -> str:
return "array"

@staticmethod
def token() -> str:
return "["

@staticmethod
async def read(
stream: CharacterStream, dependency_resolver: NodeResolver
stream: ChunkedCharStream, node_resolver: NodeResolver
) -> AsyncIterator[JsonNode]:
try:
normalised_stream = JsonNormalizer(stream)
char = await anext(normalised_stream)
if not char == JsonArray.token():
char = await anext(await stream.skip_whitespaces())
if not JsonArray.starts_with(char):
raise unexpected_symbol_error(char, stream.char_position)

separate = False
is_comma_expected = False
while True:
char = await normalised_stream.apeek()
char = await (await stream.skip_whitespaces()).apeek()
if char == "]":
await anext(normalised_stream)
await stream.askip()
break

if char == ",":
if not separate:
if not is_comma_expected:
raise unexpected_symbol_error(
char, stream.char_position
)

await anext(normalised_stream)
separate = False
await stream.askip()
is_comma_expected = False
else:
value = await dependency_resolver.resolve(stream)
value = await node_resolver.resolve(stream)
yield value

if isinstance(value, ReadableNode):
if isinstance(value, CompoundNode):
await value.read_to_end()
separate = True
is_comma_expected = True
except StopAsyncIteration:
raise unexpected_end_of_stream_error(stream.char_position)

@override
async def to_string_chunks(self) -> AsyncIterator[str]:
yield JsonArray.token()
separate = False
yield "["
is_comma_expected = False
async for value in self:
if separate:
if is_comma_expected:
yield ", "
async for chunk in value.to_string_chunks():
yield chunk
separate = True
is_comma_expected = True
yield "]"

@override
Expand All @@ -86,8 +80,10 @@ def _accumulate(self, element: JsonNode):

@classmethod
def parse(
cls, stream: CharacterStream, dependency_resolver: NodeResolver
cls, stream: ChunkedCharStream, node_resolver: NodeResolver
) -> "JsonArray":
return cls(
JsonArray.read(stream, dependency_resolver), stream.char_position
)
return cls(JsonArray.read(stream, node_resolver), stream.char_position)

@staticmethod
def starts_with(char: str) -> bool:
return char == "["
28 changes: 16 additions & 12 deletions aidial_assistant/json_stream/json_bool.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,35 @@
import json

from typing_extensions import override

from aidial_assistant.json_stream.json_node import PrimitiveNode
from aidial_assistant.json_stream.exceptions import invalid_sequence_error
from aidial_assistant.json_stream.json_node import AtomicNode

TRUE_STRING = "true"
FALSE_STRING = "false"


class JsonBoolean(PrimitiveNode[bool]):
class JsonBoolean(AtomicNode[bool]):
def __init__(self, raw_data: str, char_position: int):
super().__init__(char_position)
self._raw_data = raw_data
self._value: bool = json.loads(raw_data)
super().__init__(raw_data, char_position)
self._value: bool = JsonBoolean._parse_boolean(raw_data, char_position)

@override
def type(self) -> str:
return "boolean"

@override
def raw_data(self) -> str:
return self._raw_data

@override
def value(self) -> bool:
return self._value

@staticmethod
def is_bool(char: str) -> bool:
def starts_with(char: str) -> bool:
return char == "t" or char == "f"

@staticmethod
def _parse_boolean(string: str, char_position: int) -> bool:
if string == TRUE_STRING:
return True

if string == FALSE_STRING:
return False

raise invalid_sequence_error(string, char_position)
23 changes: 14 additions & 9 deletions aidial_assistant/json_stream/json_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@

from typing_extensions import override

from aidial_assistant.json_stream.characterstream import CharacterStream
from aidial_assistant.json_stream.chunked_char_stream import ChunkedCharStream
from aidial_assistant.json_stream.exceptions import (
unexpected_end_of_stream_error,
)


class NodeResolver(ABC):
@abstractmethod
async def resolve(self, stream: CharacterStream) -> "JsonNode":
async def resolve(self, stream: ChunkedCharStream) -> "JsonNode":
pass


Expand Down Expand Up @@ -41,7 +41,7 @@ def value(self) -> TValue:
pass


class ReadableNode(
class CompoundNode(
JsonNode[TValue], AsyncIterator[TElement], ABC, Generic[TValue, TElement]
):
def __init__(self, source: AsyncIterator[TElement], char_position: int):
Expand All @@ -68,17 +68,22 @@ async def read_to_end(self):
pass


class PrimitiveNode(JsonNode[TValue], ABC, Generic[TValue]):
@abstractmethod
def raw_data(self) -> str:
pass
class AtomicNode(JsonNode[TValue], ABC, Generic[TValue]):
def __init__(self, raw_data: str, char_position: int):
super().__init__(char_position)
self._raw_data = raw_data

@override
async def to_string_chunks(self) -> AsyncIterator[str]:
yield self.raw_data()
yield self._raw_data

@classmethod
async def parse(cls, stream: ChunkedCharStream) -> "AtomicNode":
position = stream.char_position
return cls(await AtomicNode._read_all(stream), position)

@staticmethod
async def collect(stream: CharacterStream) -> str:
async def _read_all(stream: ChunkedCharStream) -> str:
try:
raw_data = ""
while True:
Expand Down
30 changes: 0 additions & 30 deletions aidial_assistant/json_stream/json_normalizer.py

This file was deleted.

Loading

0 comments on commit b77f887

Please sign in to comment.