Skip to content

Commit

Permalink
feat: support list models to return base models
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 713744661
  • Loading branch information
hkt74 authored and copybara-github committed Jan 9, 2025
1 parent a1ed3fa commit 0f713f1
Show file tree
Hide file tree
Showing 6 changed files with 154 additions and 43 deletions.
16 changes: 14 additions & 2 deletions google/genai/_api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ class HttpOptions(BaseModel):
default=None,
description="""Timeout for the request in seconds.""",
)
skip_project_and_location_in_path: bool = Field(
default=False,
description="""If set to True, the project and location will not be appended to the path.""",
)


class HttpOptionsDict(TypedDict):
Expand All @@ -75,7 +79,8 @@ class HttpOptionsDict(TypedDict):
"""If set, the response payload will be returned int the supplied dict."""
timeout: Optional[Union[float, Tuple[float, float]]] = None
"""Timeout for the request in seconds."""

skip_project_and_location_in_path: bool = False
"""If set to True, the project and location will not be appended to the path."""

HttpOptionsOrDict = Union[HttpOptions, HttpOptionsDict]

Expand Down Expand Up @@ -266,7 +271,14 @@ def _build_request(
)
else:
patched_http_options = self._http_options
if self.vertexai and not path.startswith('projects/'):
skip_project_and_location_in_path_val = patched_http_options.get(
'skip_project_and_location_in_path', False
)
if (
self.vertexai
and not path.startswith('projects/')
and not skip_project_and_location_in_path_val
):
path = f'projects/{self.project}/locations/{self.location}/' + path
url = _join_url_path(
patched_http_options['base_url'],
Expand Down
5 changes: 5 additions & 0 deletions google/genai/_replay_api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,11 @@ def _redact_request_url(url: str) -> str:
'{VERTEX_URL_PREFIX}/',
url,
)
result = re.sub(
r'.*-aiplatform.googleapis.com/[^/]+/',
'{VERTEX_URL_PREFIX}/',
result,
)
result = re.sub(
r'https://generativelanguage.googleapis.com/[^/]+',
'{MLDEV_URL_PREFIX}',
Expand Down
24 changes: 24 additions & 0 deletions google/genai/_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,30 @@ def t_model(client: _api_client.ApiClient, model: str):
else:
return f'models/{model}'

def t_models_url(api_client: _api_client.ApiClient, base_models: bool) -> str:
if api_client.vertexai:
if base_models:
return 'publishers/google/models'
else:
return 'models'
else:
if base_models:
return 'models'
else:
return 'tunedModels'


def t_extract_models(api_client: _api_client.ApiClient, response: dict) -> list[types.Model]:
if response.get('models') is not None:
return response.get('models')
elif response.get('tunedModels') is not None:
return response.get('tunedModels')
elif response.get('publisherModels') is not None:
return response.get('publisherModels')
else:
raise ValueError('Cannot determine the models type.')


def t_caches_model(api_client: _api_client.ApiClient, model: str):
model = t_model(api_client, model)
if not model:
Expand Down
91 changes: 65 additions & 26 deletions google/genai/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from . import _extra_utils
from . import _transformers as t
from . import types
from ._api_client import ApiClient
from ._api_client import ApiClient, HttpOptionsDict
from ._common import get_value_by_path as getv
from ._common import set_value_by_path as setv
from .pagers import AsyncPager, Pager
Expand Down Expand Up @@ -2280,6 +2280,9 @@ def _ListModelsConfig_to_mldev(
parent_object: dict = None,
) -> dict:
to_object = {}
if getv(from_object, ['http_options']) is not None:
setv(to_object, ['httpOptions'], getv(from_object, ['http_options']))

if getv(from_object, ['page_size']) is not None:
setv(
parent_object, ['_query', 'pageSize'], getv(from_object, ['page_size'])
Expand All @@ -2295,6 +2298,13 @@ def _ListModelsConfig_to_mldev(
if getv(from_object, ['filter']) is not None:
setv(parent_object, ['_query', 'filter'], getv(from_object, ['filter']))

if getv(from_object, ['query_base']) is not None:
setv(
parent_object,
['_url', 'models_url'],
t.t_models_url(api_client, getv(from_object, ['query_base'])),
)

return to_object


Expand All @@ -2304,6 +2314,9 @@ def _ListModelsConfig_to_vertex(
parent_object: dict = None,
) -> dict:
to_object = {}
if getv(from_object, ['http_options']) is not None:
setv(to_object, ['httpOptions'], getv(from_object, ['http_options']))

if getv(from_object, ['page_size']) is not None:
setv(
parent_object, ['_query', 'pageSize'], getv(from_object, ['page_size'])
Expand All @@ -2319,6 +2332,13 @@ def _ListModelsConfig_to_vertex(
if getv(from_object, ['filter']) is not None:
setv(parent_object, ['_query', 'filter'], getv(from_object, ['filter']))

if getv(from_object, ['query_base']) is not None:
setv(
parent_object,
['_url', 'models_url'],
t.t_models_url(api_client, getv(from_object, ['query_base'])),
)

return to_object


Expand Down Expand Up @@ -3524,13 +3544,15 @@ def _ListModelsResponse_from_mldev(
if getv(from_object, ['nextPageToken']) is not None:
setv(to_object, ['next_page_token'], getv(from_object, ['nextPageToken']))

if getv(from_object, ['tunedModels']) is not None:
if getv(from_object, ['_self']) is not None:
setv(
to_object,
['models'],
[
_Model_from_mldev(api_client, item, to_object)
for item in getv(from_object, ['tunedModels'])
for item in t.t_extract_models(
api_client, getv(from_object, ['_self'])
)
],
)

Expand All @@ -3546,13 +3568,15 @@ def _ListModelsResponse_from_vertex(
if getv(from_object, ['nextPageToken']) is not None:
setv(to_object, ['next_page_token'], getv(from_object, ['nextPageToken']))

if getv(from_object, ['models']) is not None:
if getv(from_object, ['_self']) is not None:
setv(
to_object,
['models'],
[
_Model_from_vertex(api_client, item, to_object)
for item in getv(from_object, ['models'])
for item in t.t_extract_models(
api_client, getv(from_object, ['_self'])
)
],
)

Expand Down Expand Up @@ -4091,12 +4115,12 @@ def _list(
request_dict = _ListModelsParameters_to_vertex(
self.api_client, parameter_model
)
path = 'models'.format_map(request_dict.get('_url'))
path = '{models_url}'.format_map(request_dict.get('_url'))
else:
request_dict = _ListModelsParameters_to_mldev(
self.api_client, parameter_model
)
path = 'tunedModels'.format_map(request_dict.get('_url'))
path = '{models_url}'.format_map(request_dict.get('_url'))
query_params = request_dict.get('_query')
if query_params:
path = f'{path}?{urlencode(query_params)}'
Expand Down Expand Up @@ -4523,17 +4547,24 @@ def list(
types._ListModelsParameters(config=config).config
or types.ListModelsConfig()
)

if self.api_client.vertexai:
# Filter for tuning jobs artifacts by labels.
config = config.copy()
filter_value = config.filter
config.filter = (
filter_value + '&filter=labels.tune-type:*'
if filter_value
else 'labels.tune-type:*'
)

if config.query_base:
http_options = (
config.http_options if config.http_options else HttpOptionsDict()
)
http_options['skip_project_and_location_in_path'] = True
config.http_options = http_options
else:
# Filter for tuning jobs artifacts by labels.
filter_value = config.filter
config.filter = (
filter_value + '&filter=labels.tune-type:*'
if filter_value
else 'labels.tune-type:*'
)
if not config.query_base:
config.query_base = False
return Pager(
'models',
self._list,
Expand Down Expand Up @@ -4999,12 +5030,12 @@ async def _list(
request_dict = _ListModelsParameters_to_vertex(
self.api_client, parameter_model
)
path = 'models'.format_map(request_dict.get('_url'))
path = '{models_url}'.format_map(request_dict.get('_url'))
else:
request_dict = _ListModelsParameters_to_mldev(
self.api_client, parameter_model
)
path = 'tunedModels'.format_map(request_dict.get('_url'))
path = '{models_url}'.format_map(request_dict.get('_url'))
query_params = request_dict.get('_query')
if query_params:
path = f'{path}?{urlencode(query_params)}'
Expand Down Expand Up @@ -5366,16 +5397,24 @@ async def list(
types._ListModelsParameters(config=config).config
or types.ListModelsConfig()
)

if self.api_client.vertexai:
# Filter for tuning jobs artifacts by labels.
config = config.copy()
filter_value = config.filter
config.filter = (
filter_value + '&filter=labels.tune-type:*'
if filter_value
else 'labels.tune-type:*'
)
if config.query_base:
http_options = (
config.http_options if config.http_options else HttpOptionsDict()
)
http_options['skip_project_and_location_in_path'] = True
config.http_options = http_options
else:
# Filter for tuning jobs artifacts by labels.
filter_value = config.filter
config.filter = (
filter_value + '&filter=labels.tune-type:*'
if filter_value
else 'labels.tune-type:*'
)
if not config.query_base:
config.query_base = False
return AsyncPager(
'models',
self._list,
Expand Down
48 changes: 33 additions & 15 deletions google/genai/tests/models/test_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,15 @@

test_table: list[pytest_helper.TestTableItem] = [
pytest_helper.TestTableItem(
name='test_list_models',
name='test_tuned_models',
parameters=types._ListModelsParameters(),
),
pytest_helper.TestTableItem(
name='test_list_models_with_config',
name='test_base_models',
parameters=types._ListModelsParameters(config={'query_base': True}),
),
pytest_helper.TestTableItem(
name='test_with_config',
parameters=types._ListModelsParameters(config={'page_size': 3}),
),
]
Expand All @@ -40,30 +44,44 @@
)


def test_pager(client):
models = client.models.list(config={'page_size': 10})
def test_tuned_models_pager(client):
pager = client.models.list(config={'page_size': 10})

assert pager.name == 'models'
assert pager.page_size == 10
assert len(pager) <= 10

# Iterate through all the pages. Then next_page() should raise an exception.
for _ in pager:
pass
with pytest.raises(IndexError, match='No more pages to fetch.'):
pager.next_page()


def test_base_models_pager(client):
pager = client.models.list(config={'page_size': 10, 'query_base': True})

assert models.name == 'models'
assert models.page_size == 10
assert len(models) <= 10
assert pager.name == 'models'
assert pager.page_size == 10
assert len(pager) <= 10

# Iterate through all the pages. Then next_page() should raise an exception.
for _ in models:
for _ in pager:
pass
with pytest.raises(IndexError, match='No more pages to fetch.'):
models.next_page()
pager.next_page()


@pytest.mark.asyncio
async def test_async_pager(client):
models = await client.aio.models.list(config={'page_size': 10})
pager = await client.aio.models.list(config={'page_size': 10})

assert models.name == 'models'
assert models.page_size == 10
assert len(models) <= 10
assert pager.name == 'models'
assert pager.page_size == 10
assert len(pager) <= 10

# Iterate through all the pages. Then next_page() should raise an exception.
async for _ in models:
async for _ in pager:
pass
with pytest.raises(IndexError, match='No more pages to fetch.'):
await models.next_page()
await pager.next_page()
13 changes: 13 additions & 0 deletions google/genai/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -3863,13 +3863,23 @@ class ModelDict(TypedDict, total=False):

class ListModelsConfig(_common.BaseModel):

http_options: Optional[dict[str, Any]] = Field(
default=None, description="""Used to override HTTP request options."""
)
page_size: Optional[int] = Field(default=None, description="""""")
page_token: Optional[str] = Field(default=None, description="""""")
filter: Optional[str] = Field(default=None, description="""""")
query_base: Optional[bool] = Field(
default=None,
description="""Set true to list base models, false to list tuned models.""",
)


class ListModelsConfigDict(TypedDict, total=False):

http_options: Optional[dict[str, Any]]
"""Used to override HTTP request options."""

page_size: Optional[int]
""""""

Expand All @@ -3879,6 +3889,9 @@ class ListModelsConfigDict(TypedDict, total=False):
filter: Optional[str]
""""""

query_base: Optional[bool]
"""Set true to list base models, false to list tuned models."""


ListModelsConfigOrDict = Union[ListModelsConfig, ListModelsConfigDict]

Expand Down

0 comments on commit 0f713f1

Please sign in to comment.