From 1ae434e3da6e74147564f86c6d7279896648ff11 Mon Sep 17 00:00:00 2001 From: Nathan Nowack Date: Fri, 26 May 2023 15:37:13 -0600 Subject: [PATCH 1/2] add basic test coverage for mapping --- src/marvin/ai_functions/base.py | 25 +++++++++++- src/marvin/ai_models/base.py | 39 ++++++++++++++++++- tests/fixtures/__init__.py | 1 + tests/fixtures/prefect_utils.py | 8 ++++ .../ai_functions/test_ai_functions.py | 16 ++++++++ tests/llm_tests/ai_models/test_ai_models.py | 26 +++++++++++++ 6 files changed, 113 insertions(+), 2 deletions(-) create mode 100644 tests/fixtures/prefect_utils.py diff --git a/src/marvin/ai_functions/base.py b/src/marvin/ai_functions/base.py index eb04824f3..6783c4d3b 100644 --- a/src/marvin/ai_functions/base.py +++ b/src/marvin/ai_functions/base.py @@ -3,7 +3,10 @@ import re import sys from functools import partial -from typing import Callable, TypeVar +from typing import Any, Callable, Dict, List, TypeVar + +from prefect import flow, task +from prefect.utilities.asyncutils import sync_compatible from marvin.bot import Bot from marvin.bot.history import InMemoryHistory @@ -190,6 +193,26 @@ def run(self, *args, **kwargs): """ raise NotImplementedError() + @sync_compatible + async def map( + self, + *args, + task_kwargs: Dict[str, Any] = None, + flow_kwargs: Dict[str, Any] = None, + **kwargs, + ) -> List[T]: + @task(**{"name": self.fn.__name__, **(task_kwargs or {})}) + async def process_item(item: Any): + return await self._run(item, **kwargs) + + @flow(**{"name": self.fn.__name__, **(flow_kwargs or {})}) + async def mapped_ai_fn(*args, **kwargs): + return await process_item.map(*args, **kwargs) + + return [ + await state.result().get() for state in await mapped_ai_fn(*args, **kwargs) + ] + def ai_fn( fn: Callable[[A], T] = None, diff --git a/src/marvin/ai_models/base.py b/src/marvin/ai_models/base.py index b31457e8d..774ce066f 100644 --- a/src/marvin/ai_models/base.py +++ b/src/marvin/ai_models/base.py @@ -1,7 +1,9 @@ from functools import partial, wraps -from typing import Optional, Type, TypeVar +from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union import pydantic +from prefect import flow, task +from prefect.utilities.asyncutils import sync_compatible from marvin import ai_fn from marvin.bot import Bot @@ -9,6 +11,11 @@ M = TypeVar("M", bound=pydantic.BaseModel) +Context = Union[ + str, + Tuple[str, Optional[Dict[str, Any]]], +] + def AIModel( cls: Optional[Type[M]], @@ -60,6 +67,36 @@ def _ai_validator(cls, values): # add _ai_validator as a pre root validator to run before any other root validators. cls.__pre_root_validators__ = [_ai_validator, *cls.__pre_root_validators__] + + @sync_compatible + async def map( + cls, + contexts: List[Context], + task_kwargs: Dict[str, Any] = None, + flow_kwargs: Dict[str, Any] = None, + ) -> List[M]: + @task(**{"name": cls.__name__, **(task_kwargs or {})}) + async def process_item(context: Context): + if isinstance(context, str): + return cls(context) + elif isinstance(context, tuple): + iter_context = iter(context) + unstructured = next(iter_context, None) + structured = next(iter_context, None) + return cls(unstructured, **(structured or {})) + else: + raise TypeError( + "`Context` must be a `str` or a" + f" `Tuple[str, Optional[Dict[str, Any]]]`, not {type(context)}" + ) + + @flow(**{"name": cls.__name__, **(flow_kwargs or {})}) + async def mapped_ai_fn(contexts: List[Context]): + return await process_item.map(contexts) + + return [await state.result().get() for state in await mapped_ai_fn(contexts)] + + cls.map = classmethod(map) return cls diff --git a/tests/fixtures/__init__.py b/tests/fixtures/__init__.py index 0ffdf7e3c..a93114298 100644 --- a/tests/fixtures/__init__.py +++ b/tests/fixtures/__init__.py @@ -2,3 +2,4 @@ from .objects import * from .rest_api import * from .llms import * +from .prefect_utils import * diff --git a/tests/fixtures/prefect_utils.py b/tests/fixtures/prefect_utils.py new file mode 100644 index 000000000..bbbfb72fa --- /dev/null +++ b/tests/fixtures/prefect_utils.py @@ -0,0 +1,8 @@ +import pytest +from prefect.testing.utilities import prefect_test_harness + + +@pytest.fixture(scope="session") +def prefect_db(): + with prefect_test_harness(): + yield diff --git a/tests/llm_tests/ai_functions/test_ai_functions.py b/tests/llm_tests/ai_functions/test_ai_functions.py index a3567a7df..04aa9dec2 100644 --- a/tests/llm_tests/ai_functions/test_ai_functions.py +++ b/tests/llm_tests/ai_functions/test_ai_functions.py @@ -417,3 +417,19 @@ def my_fn() -> int: assert my_fn.name == "my_fn" assert my_fn.description == "returns 1" + + +class TestAIFunctionMapping: + def test_mapping_sync(self, prefect_db): + @ai_fn + def opposite(thing: str) -> str: + """returns the opposite of the input""" + + assert opposite.map(["up", "happy"]) == ["down", "sad"] + + async def test_mapping_async(self, prefect_db): + @ai_fn + async def opposite(thing: str) -> str: + """returns the opposite of the input""" + + assert await opposite.map(["up", "happy"]) == ["down", "sad"] diff --git a/tests/llm_tests/ai_models/test_ai_models.py b/tests/llm_tests/ai_models/test_ai_models.py index 5c90ddf18..ce4e5de6a 100644 --- a/tests/llm_tests/ai_models/test_ai_models.py +++ b/tests/llm_tests/ai_models/test_ai_models.py @@ -131,3 +131,29 @@ class Election(pydantic.BaseModel): ] ) ) + + +class TestAIModelsMapping: + def test_mapping_sync(self, prefect_db): + @ai_model + class CardinalDirection(pydantic.BaseModel): + """use a single capital letter for each cardinal direction.""" + + direction: str + + assert CardinalDirection.map(["sunrise", "sunset"]) == [ + CardinalDirection(direction="E"), + CardinalDirection(direction="W"), + ] + + async def test_mapping_async(self, prefect_db): + @ai_model + class CardinalDirection(pydantic.BaseModel): + """use a single capital letter for each cardinal direction.""" + + direction: str + + assert await CardinalDirection.map(["sunrise", "sunset"]) == [ + CardinalDirection(direction="E"), + CardinalDirection(direction="W"), + ] From 61f7db31240c17a5c6ff2e559adc033e57c68e6b Mon Sep 17 00:00:00 2001 From: Nathan Nowack Date: Fri, 26 May 2023 15:44:00 -0600 Subject: [PATCH 2/2] improve test slightly --- tests/llm_tests/ai_models/test_ai_models.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/llm_tests/ai_models/test_ai_models.py b/tests/llm_tests/ai_models/test_ai_models.py index ce4e5de6a..ca903cc5f 100644 --- a/tests/llm_tests/ai_models/test_ai_models.py +++ b/tests/llm_tests/ai_models/test_ai_models.py @@ -140,10 +140,11 @@ class CardinalDirection(pydantic.BaseModel): """use a single capital letter for each cardinal direction.""" direction: str + degrees: int assert CardinalDirection.map(["sunrise", "sunset"]) == [ - CardinalDirection(direction="E"), - CardinalDirection(direction="W"), + CardinalDirection(direction="E", degrees=90), + CardinalDirection(direction="W", degrees=270), ] async def test_mapping_async(self, prefect_db): @@ -152,8 +153,9 @@ class CardinalDirection(pydantic.BaseModel): """use a single capital letter for each cardinal direction.""" direction: str + degrees: int assert await CardinalDirection.map(["sunrise", "sunset"]) == [ - CardinalDirection(direction="E"), - CardinalDirection(direction="W"), + CardinalDirection(direction="E", degrees=90), + CardinalDirection(direction="W", degrees=270), ]