diff --git a/docs/docs/text/transformation.md b/docs/docs/text/transformation.md index 9695f80b7..f007260ee 100644 --- a/docs/docs/text/transformation.md +++ b/docs/docs/text/transformation.md @@ -66,6 +66,8 @@ marvin.cast('Mass.', target=str, instructions="The state's abbreviation") # MA ``` +Note that when providing instructions, the `target` field is assumed to be a string unless otherwise specified. If no instructions are provided, a target type is required. + ## Classification diff --git a/src/marvin/ai/text.py b/src/marvin/ai/text.py index 0e310af06..2620ba111 100644 --- a/src/marvin/ai/text.py +++ b/src/marvin/ai/text.py @@ -227,7 +227,7 @@ async def _generate_typed_llm_response_with_logit_bias( async def cast_async( data: str, - target: type[T], + target: type[T] = None, instructions: Optional[str] = None, model_kwargs: Optional[dict] = None, client: Optional[AsyncMarvinClient] = None, @@ -235,22 +235,32 @@ async def cast_async( """ Converts the input data into the specified type. - This function uses a language model to convert the input data into a specified type. - The conversion process can be guided by specific instructions. The function also - supports additional arguments for the language model. + This function uses a language model to convert the input data into a + specified type. The conversion process can be guided by specific + instructions. The function also supports additional arguments for the + language model. Args: data (str): The data to be converted. - target (type): The type to convert the data into. - instructions (str, optional): Specific instructions for the conversion. Defaults to None. - model_kwargs (dict, optional): Additional keyword arguments for the language model. Defaults to None. - client (AsyncMarvinClient, optional): The client to use for the AI function. + target (type): The type to convert the data into. If none is provided + but instructions are provided, `str` is assumed. + instructions (str, optional): Specific instructions for the conversion. + Defaults to None. + model_kwargs (dict, optional): Additional keyword arguments for the + language model. Defaults to None. + client (AsyncMarvinClient, optional): The client to use for the AI + function. Returns: T: The converted data of the specified type. """ model_kwargs = model_kwargs or {} + if target is None and instructions is None: + raise ValueError("Must provide either a target type or instructions.") + elif target is None: + target = str + # if the user provided a `to` type that represents a list of labels, we use # `classify()` for performance. if ( @@ -686,7 +696,7 @@ def __init__(self, *args, **kwargs): def cast( data: str, - target: type[T], + target: type[T] = None, instructions: Optional[str] = None, model_kwargs: Optional[dict] = None, client: Optional[AsyncMarvinClient] = None, @@ -694,16 +704,21 @@ def cast( """ Converts the input data into the specified type. - This function uses a language model to convert the input data into a specified type. - The conversion process can be guided by specific instructions. The function also - supports additional arguments for the language model. + This function uses a language model to convert the input data into a + specified type. The conversion process can be guided by specific + instructions. The function also supports additional arguments for the + language model. Args: data (str): The data to be converted. - target (type): The type to convert the data into. - instructions (str, optional): Specific instructions for the conversion. Defaults to None. - model_kwargs (dict, optional): Additional keyword arguments for the language model. Defaults to None. - client (AsyncMarvinClient, optional): The client to use for the AI function. + target (type): The type to convert the data into. If none is provided + but instructions are provided, `str` is assumed. + instructions (str, optional): Specific instructions for the conversion. + Defaults to None. + model_kwargs (dict, optional): Additional keyword arguments for the + language model. Defaults to None. + client (AsyncMarvinClient, optional): The client to use for the AI + function. Returns: T: The converted data of the specified type. @@ -882,7 +897,7 @@ def classify_map( async def cast_async_map( data: list[str], - target: type[T], + target: type[T] = None, instructions: Optional[str] = None, model_kwargs: Optional[dict] = None, client: Optional[AsyncMarvinClient] = None, @@ -901,7 +916,7 @@ async def cast_async_map( def cast_map( data: list[str], - target: type[T], + target: type[T] = None, instructions: Optional[str] = None, model_kwargs: Optional[dict] = None, client: Optional[AsyncMarvinClient] = None, diff --git a/src/marvin/beta/vision.py b/src/marvin/beta/vision.py index c9cfb8294..c433a1e2e 100644 --- a/src/marvin/beta/vision.py +++ b/src/marvin/beta/vision.py @@ -207,7 +207,7 @@ async def caption_async( async def cast_async( data: Union[str, Image], - target: type[T], + target: type[T] = None, instructions: str = None, images: list[Image] = None, vision_model_kwargs: dict = None, @@ -223,7 +223,8 @@ async def cast_async( Args: images (list[Image]): The images to be processed. data (str): The data to be converted. - target (type): The type to convert the data into. + target (type): The type to convert the data into. If not provided but + instructions are provided, assumed to be str. instructions (str, optional): Specific instructions for the conversion. Defaults to None. vision_model_kwargs (dict, optional): Additional keyword arguments for @@ -358,7 +359,7 @@ def caption( def cast( data: Union[str, Image], - target: type[T], + target: type[T] = None, instructions: str = None, images: list[Image] = None, vision_model_kwargs: dict = None, @@ -369,7 +370,8 @@ def cast( Args: data (Union[str, Image]): The data to be converted. - target (type[T]): The type to convert the data into. + target (type[T]): The type to convert the data into. If not provided but + instructions are provided, assumed to be str. instructions (str, optional): Specific instructions for the conversion. images (list[Image], optional): The images to be processed. vision_model_kwargs (dict, optional): Additional keyword arguments for the vision model. diff --git a/tests/ai/test_cast.py b/tests/ai/test_cast.py index 202bf1609..3cb5079a8 100644 --- a/tests/ai/test_cast.py +++ b/tests/ai/test_cast.py @@ -93,6 +93,17 @@ def test_cast_text_with_subtle_instructions(self, gpt_4): ) assert result == "My name is MARVIN" + def test_str_target_if_only_instructions_provided(self): + result = marvin.cast( + "one", instructions="the numerical representation of the word " + ) + assert isinstance(result, str) + assert result == "1" + + def test_error_if_no_target_and_no_instructions(self): + with pytest.raises(ValueError): + marvin.cast("one") + class TestCastCallsClassify: @patch("marvin.ai.text.classify_async") def test_cast_doesnt_call_classify_for_int(self, mock_classify):