Skip to content

Commit

Permalink
No public description
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 712651252
  • Loading branch information
sararob authored and copybara-github committed Jan 6, 2025
1 parent e1e5897 commit 8159502
Show file tree
Hide file tree
Showing 4 changed files with 160 additions and 24 deletions.
76 changes: 58 additions & 18 deletions google/genai/_api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,27 +29,56 @@
import google.auth
import google.auth.credentials
from google.auth.transport.requests import AuthorizedSession
from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict, Field, ValidationError
import requests

from . import errors


class HttpOptions(TypedDict):
class HttpOptions(BaseModel):
"""HTTP options for the api client."""
model_config = ConfigDict(extra='forbid')

base_url: Optional[str] = Field(
default=None,
description="""The base URL for the AI platform service endpoint.""",
)
api_version: Optional[str] = Field(
default=None,
description="""Specifies the version of the API to use.""",
)
headers: Optional[dict[str, str]] = Field(
default=None,
description="""Additional HTTP headers to be sent with the request.""",
)
response_payload: Optional[dict] = Field(
default=None,
description="""If set, the response payload will be returned int the supplied dict.""",
)
timeout: Optional[Union[float, Tuple[float, float]]] = Field(
default=None,
description="""Timeout for the request in seconds.""",
)


class HttpOptionsDict(TypedDict):
"""HTTP options for the api client."""

base_url: str = None
base_url: Optional[str] = None
"""The base URL for the AI platform service endpoint."""
api_version: str = None
api_version: Optional[str] = None
"""Specifies the version of the API to use."""
headers: dict[str, Union[str, list[str]]] = None
headers: Optional[dict[str, Union[str, list[str]]]] = None
"""Additional HTTP headers to be sent with the request."""
response_payload: dict = None
response_payload: Optional[dict] = None
"""If set, the response payload will be returned int the supplied dict."""
timeout: Optional[Union[float, Tuple[float, float]]] = None
"""Timeout for the request in seconds."""


HttpOptionsOrDict = Union[HttpOptions, HttpOptionsDict]


def _append_library_version_headers(headers: dict[str, str]) -> None:
"""Appends the telemetry header to the headers dict."""
# TODO: Automate revisions to the SDK library version.
Expand All @@ -73,10 +102,10 @@ def _append_library_version_headers(headers: dict[str, str]) -> None:


def _patch_http_options(
options: HttpOptions, patch_options: HttpOptions
) -> HttpOptions:
options: HttpOptionsDict, patch_options: HttpOptionsDict
) -> HttpOptionsDict:
# use shallow copy so we don't override the original objects.
copy_option = HttpOptions()
copy_option = HttpOptionsDict()
copy_option.update(options)
for patch_key, patch_value in patch_options.items():
# if both are dicts, update the copy.
Expand Down Expand Up @@ -154,7 +183,7 @@ def __init__(
credentials: google.auth.credentials.Credentials = None,
project: Union[str, None] = None,
location: Union[str, None] = None,
http_options: HttpOptions = None,
http_options: HttpOptionsOrDict = None,
):
self.vertexai = vertexai
if self.vertexai is None:
Expand All @@ -170,11 +199,20 @@ def __init__(
'Project/location and API key are mutually exclusive in the client initializer.'
)

# Validate http_options if a dict is provided.
if isinstance(http_options, dict):
try:
HttpOptions.model_validate(http_options)
except ValidationError as e:
raise ValueError(f'Invalid http_options: {e}')
elif(isinstance(http_options, HttpOptions)):
http_options = http_options.model_dump()

self.api_key: Optional[str] = None
self.project = project or os.environ.get('GOOGLE_CLOUD_PROJECT', None)
self.location = location or os.environ.get('GOOGLE_CLOUD_LOCATION', None)
self._credentials = credentials
self._http_options = HttpOptions()
self._http_options = HttpOptionsDict()

if self.vertexai:
if not self.project:
Expand Down Expand Up @@ -215,7 +253,7 @@ def _build_request(
http_method: str,
path: str,
request_dict: dict[str, object],
http_options: HttpOptions = None,
http_options: HttpOptionsDict = None,
) -> HttpRequest:
# Remove all special dict keys such as _url and _query.
keys_to_delete = [key for key in request_dict.keys() if key.startswith('_')]
Expand Down Expand Up @@ -316,8 +354,10 @@ async def _async_request(
stream=stream,
)

def get_read_only_http_options(self) -> HttpOptions:
copied = HttpOptions()
def get_read_only_http_options(self) -> HttpOptionsDict:
copied = HttpOptionsDict()
if isinstance(self._http_options, BaseModel):
self._http_options = self._http_options.model_dump()
copied.update(self._http_options)
return copied

Expand All @@ -326,7 +366,7 @@ def request(
http_method: str,
path: str,
request_dict: dict[str, object],
http_options: HttpOptions = None,
http_options: HttpOptionsDict = None,
):
http_request = self._build_request(
http_method, path, request_dict, http_options
Expand All @@ -341,7 +381,7 @@ def request_streamed(
http_method: str,
path: str,
request_dict: dict[str, object],
http_options: HttpOptions = None,
http_options: HttpOptionsDict = None,
):
http_request = self._build_request(
http_method, path, request_dict, http_options
Expand All @@ -358,7 +398,7 @@ async def async_request(
http_method: str,
path: str,
request_dict: dict[str, object],
http_options: HttpOptions = None,
http_options: HttpOptionsDict = None,
) -> dict[str, object]:
http_request = self._build_request(
http_method, path, request_dict, http_options
Expand All @@ -374,7 +414,7 @@ async def async_request_streamed(
http_method: str,
path: str,
request_dict: dict[str, object],
http_options: HttpOptions = None,
http_options: HttpOptionsDict = None,
):
http_request = self._build_request(
http_method, path, request_dict, http_options
Expand Down
9 changes: 6 additions & 3 deletions google/genai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@
#

import os
from typing import Optional
from typing import Optional, Union

import google.auth
import pydantic

from ._api_client import ApiClient, HttpOptions
from ._api_client import ApiClient, HttpOptions, HttpOptionsDict
from ._replay_api_client import ReplayApiClient
from .batches import AsyncBatches, Batches
from .caches import AsyncCaches, Caches
Expand Down Expand Up @@ -145,7 +145,7 @@ def __init__(
project: Optional[str] = None,
location: Optional[str] = None,
debug_config: Optional[DebugConfig] = None,
http_options: Optional[HttpOptions] = None,
http_options: Optional[Union[HttpOptions, HttpOptionsDict]] = None,
):
"""Initializes the client.
Expand Down Expand Up @@ -179,6 +179,9 @@ def __init__(
debug_config (DebugConfig):
Config settings that control network
behavior of the client. This is typically used when running test code.
http_options (Union[HttpOptions, HttpOptionsDict]):
Http options to use for the client. Response_payload can't be
set when passing to the client constructor.
"""

self._debug_config = debug_config or DebugConfig()
Expand Down
93 changes: 93 additions & 0 deletions google/genai/tests/client/test_client_initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,99 @@ def test_constructor_with_response_payload_in_http_options():
)


def test_constructor_with_invalid_http_options_key():
mldev_http_options = {
"invalid_version_key": "v1",
"base_url": "https://placeholder-fake-url.com/",
"headers": {"X-Custom-Header": "custom_value"},
}
vertexai_http_options = {
"api_version": "v1",
"base_url": (
"https://{self.location}-aiplatform.googleapis.com/{{api_version}}/"
),
"invalid_header_key": {"X-Custom-Header": "custom_value"},
}

# Expect value error when HTTPOptions is provided as a dict and contains
# an invalid key.
try:
_ = Client(api_key="google_api_key", http_options=mldev_http_options)
except Exception as e:
assert isinstance(e, ValueError)
assert "Invalid http_options" in str(e)

# Expect value error when HTTPOptions is provided as a dict and contains
# an invalid key.
try:
_ = Client(
vertexai=True,
project="fake_project_id",
location="fake-location",
http_options=vertexai_http_options,
)
except Exception as e:
assert isinstance(e, ValueError)
assert "Invalid http_options" in str(e)


def test_constructor_with_http_options_as_pydantic_type():
mldev_http_options = api_client.HttpOptions(
api_version="v1",
base_url="https://placeholder-fake-url.com/",
headers={"X-Custom-Header": "custom_value"},
)
vertexai_http_options = api_client.HttpOptions(
api_version="v1",
base_url=(
"https://{self.location}-aiplatform.googleapis.com/{{api_version}}/"
),
headers={"X-Custom-Header": "custom_value"},
)

# Test http_options for mldev client.
mldev_client = Client(
api_key="google_api_key", http_options=mldev_http_options
)
assert not mldev_client.models.api_client.vertexai
assert (
mldev_client.models.api_client.get_read_only_http_options()["base_url"]
== mldev_http_options.base_url
)
assert (
mldev_client.models.api_client.get_read_only_http_options()["api_version"]
== mldev_http_options.api_version
)

assert mldev_client.models.api_client.get_read_only_http_options()["headers"][
"X-Custom-Header"] == mldev_http_options.headers["X-Custom-Header"]

# Test http_options for vertexai client.
vertexai_client = Client(
vertexai=True,
project="fake_project_id",
location="fake-location",
http_options=vertexai_http_options,
)
assert vertexai_client.models.api_client.vertexai
assert (
vertexai_client.models.api_client.get_read_only_http_options()["base_url"]
== vertexai_http_options.base_url
)
assert (
vertexai_client.models.api_client.get_read_only_http_options()[
"api_version"
]
== vertexai_http_options.api_version
)
assert (
vertexai_client.models.api_client.get_read_only_http_options()["headers"][
"X-Custom-Header"
]
== vertexai_http_options.headers["X-Custom-Header"]
)


def test_vertexai_from_env_1(monkeypatch):
project_id = "fake_project_id"
location = "fake-location"
Expand Down
6 changes: 3 additions & 3 deletions google/genai/tests/client/test_client_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def test_build_request_appends_to_user_agent_headers(monkeypatch):
'GET',
'test/path',
{'key': 'value'},
api_client.HttpOptions(
api_client.HttpOptionsDict(
url='test/url',
api_version='1',
headers={'user-agent': 'test-user-agent'},
Expand All @@ -115,7 +115,7 @@ def test_build_request_appends_to_goog_api_client_headers(monkeypatch):
'GET',
'test/path',
{'key': 'value'},
api_client.HttpOptions(
api_client.HttpOptionsDict(
url='test/url',
api_version='1',
headers={'x-goog-api-client': 'test-goog-api-client'},
Expand All @@ -137,7 +137,7 @@ def test_build_request_keeps_sdk_version_headers(monkeypatch):
'GET',
'test/path',
{'key': 'value'},
api_client.HttpOptions(
api_client.HttpOptionsDict(
url='test/url',
api_version='1',
headers=headers_to_inject,
Expand Down

0 comments on commit 8159502

Please sign in to comment.