Skip to content

Commit

Permalink
Merge pull request #330 from PrefectHQ/mapping
Browse files Browse the repository at this point in the history
implements mapping for `ai_fn` and `ai_model`
  • Loading branch information
zzstoatzz authored May 26, 2023
2 parents 36368d0 + f07141a commit bf7bb2b
Show file tree
Hide file tree
Showing 6 changed files with 115 additions and 2 deletions.
25 changes: 24 additions & 1 deletion src/marvin/ai_functions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
import re
import sys
from functools import partial
from typing import Callable, TypeVar
from typing import Any, Callable, Dict, List, TypeVar

from prefect import flow, task
from prefect.utilities.asyncutils import sync_compatible

from marvin.bot import Bot
from marvin.bot.history import InMemoryHistory
Expand Down Expand Up @@ -193,6 +196,26 @@ def run(self, *args, **kwargs):
"""
raise NotImplementedError()

@sync_compatible
async def map(
self,
*args,
task_kwargs: Dict[str, Any] = None,
flow_kwargs: Dict[str, Any] = None,
**kwargs,
) -> List[T]:
@task(**{"name": self.fn.__name__, **(task_kwargs or {})})
async def process_item(item: Any):
return await self._run(item, **kwargs)

@flow(**{"name": self.fn.__name__, **(flow_kwargs or {})})
async def mapped_ai_fn(*args, **kwargs):
return await process_item.map(*args, **kwargs)

return [
await state.result().get() for state in await mapped_ai_fn(*args, **kwargs)
]


def ai_fn(
fn: Callable[[A], T] = None,
Expand Down
39 changes: 38 additions & 1 deletion src/marvin/ai_models/base.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,21 @@
from functools import partial, wraps
from typing import Optional, Type, TypeVar
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union

import pydantic
from prefect import flow, task
from prefect.utilities.asyncutils import sync_compatible

from marvin import ai_fn
from marvin.bot import Bot
from marvin.bot.response_formatters import PydanticFormatter

M = TypeVar("M", bound=pydantic.BaseModel)

Context = Union[
str,
Tuple[str, Optional[Dict[str, Any]]],
]


def AIModel(
cls: Optional[Type[M]],
Expand Down Expand Up @@ -60,6 +67,36 @@ def _ai_validator(cls, values):

# add _ai_validator as a pre root validator to run before any other root validators.
cls.__pre_root_validators__ = [_ai_validator, *cls.__pre_root_validators__]

@sync_compatible
async def map(
cls,
contexts: List[Context],
task_kwargs: Dict[str, Any] = None,
flow_kwargs: Dict[str, Any] = None,
) -> List[M]:
@task(**{"name": cls.__name__, **(task_kwargs or {})})
async def process_item(context: Context):
if isinstance(context, str):
return cls(context)
elif isinstance(context, tuple):
iter_context = iter(context)
unstructured = next(iter_context, None)
structured = next(iter_context, None)
return cls(unstructured, **(structured or {}))
else:
raise TypeError(
"`Context` must be a `str` or a"
f" `Tuple[str, Optional[Dict[str, Any]]]`, not {type(context)}"
)

@flow(**{"name": cls.__name__, **(flow_kwargs or {})})
async def mapped_ai_fn(contexts: List[Context]):
return await process_item.map(contexts)

return [await state.result().get() for state in await mapped_ai_fn(contexts)]

cls.map = classmethod(map)
return cls


Expand Down
1 change: 1 addition & 0 deletions tests/fixtures/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
from .objects import *
from .rest_api import *
from .llms import *
from .prefect_utils import *
8 changes: 8 additions & 0 deletions tests/fixtures/prefect_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import pytest
from prefect.testing.utilities import prefect_test_harness


@pytest.fixture(scope="session")
def prefect_db():
with prefect_test_harness():
yield
16 changes: 16 additions & 0 deletions tests/llm_tests/ai_functions/test_ai_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,3 +417,19 @@ def my_fn() -> int:

assert my_fn.name == "my_fn"
assert my_fn.description == "returns 1"


class TestAIFunctionMapping:
def test_mapping_sync(self, prefect_db):
@ai_fn
def opposite(thing: str) -> str:
"""returns the opposite of the input"""

assert opposite.map(["up", "happy"]) == ["down", "sad"]

async def test_mapping_async(self, prefect_db):
@ai_fn
async def opposite(thing: str) -> str:
"""returns the opposite of the input"""

assert await opposite.map(["up", "happy"]) == ["down", "sad"]
28 changes: 28 additions & 0 deletions tests/llm_tests/ai_models/test_ai_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,31 @@ class Election(pydantic.BaseModel):
]
)
)


class TestAIModelsMapping:
def test_mapping_sync(self, prefect_db):
@ai_model
class CardinalDirection(pydantic.BaseModel):
"""use a single capital letter for each cardinal direction."""

direction: str
degrees: int

assert CardinalDirection.map(["sunrise", "sunset"]) == [
CardinalDirection(direction="E", degrees=90),
CardinalDirection(direction="W", degrees=270),
]

async def test_mapping_async(self, prefect_db):
@ai_model
class CardinalDirection(pydantic.BaseModel):
"""use a single capital letter for each cardinal direction."""

direction: str
degrees: int

assert await CardinalDirection.map(["sunrise", "sunset"]) == [
CardinalDirection(direction="E", degrees=90),
CardinalDirection(direction="W", degrees=270),
]

0 comments on commit bf7bb2b

Please sign in to comment.