diff --git a/pyproject.toml b/pyproject.toml index 0397afcea..18ba4b7af 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -107,6 +107,8 @@ env = [ # don't load plugins 'MARVIN_LOAD_BOT_DEFAULT_PLUGINS=0', 'MARVIN_OPENAI_MODEL_TEMPERATURE=0', + # use 3.5 for tests by default + 'MARVIN_OPENAI_MODEL_NAME=gpt-3.5-turbo' ] diff --git a/src/marvin/bots/ai_functions.py b/src/marvin/bots/ai_functions.py index a7676c4a0..375d685c6 100644 --- a/src/marvin/bots/ai_functions.py +++ b/src/marvin/bots/ai_functions.py @@ -15,11 +15,13 @@ {{ function_def }} - You can not see all of the function's source code. To assist you, the - user may have modified the function to return values that will help when - generating outputs. You will be provided any values returned from the - function but you should NOT assume they are actual outputs of the full - function. Treat any source code (and returned values) as preproccesing. + + You can not see all of the function's source code, just its signature + and docstring. However, to assist you, the user may have modified the + function to return values that will help when generating outputs. You + will be provided any values returned from the function but you should + NOT assume they are actual outputs of the full function. Treat any + source code (and returned values) as preproccesing. The user will give you inputs to this function and you must respond with its result, in the appropriate form. Do not describe your process or @@ -42,10 +44,13 @@ AI_FN_MESSAGE = jinja_env.from_string( inspect.cleandoc( """ - {% if input_binds %}} The user supplied the following inputs: - {%for desc in input_binds%} - {{ desc }} - {% endfor %} + {% if input_binds %} + The user supplied the following inputs: + + {%for desc in input_binds%} + {{ desc }} + + {% endfor %} {% endif -%} {% if return_value %} @@ -133,7 +138,7 @@ def ai_fn_wrapper(*args, **kwargs) -> Any: # get the function source code - it will include the @ai_fn decorator, # which can confuse the AI, so we use regex to only get the function # that is being decorated - function_def = inspect.getsource(fn) + function_def = inspect.cleandoc(inspect.getsource(fn)) function_def = re.search( re.compile(r"(\bdef\b.*)", re.DOTALL), function_def ).group(0) @@ -177,6 +182,5 @@ async def get_response(): return asyncio.run(get_response()) ai_fn_wrapper.fn = fn - ai_fn_wrapper.set_docstring = lambda doc: setattr(fn, "__doc__", doc) return ai_fn_wrapper diff --git a/src/marvin/bots/base.py b/src/marvin/bots/base.py index 9808207d0..c81c6d96f 100644 --- a/src/marvin/bots/base.py +++ b/src/marvin/bots/base.py @@ -289,29 +289,31 @@ async def say(self, *args, response_format=None, **kwargs) -> BotResponse: # validate response format parsed_response = response validated = False + validation_attempts = 0 - try: - self.response_format.validate_response(response) - validated = True - except Exception as exc: - match self.response_format.on_error: - case "ignore": - pass - case "raise": - raise exc - case "reformat": - self.logger.debug( - "Response did not pass validation. Attempted to reformat:" - f" {response}" - ) - reformatted_response = _reformat_response( - user_message=user_message.content, - ai_response=response, - error_message=repr(exc), - target_return_type=self.response_format.format, - ) - response = str(reformatted_response) - validated = True + while not validated and validation_attempts < 3: + validation_attempts += 1 + try: + self.response_format.validate_response(response) + validated = True + except Exception as exc: + match self.response_format.on_error: + case "ignore": + validated = True + case "raise": + raise exc + case "reformat": + self.logger.debug( + "Response did not pass validation. Attempted to reformat:" + f" {response}" + ) + reformatted_response = _reformat_response( + user_message=user_message.content, + ai_response=response, + error_message=repr(exc), + target_return_type=self.response_format.format, + ) + response = str(reformatted_response) if validated: parsed_response = self.response_format.parse_response(response) @@ -508,22 +510,24 @@ def _reformat_response( ) -> str: @marvin.ai_fn( plugins=[], - bot_modifier=lambda bot: setattr(bot.response_format, "on_error", "raise"), + bot_modifier=lambda bot: setattr(bot.response_format, "on_error", "ignore"), ) - def reformat_response(response: str) -> target_return_type: - pass - - # set docstring outside of function definition so it can access local variables - reformat_response.set_docstring( - f""" - The `response` contains an answer to the prompt: "{user_message}". + def reformat_response( + response: str, user_message: str, target_return_type: str, error_message: str + ) -> target_return_type: + """ + The `response` contains an answer to the `user_prompt`. However it could not be parsed into the correct return format - ({target_return_type}). The associated error message was - "{error_message}". + (`target_return_type`). The associated error message was + `error_message`. Extract the answer from the `response` and format it to be parsed correctly. """ - ) - return reformat_response(ai_response) + return reformat_response( + response=ai_response, + user_message=user_message, + target_return_type=target_return_type, + error_message=error_message, + ) diff --git a/src/marvin/bots/response_formatters.py b/src/marvin/bots/response_formatters.py index 2cbbf82a9..aa048c62a 100644 --- a/src/marvin/bots/response_formatters.py +++ b/src/marvin/bots/response_formatters.py @@ -71,18 +71,18 @@ def __init__(self, type_: type = SENTINEL, **kwargs): kwargs.update( type_schema=schema, format=( - "A JSON object that matches the following Python type signature:" - f" `{format_type_str(type_)}`. Make sure your response is valid" - " JSON, so return lists instead of sets or tuples; `true` and" - " `false` instead of `True` and `False`; and `null` instead of" - " `None`." + "A valid JSON object that can be cast to the following type" + f" signature: `{format_type_str(type_)}`. Make sure your response" + " is valid JSON, so use lists instead of sets or tuples; literal" + " `true` and `false` instead of `True` and `False`; literal `null`" + " instead of `None`; and double quotes instead of single quotes." ), ) super().__init__(**kwargs) if type_ is not SENTINEL: self._cached_type = type_ - def get_type(self) -> type | GenericAlias | pydantic.BaseModel: + def get_type(self) -> type | GenericAlias: if self._cached_type is not SENTINEL: return self._cached_type @@ -96,7 +96,7 @@ def parse_response(self, response): # handle GenericAlias and containers if isinstance(type_, GenericAlias): - return pydantic.parse_raw_as(self.get_type(), response) + return pydantic.parse_raw_as(type_, response) # handle basic types else: diff --git a/tests/llm_tests/bots/test_ai_functions.py b/tests/llm_tests/bots/test_ai_functions.py index ab1d548dc..5499d1427 100644 --- a/tests/llm_tests/bots/test_ai_functions.py +++ b/tests/llm_tests/bots/test_ai_functions.py @@ -1,3 +1,5 @@ +from typing import Optional + from marvin import ai_fn from marvin.utilities.tests import assert_llm @@ -113,8 +115,7 @@ def test_extract_sentences_with_question_mark(self): @ai_fn def list_questions(email_body: str) -> list[str]: """ - return any sentences that end with a question mark found in the - email_body + Returns a list of any questions in the email body. """ email_body = "Hi Taylor, It is nice outside today. What is your favorite color?" @@ -137,7 +138,7 @@ def extract_colors(words: list[str]) -> set[str]: class TestNone: def test_none_response(self): @ai_fn - def filter_with_none(words: list[str]) -> list[str | None]: + def filter_with_none(words: list[str]) -> list[Optional[str]]: """ takes a list of words and returns a list of equal length that replaces any word except "blue" with None diff --git a/tests/llm_tests/bots/test_bots.py b/tests/llm_tests/bots/test_bots.py index bf88bea1f..47a22d888 100644 --- a/tests/llm_tests/bots/test_bots.py +++ b/tests/llm_tests/bots/test_bots.py @@ -1,6 +1,6 @@ import pytest from marvin import Bot -from marvin.utilities.tests import assert_approx_equal +from marvin.utilities.tests import assert_llm class TestBotResponse: @@ -11,13 +11,13 @@ class TestBotResponse: async def test_simple_response(self, message, expected_response): bot = Bot() response = await bot.say(message) - assert_approx_equal(response.content, expected_response) + assert_llm(response.content, expected_response) async def test_memory(self): bot = Bot() response = await bot.say("My favorite color is blue") response = await bot.say("What is my favorite color?") - assert_approx_equal( + assert_llm( response.content, "You told me that your favorite color is blue", ) @@ -30,16 +30,22 @@ async def test_int(self): assert response.parsed_content == 2 async def test_list_str(self): - bot = Bot(instructions="solve the math problems", response_format=list[str]) + bot = Bot( + instructions="solve the math problems and return only the answer", + response_format=list[str], + ) response = await bot.say("Problem 1: 1 + 1\n\nProblem 2: 2 + 2") - assert response.parsed_content == ["2", "4"] + assert_llm(response.parsed_content, ["2", "4"]) + assert isinstance(response.parsed_content, list) + assert all(isinstance(x, str) for x in response.parsed_content) async def test_natural_language_list(self): bot = Bot( - instructions="solve the math problems", response_format="a list of strings" + instructions="solve the math problems and return only the answer", + response_format="a list of strings", ) response = await bot.say("Problem 1: 1 + 1\n\nProblem 2: 2 + 2") - assert response.parsed_content == ["2", "4"] + assert_llm(response.parsed_content, '["2", "4"]') async def test_natural_language_list_2(self): bot = Bot(instructions="list the keywords", response_format="a list of strings") @@ -48,8 +54,8 @@ async def test_natural_language_list_2(self): async def test_natural_language_list_with_json_keyword(self): bot = Bot( - instructions="solve the math problems", + instructions="solve the math problems and return only the answer", response_format="a JSON list of strings", ) response = await bot.say("Problem 1: 1 + 1\n\nProblem 2: 2 + 2") - assert response.parsed_content == ["2", "4"] + assert_llm(response.parsed_content, ["2", "4"])