Skip to content

Commit

Permalink
feat: support function_calls quick accessor in GenerateContentRespons…
Browse files Browse the repository at this point in the history
…e class

PiperOrigin-RevId: 712976996
  • Loading branch information
yyyu-google authored and copybara-github committed Jan 7, 2025
1 parent 8159502 commit 81b8a23
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 0 deletions.
110 changes: 110 additions & 0 deletions google/genai/tests/types/test_part_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,3 +224,113 @@ def test_non_text_part_and_text_part_text():

with pytest.raises(ValueError):
response.text


def test_candidates_none_function_calls():
response = types.GenerateContentResponse()
assert response.function_calls is None


def test_candidates_empty_function_calls():
response = types.GenerateContentResponse(candidates=[])
assert response.function_calls is None


def test_content_none_function_calls():
response = types.GenerateContentResponse(candidates=[types.Candidate()])
assert response.function_calls is None


def test_parts_none_function_calls():
response = types.GenerateContentResponse(
candidates=[types.Candidate(content=types.Content())]
)
assert response.function_calls is None


def test_parts_empty_function_calls():
response = types.GenerateContentResponse(
candidates=[
types.Candidate(content=types.Content(parts=[])),
]
)
assert response.function_calls is None


def test_multiple_candidates_function_calls():
response = types.GenerateContentResponse(
candidates=[
types.Candidate(
content=types.Content(
parts=[
types.Part(
function_call=types.FunctionCall.model_validate({
'args': {'key1': 'value1'},
'name': 'funcCall1',
})
)
]
)
),
types.Candidate(
content=types.Content(
parts=[
types.Part(
function_call=types.FunctionCall.model_validate({
'args': {'key2': 'value2'},
'name': 'funcCall2',
})
)
]
)
),
]
)
assert response.function_calls == [
types.FunctionCall(name='funcCall1', args={'key1': 'value1'})
]


def test_multiple_function_calls():
response = types.GenerateContentResponse(
candidates=[
types.Candidate(
content=types.Content(
parts=[
types.Part(
function_call=types.FunctionCall.model_validate({
'args': {'key1': 'value1'},
'name': 'funcCall1',
})
),
types.Part(
function_call=types.FunctionCall.model_validate({
'args': {'key2': 'value2'},
'name': 'funcCall2',
})
),
]
)
),
]
)
assert response.function_calls == [
types.FunctionCall(name='funcCall1', args={'key1': 'value1'}),
types.FunctionCall(name='funcCall2', args={'key2': 'value2'}),
]


def test_no_function_calls():
response = types.GenerateContentResponse(
candidates=[
types.Candidate(
content=types.Content(
parts=[
types.Part(text='Hello1'),
types.Part(text='World1'),
]
)
),
]
)
assert response.function_calls is None
22 changes: 22 additions & 0 deletions google/genai/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2492,6 +2492,28 @@ def text(self) -> Optional[str]:
# part.text == '' is different from part.text is None
return text if any_text_part_text else None

@property
def function_calls(self) -> Optional[list[FunctionCall]]:
"""Returns the list of function calls in the response."""
if (
not self.candidates
or not self.candidates[0].content
or not self.candidates[0].content.parts
):
return None
if len(self.candidates) > 1:
logging.warning(
"Warning: there are multiple candidates in the response, returning"
" function calls from the first one."
)
function_calls = [
part.function_call
for part in self.candidates[0].content.parts
if part.function_call is not None
]

return function_calls if function_calls else None

@classmethod
def _from_response(
cls, response: dict[str, object], kwargs: dict[str, object]
Expand Down

0 comments on commit 81b8a23

Please sign in to comment.