Skip to content

Commit

Permalink
Merge pull request #92 from PrefectHQ/test-fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
jlowin authored Mar 29, 2023
2 parents 9c317f5 + 8ea63f7 commit a142fd9
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 64 deletions.
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ env = [
# don't load plugins
'MARVIN_LOAD_BOT_DEFAULT_PLUGINS=0',
'MARVIN_OPENAI_MODEL_TEMPERATURE=0',
# use 3.5 for tests by default
'MARVIN_OPENAI_MODEL_NAME=gpt-3.5-turbo'

]

Expand Down
26 changes: 15 additions & 11 deletions src/marvin/bots/ai_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@
{{ function_def }}
You can not see all of the function's source code. To assist you, the
user may have modified the function to return values that will help when
generating outputs. You will be provided any values returned from the
function but you should NOT assume they are actual outputs of the full
function. Treat any source code (and returned values) as preproccesing.
You can not see all of the function's source code, just its signature
and docstring. However, to assist you, the user may have modified the
function to return values that will help when generating outputs. You
will be provided any values returned from the function but you should
NOT assume they are actual outputs of the full function. Treat any
source code (and returned values) as preproccesing.
The user will give you inputs to this function and you must respond with
its result, in the appropriate form. Do not describe your process or
Expand All @@ -42,10 +44,13 @@
AI_FN_MESSAGE = jinja_env.from_string(
inspect.cleandoc(
"""
{% if input_binds %}} The user supplied the following inputs:
{%for desc in input_binds%}
{{ desc }}
{% endfor %}
{% if input_binds %}
The user supplied the following inputs:
{%for desc in input_binds%}
{{ desc }}
{% endfor %}
{% endif -%}
{% if return_value %}
Expand Down Expand Up @@ -133,7 +138,7 @@ def ai_fn_wrapper(*args, **kwargs) -> Any:
# get the function source code - it will include the @ai_fn decorator,
# which can confuse the AI, so we use regex to only get the function
# that is being decorated
function_def = inspect.getsource(fn)
function_def = inspect.cleandoc(inspect.getsource(fn))
function_def = re.search(
re.compile(r"(\bdef\b.*)", re.DOTALL), function_def
).group(0)
Expand Down Expand Up @@ -177,6 +182,5 @@ async def get_response():
return asyncio.run(get_response())

ai_fn_wrapper.fn = fn
ai_fn_wrapper.set_docstring = lambda doc: setattr(fn, "__doc__", doc)

return ai_fn_wrapper
72 changes: 38 additions & 34 deletions src/marvin/bots/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,29 +289,31 @@ async def say(self, *args, response_format=None, **kwargs) -> BotResponse:
# validate response format
parsed_response = response
validated = False
validation_attempts = 0

try:
self.response_format.validate_response(response)
validated = True
except Exception as exc:
match self.response_format.on_error:
case "ignore":
pass
case "raise":
raise exc
case "reformat":
self.logger.debug(
"Response did not pass validation. Attempted to reformat:"
f" {response}"
)
reformatted_response = _reformat_response(
user_message=user_message.content,
ai_response=response,
error_message=repr(exc),
target_return_type=self.response_format.format,
)
response = str(reformatted_response)
validated = True
while not validated and validation_attempts < 3:
validation_attempts += 1
try:
self.response_format.validate_response(response)
validated = True
except Exception as exc:
match self.response_format.on_error:
case "ignore":
validated = True
case "raise":
raise exc
case "reformat":
self.logger.debug(
"Response did not pass validation. Attempted to reformat:"
f" {response}"
)
reformatted_response = _reformat_response(
user_message=user_message.content,
ai_response=response,
error_message=repr(exc),
target_return_type=self.response_format.format,
)
response = str(reformatted_response)

if validated:
parsed_response = self.response_format.parse_response(response)
Expand Down Expand Up @@ -508,22 +510,24 @@ def _reformat_response(
) -> str:
@marvin.ai_fn(
plugins=[],
bot_modifier=lambda bot: setattr(bot.response_format, "on_error", "raise"),
bot_modifier=lambda bot: setattr(bot.response_format, "on_error", "ignore"),
)
def reformat_response(response: str) -> target_return_type:
pass

# set docstring outside of function definition so it can access local variables
reformat_response.set_docstring(
f"""
The `response` contains an answer to the prompt: "{user_message}".
def reformat_response(
response: str, user_message: str, target_return_type: str, error_message: str
) -> target_return_type:
"""
The `response` contains an answer to the `user_prompt`.
However it could not be parsed into the correct return format
({target_return_type}). The associated error message was
"{error_message}".
(`target_return_type`). The associated error message was
`error_message`.
Extract the answer from the `response` and format it to be parsed
correctly.
"""
)

return reformat_response(ai_response)
return reformat_response(
response=ai_response,
user_message=user_message,
target_return_type=target_return_type,
error_message=error_message,
)
14 changes: 7 additions & 7 deletions src/marvin/bots/response_formatters.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,18 +71,18 @@ def __init__(self, type_: type = SENTINEL, **kwargs):
kwargs.update(
type_schema=schema,
format=(
"A JSON object that matches the following Python type signature:"
f" `{format_type_str(type_)}`. Make sure your response is valid"
" JSON, so return lists instead of sets or tuples; `true` and"
" `false` instead of `True` and `False`; and `null` instead of"
" `None`."
"A valid JSON object that can be cast to the following type"
f" signature: `{format_type_str(type_)}`. Make sure your response"
" is valid JSON, so use lists instead of sets or tuples; literal"
" `true` and `false` instead of `True` and `False`; literal `null`"
" instead of `None`; and double quotes instead of single quotes."
),
)
super().__init__(**kwargs)
if type_ is not SENTINEL:
self._cached_type = type_

def get_type(self) -> type | GenericAlias | pydantic.BaseModel:
def get_type(self) -> type | GenericAlias:
if self._cached_type is not SENTINEL:
return self._cached_type

Expand All @@ -96,7 +96,7 @@ def parse_response(self, response):

# handle GenericAlias and containers
if isinstance(type_, GenericAlias):
return pydantic.parse_raw_as(self.get_type(), response)
return pydantic.parse_raw_as(type_, response)

# handle basic types
else:
Expand Down
7 changes: 4 additions & 3 deletions tests/llm_tests/bots/test_ai_functions.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional

from marvin import ai_fn
from marvin.utilities.tests import assert_llm

Expand Down Expand Up @@ -113,8 +115,7 @@ def test_extract_sentences_with_question_mark(self):
@ai_fn
def list_questions(email_body: str) -> list[str]:
"""
return any sentences that end with a question mark found in the
email_body
Returns a list of any questions in the email body.
"""

email_body = "Hi Taylor, It is nice outside today. What is your favorite color?"
Expand All @@ -137,7 +138,7 @@ def extract_colors(words: list[str]) -> set[str]:
class TestNone:
def test_none_response(self):
@ai_fn
def filter_with_none(words: list[str]) -> list[str | None]:
def filter_with_none(words: list[str]) -> list[Optional[str]]:
"""
takes a list of words and returns a list of equal length that
replaces any word except "blue" with None
Expand Down
24 changes: 15 additions & 9 deletions tests/llm_tests/bots/test_bots.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest
from marvin import Bot
from marvin.utilities.tests import assert_approx_equal
from marvin.utilities.tests import assert_llm


class TestBotResponse:
Expand All @@ -11,13 +11,13 @@ class TestBotResponse:
async def test_simple_response(self, message, expected_response):
bot = Bot()
response = await bot.say(message)
assert_approx_equal(response.content, expected_response)
assert_llm(response.content, expected_response)

async def test_memory(self):
bot = Bot()
response = await bot.say("My favorite color is blue")
response = await bot.say("What is my favorite color?")
assert_approx_equal(
assert_llm(
response.content,
"You told me that your favorite color is blue",
)
Expand All @@ -30,16 +30,22 @@ async def test_int(self):
assert response.parsed_content == 2

async def test_list_str(self):
bot = Bot(instructions="solve the math problems", response_format=list[str])
bot = Bot(
instructions="solve the math problems and return only the answer",
response_format=list[str],
)
response = await bot.say("Problem 1: 1 + 1\n\nProblem 2: 2 + 2")
assert response.parsed_content == ["2", "4"]
assert_llm(response.parsed_content, ["2", "4"])
assert isinstance(response.parsed_content, list)
assert all(isinstance(x, str) for x in response.parsed_content)

async def test_natural_language_list(self):
bot = Bot(
instructions="solve the math problems", response_format="a list of strings"
instructions="solve the math problems and return only the answer",
response_format="a list of strings",
)
response = await bot.say("Problem 1: 1 + 1\n\nProblem 2: 2 + 2")
assert response.parsed_content == ["2", "4"]
assert_llm(response.parsed_content, '["2", "4"]')

async def test_natural_language_list_2(self):
bot = Bot(instructions="list the keywords", response_format="a list of strings")
Expand All @@ -48,8 +54,8 @@ async def test_natural_language_list_2(self):

async def test_natural_language_list_with_json_keyword(self):
bot = Bot(
instructions="solve the math problems",
instructions="solve the math problems and return only the answer",
response_format="a JSON list of strings",
)
response = await bot.say("Problem 1: 1 + 1\n\nProblem 2: 2 + 2")
assert response.parsed_content == ["2", "4"]
assert_llm(response.parsed_content, ["2", "4"])

0 comments on commit a142fd9

Please sign in to comment.