Skip to content

Commit

Permalink
Merge pull request #848 from PrefectHQ/cast-default
Browse files Browse the repository at this point in the history
Allow optional cast target when instructions are provided
  • Loading branch information
zzstoatzz authored Mar 7, 2024
2 parents 58dd1a3 + 7d54c36 commit 3bfaa07
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 22 deletions.
2 changes: 2 additions & 0 deletions docs/docs/text/transformation.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ marvin.cast('Mass.', target=str, instructions="The state's abbreviation")
# MA
```

Note that when providing instructions, the `target` field is assumed to be a string unless otherwise specified. If no instructions are provided, a target type is required.


## Classification

Expand Down
51 changes: 33 additions & 18 deletions src/marvin/ai/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,30 +227,40 @@ async def _generate_typed_llm_response_with_logit_bias(

async def cast_async(
data: str,
target: type[T],
target: type[T] = None,
instructions: Optional[str] = None,
model_kwargs: Optional[dict] = None,
client: Optional[AsyncMarvinClient] = None,
) -> T:
"""
Converts the input data into the specified type.
This function uses a language model to convert the input data into a specified type.
The conversion process can be guided by specific instructions. The function also
supports additional arguments for the language model.
This function uses a language model to convert the input data into a
specified type. The conversion process can be guided by specific
instructions. The function also supports additional arguments for the
language model.
Args:
data (str): The data to be converted.
target (type): The type to convert the data into.
instructions (str, optional): Specific instructions for the conversion. Defaults to None.
model_kwargs (dict, optional): Additional keyword arguments for the language model. Defaults to None.
client (AsyncMarvinClient, optional): The client to use for the AI function.
target (type): The type to convert the data into. If none is provided
but instructions are provided, `str` is assumed.
instructions (str, optional): Specific instructions for the conversion.
Defaults to None.
model_kwargs (dict, optional): Additional keyword arguments for the
language model. Defaults to None.
client (AsyncMarvinClient, optional): The client to use for the AI
function.
Returns:
T: The converted data of the specified type.
"""
model_kwargs = model_kwargs or {}

if target is None and instructions is None:
raise ValueError("Must provide either a target type or instructions.")
elif target is None:
target = str

# if the user provided a `to` type that represents a list of labels, we use
# `classify()` for performance.
if (
Expand Down Expand Up @@ -686,24 +696,29 @@ def __init__(self, *args, **kwargs):

def cast(
data: str,
target: type[T],
target: type[T] = None,
instructions: Optional[str] = None,
model_kwargs: Optional[dict] = None,
client: Optional[AsyncMarvinClient] = None,
) -> T:
"""
Converts the input data into the specified type.
This function uses a language model to convert the input data into a specified type.
The conversion process can be guided by specific instructions. The function also
supports additional arguments for the language model.
This function uses a language model to convert the input data into a
specified type. The conversion process can be guided by specific
instructions. The function also supports additional arguments for the
language model.
Args:
data (str): The data to be converted.
target (type): The type to convert the data into.
instructions (str, optional): Specific instructions for the conversion. Defaults to None.
model_kwargs (dict, optional): Additional keyword arguments for the language model. Defaults to None.
client (AsyncMarvinClient, optional): The client to use for the AI function.
target (type): The type to convert the data into. If none is provided
but instructions are provided, `str` is assumed.
instructions (str, optional): Specific instructions for the conversion.
Defaults to None.
model_kwargs (dict, optional): Additional keyword arguments for the
language model. Defaults to None.
client (AsyncMarvinClient, optional): The client to use for the AI
function.
Returns:
T: The converted data of the specified type.
Expand Down Expand Up @@ -882,7 +897,7 @@ def classify_map(

async def cast_async_map(
data: list[str],
target: type[T],
target: type[T] = None,
instructions: Optional[str] = None,
model_kwargs: Optional[dict] = None,
client: Optional[AsyncMarvinClient] = None,
Expand All @@ -901,7 +916,7 @@ async def cast_async_map(

def cast_map(
data: list[str],
target: type[T],
target: type[T] = None,
instructions: Optional[str] = None,
model_kwargs: Optional[dict] = None,
client: Optional[AsyncMarvinClient] = None,
Expand Down
10 changes: 6 additions & 4 deletions src/marvin/beta/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ async def caption_async(

async def cast_async(
data: Union[str, Image],
target: type[T],
target: type[T] = None,
instructions: str = None,
images: list[Image] = None,
vision_model_kwargs: dict = None,
Expand All @@ -223,7 +223,8 @@ async def cast_async(
Args:
images (list[Image]): The images to be processed.
data (str): The data to be converted.
target (type): The type to convert the data into.
target (type): The type to convert the data into. If not provided but
instructions are provided, assumed to be str.
instructions (str, optional): Specific instructions for the conversion.
Defaults to None.
vision_model_kwargs (dict, optional): Additional keyword arguments for
Expand Down Expand Up @@ -358,7 +359,7 @@ def caption(

def cast(
data: Union[str, Image],
target: type[T],
target: type[T] = None,
instructions: str = None,
images: list[Image] = None,
vision_model_kwargs: dict = None,
Expand All @@ -369,7 +370,8 @@ def cast(
Args:
data (Union[str, Image]): The data to be converted.
target (type[T]): The type to convert the data into.
target (type[T]): The type to convert the data into. If not provided but
instructions are provided, assumed to be str.
instructions (str, optional): Specific instructions for the conversion.
images (list[Image], optional): The images to be processed.
vision_model_kwargs (dict, optional): Additional keyword arguments for the vision model.
Expand Down
11 changes: 11 additions & 0 deletions tests/ai/test_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,17 @@ def test_cast_text_with_subtle_instructions(self, gpt_4):
)
assert result == "My name is MARVIN"

def test_str_target_if_only_instructions_provided(self):
result = marvin.cast(
"one", instructions="the numerical representation of the word "
)
assert isinstance(result, str)
assert result == "1"

def test_error_if_no_target_and_no_instructions(self):
with pytest.raises(ValueError):
marvin.cast("one")

class TestCastCallsClassify:
@patch("marvin.ai.text.classify_async")
def test_cast_doesnt_call_classify_for_int(self, mock_classify):
Expand Down

0 comments on commit 3bfaa07

Please sign in to comment.