From a2e17e11c13b4d944af7bca71440edfcaf0fed05 Mon Sep 17 00:00:00 2001 From: Tatsiana Havina Date: Fri, 7 Feb 2025 15:39:19 +0100 Subject: [PATCH] feat(py): VertexAI options --- py/packages/genkit/src/genkit/ai/model.py | 8 +- .../src/genkit/plugins/vertex_ai/__init__.py | 67 +-------- .../src/genkit/plugins/vertex_ai/constants.py | 10 ++ .../src/genkit/plugins/vertex_ai/gemini.py | 140 ++++++++++++++++++ .../{models => modelgarden}/__init__.py | 2 +- .../src/genkit/plugins/vertex_ai/options.py | 67 +++++++++ .../genkit/plugins/vertex_ai/plugin_api.py | 42 ++++++ py/plugins/vertex-ai/tests/test_options.py | 96 ++++++++++++ py/samples/hello/hello.py | 12 +- pyproject.toml | 6 + 10 files changed, 381 insertions(+), 69 deletions(-) create mode 100644 py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/constants.py create mode 100644 py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/gemini.py rename py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/{models => modelgarden}/__init__.py (77%) create mode 100644 py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/options.py create mode 100644 py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/plugin_api.py create mode 100644 py/plugins/vertex-ai/tests/test_options.py create mode 100644 pyproject.toml diff --git a/py/packages/genkit/src/genkit/ai/model.py b/py/packages/genkit/src/genkit/ai/model.py index c30d1825d..4f07ba95f 100644 --- a/py/packages/genkit/src/genkit/ai/model.py +++ b/py/packages/genkit/src/genkit/ai/model.py @@ -2,7 +2,13 @@ # SPDX-License-Identifier: Apache-2.0 from collections.abc import Callable +from pydantic import BaseModel +from genkit.core.schemas import GenerateRequest, GenerateResponse, ModelInfo -from genkit.core.schemas import GenerateRequest, GenerateResponse ModelFn = Callable[[GenerateRequest], GenerateResponse] + + +class ModelReference(BaseModel): + name: str + info: ModelInfo diff --git a/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/__init__.py b/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/__init__.py index 3bb387539..cb57c0eed 100644 --- a/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/__init__.py +++ b/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/__init__.py @@ -2,68 +2,9 @@ # SPDX-License-Identifier: Apache-2.0 -""" -Google Cloud Vertex AI Plugin for Genkit. -""" +"""Google Cloud Vertex AI Plugin for Genkit.""" +from genkit.plugins.vertex_ai.plugin_api import vertexAI, gemini +from genkit.plugins.vertex_ai.options import PluginOptions -from collections.abc import Callable -import vertexai -from genkit.core.schemas import ( - GenerateRequest, - GenerateResponse, - Message, - Role, - TextPart, -) -from genkit.veneer.veneer import Genkit -from vertexai.generative_models import Content, GenerativeModel, Part - - -def package_name() -> str: - return 'genkit.plugins.vertex_ai' - - -def vertex_ai(project_id: str | None = None) -> Callable[[Genkit], None]: - def plugin(ai: Genkit) -> None: - vertexai.init(location='us-central1', project=project_id) - - def handle_gemini_request(request: GenerateRequest) -> GenerateResponse: - gemini_msgs: list[Content] = [] - for m in request.messages: - gemini_parts: list[Part] = [] - for p in m.content: - if p.root.text is not None: - gemini_parts.append(Part.from_text(p.root.text)) - else: - raise Exception('unsupported part type') - gemini_msgs.append( - Content(role=m.role.value, parts=gemini_parts) - ) - model = GenerativeModel('gemini-1.5-flash-002') - response = model.generate_content(contents=gemini_msgs) - return GenerateResponse( - message=Message( - role=Role.model, content=[TextPart(text=response.text)] - ) - ) - - ai.define_model( - name='vertexai/gemini-1.5-flash', - fn=handle_gemini_request, - metadata={ - 'model': { - 'label': 'banana', - 'supports': {'multiturn': True}, - } - }, - ) - - return plugin - - -def gemini(name: str) -> str: - return f'vertexai/{name}' - - -__all__ = ['package_name', 'vertex_ai', 'gemini'] +__all__ = ['package_name', 'vertexAI', 'gemini', 'PluginOptions'] diff --git a/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/constants.py b/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/constants.py new file mode 100644 index 000000000..e1be0e7f9 --- /dev/null +++ b/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/constants.py @@ -0,0 +1,10 @@ +# Copyright 2025 Google LLC +# SPDX-License-Identifier: Apache-2.0 + +"""Constants for VertexAI plugin.""" +GCLOUD_PROJECT = 'GCLOUD_PROJECT' +GCLOUD_LOCATION = 'GCLOUD_LOCATION' +GCLOUD_PLATFORM_OAUTH_SCOPE = 'https://www.googleapis.com/auth/cloud-platform' +GCLOUD_SERVICE_ACCOUNT_CREDS = 'GCLOUD_SERVICE_ACCOUNT_CREDS' +FIREBASE_CONFIG = 'FIREBASE_CONFIG' +DEFAULT_REGION = 'us-central1' diff --git a/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/gemini.py b/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/gemini.py new file mode 100644 index 000000000..ee4915ddc --- /dev/null +++ b/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/gemini.py @@ -0,0 +1,140 @@ +# Copyright 2025 Google LLC +# SPDX-License-Identifier: Apache-2.0 + +"""Processes Gemini request.""" +import logging + +from vertexai.generative_models import Content, GenerativeModel, Part + +from genkit.ai.model import ModelReference +from genkit.core.schemas import ( + GenerateRequest, + GenerateResponse, + Message, + ModelInfo, + Role, + Supports, + TextPart, +) + +LOG = logging.getLogger(__name__) + + +# Deprecated on 2/15/2025 +SUPPORTED_V1_MODELS = { + 'gemini-1.0-pro': ModelReference( + name='vertexai/gemini-1.0-pro', + info=ModelInfo( + versions=['gemini-1.0-pro-001', 'gemini-1.0-pro-002'], + label='Vertex AI - Gemini Pro', + supports=Supports( + multiturn=True, + media=False, + tools=True, + systemRole=True + ) +)) +} + +SUPPORTED_V15_MODELS = { + 'gemini-1.5-pro': ModelReference( + name='vertexai/gemini-1.5-pro', + info=ModelInfo( + versions=['gemini-1.5-pro-001', 'gemini-1.5-pro-002'], + label='Vertex AI - Gemini 1.5 Pro', + supports=Supports( + multiturn=True, + media=True, + tools=True, + systemRole=True + ) + )), + 'gemini-1.5-flash': ModelReference( + name='vertexai/gemini-1.5-flash', + info=ModelInfo( + versions=['gemini-1.5-flash-001', 'gemini-1.5-flash-002'], + label='Vertex AI - Gemini 1.5 Flash', + supports=Supports( + multiturn=True, + media=True, + tools=True, + systemRole=True + ) + )), + 'gemini-2.0-flash-001': ModelReference( + name='vertexai/gemini-2.0-flash-001', + info=ModelInfo( + versions=[], + label='Vertex AI - Gemini 2.0 Flash 001', + supports=Supports( + multiturn=True, + media=True, + tools=True, + systemRole=True + ) + )), + 'gemini-2.0-flash-lite-preview-02-05': ModelReference( + name='vertexai/gemini-2.0-flash-lite-preview-02-05', + info=ModelInfo( + versions=[], + label='Vertex AI - Gemini 2.0 Flash Lite Preview 02-05', + supports=Supports( + multiturn=True, + media=True, + tools=True, + systemRole=True + ) + )), + 'gemini-2.0-pro-exp-02-05': ModelReference( + name='vertexai/gemini-2.0-pro-exp-02-05', + info=ModelInfo( + versions=[], + label='Vertex AI - Gemini 2.0 Flash Pro Experimental 02-05', + supports=Supports( + multiturn=True, + media=True, + tools=True, + systemRole=True + ) + )), +} + +SUPPORTED_MODELS = SUPPORTED_V1_MODELS | SUPPORTED_V15_MODELS + + +def nearest_gemini_model(model_name): + model = SUPPORTED_MODELS.get(model_name) + if model: + return model + return ModelReference( + name=f'vertexai/{model_name}', + info=ModelInfo( + versions=[], + label='Vertex AI - Gemini', + supports=Supports( + multiturn=True, + media=True, + tools=True, + systemRole=True + ) + ) + ) + + +def execute_gemini_request(request: GenerateRequest) -> GenerateResponse: + messages: list[Content] = [] + for msg in request.messages: + parts: list[Part] = [] + for part in msg.content: + if hasattr(part, "text") and part.text: + parts.append(Part.from_text(part.text)) + else: + LOG.error("Unsupported message type.") + messages.append(Content(role=msg.role.value, parts=parts)) + model = GenerativeModel('gemini-1.5-flash-001') + response = model.generate_content(contents=messages) + return GenerateResponse( + message=Message( + role=Role.model, content=[TextPart(text=response.text)] + ) + ) diff --git a/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/models/__init__.py b/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/modelgarden/__init__.py similarity index 77% rename from py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/models/__init__.py rename to py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/modelgarden/__init__.py index 39639eca6..39e451bdd 100644 --- a/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/models/__init__.py +++ b/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/modelgarden/__init__.py @@ -8,7 +8,7 @@ def package_name() -> str: - return 'genkit.plugins.vertex_ai.models' + return 'genkit.plugins.vertex_ai.modelgarden' __all__ = ['package_name'] diff --git a/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/options.py b/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/options.py new file mode 100644 index 000000000..cffda1219 --- /dev/null +++ b/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/options.py @@ -0,0 +1,67 @@ +# Copyright 2025 Google LLC +# SPDX-License-Identifier: Apache-2.0 + +"""Common options for plugin configuration.""" +import dataclasses +import json +import logging +import os + +import google.oauth2.credentials as oauth2_creds +from google.auth import credentials as auth_credentials + +from genkit.core.schemas import ModelInfo +from genkit.plugins.vertex_ai import constants as const + +LOG = logging.getLogger(__name__) + + +@dataclasses.dataclass(slots=True) +class PluginOptions: + project_id: str | None = None + location: str | None = None + google_auth: auth_credentials.Credentials | None = None + models: list[str | ModelInfo] = dataclasses.field(default_factory=list) + + +def get_project_from_firebase_config() -> str | None: + config = os.getenv(const.FIREBASE_CONFIG) + if config: + try: + project_id = json.loads(config)['projectId'] + return project_id + except json.JSONDecodeError: + LOG.error('Invalid JSON syntax in %s environment variable', + const.FIREBASE_CONFIG) + except KeyError: + LOG.error('projectId key is not in %s environment variable', + const.FIREBASE_CONFIG) + + return None + + +def get_plugin_parameters(options: PluginOptions | None): + project_id = options.project_id + if not project_id: + # The project_id retrieval order: + # - defined in a code + # - defined in GOOGLE_CLOUD_PROJECT env variable + # - defined in firebase config variable + # - defined by gcloud auth application-default login + # (by VertexAI Python library) + project_id = (os.getenv(const.GCLOUD_PROJECT) + or get_project_from_firebase_config()) + + location = (options.location + or os.getenv(const.GCLOUD_LOCATION) + or const.DEFAULT_REGION) + + credentials = options.google_auth + + sa_env = os.getenv(const.GCLOUD_SERVICE_ACCOUNT_CREDS) + if not credentials and sa_env: + # Credentials from oauth2 inherit from auth module credentials + credentials = oauth2_creds.Credentials.from_authorized_user_file( + sa_env, scopes=[const.GCLOUD_PLATFORM_OAUTH_SCOPE]) + + return project_id, location, credentials diff --git a/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/plugin_api.py b/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/plugin_api.py new file mode 100644 index 000000000..27116807a --- /dev/null +++ b/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/plugin_api.py @@ -0,0 +1,42 @@ +# Copyright 2025 Google LLC +# SPDX-License-Identifier: Apache-2.0 + +"""Google Cloud Vertex AI Plugin for Genkit.""" +import logging +from collections.abc import Callable + +import vertexai + +from genkit.plugins.vertex_ai.gemini import execute_gemini_request +from genkit.plugins.vertex_ai.options import ( + PluginOptions, + get_plugin_parameters, +) +from genkit.veneer.veneer import Genkit + +LOG = logging.getLogger(__name__) + + +def vertexAI(options: PluginOptions | None) -> Callable[[Genkit], None]: + + def plugin(ai: Genkit) -> None: + project_id, location, credentials = get_plugin_parameters(options) + vertexai.init(project=project_id, + location=location, + credentials=credentials) + + ai.define_model( + name=gemini('gemini-1.5-flash'), + fn=execute_gemini_request, + metadata={ + 'model': { + 'supports': {'multiturn': True}, + } + }, + ) + + return plugin + + +def gemini(name: str) -> str: + return f'vertexai/{name}' diff --git a/py/plugins/vertex-ai/tests/test_options.py b/py/plugins/vertex-ai/tests/test_options.py new file mode 100644 index 000000000..aee09c583 --- /dev/null +++ b/py/plugins/vertex-ai/tests/test_options.py @@ -0,0 +1,96 @@ +# Copyright 2025 Google LLC +# SPDX-License-Identifier: Apache-2.0 + +"""Tests parameter assignment that is set to the Vertex AI.""" +import json + +from genkit.plugins.vertex_ai import constants as const +from genkit.plugins.vertex_ai.options import ( + PluginOptions, + get_plugin_parameters, +) + +GCLOUD_ENV_PROJECT_ID = 'gcp_project_id' +GCLOUD_ENV_REGION = 'asia-east1' +SAMPLE_FIREBASE_CONFIG = {'projectId': 'firebase_gcp_project_id'} + + +def test_empty_location(): + options = PluginOptions() + _, location, _ = get_plugin_parameters(options) + assert location == const.DEFAULT_REGION + + +def test_specific_location(): + region = 'asia-east2' + options = PluginOptions(location=region) + _, location, _ = get_plugin_parameters(options) + assert location == region + + +def test_location_from_env(monkeypatch): + monkeypatch.setenv(const.GCLOUD_LOCATION, GCLOUD_ENV_REGION) + + options = PluginOptions() + _, location, _ = get_plugin_parameters(options) + assert location == GCLOUD_ENV_REGION + + +def test_location_priority(monkeypatch): + monkeypatch.setenv(const.GCLOUD_LOCATION, GCLOUD_ENV_REGION) + + region = 'asia-east2' + options = PluginOptions(location=region) + _, location, _ = get_plugin_parameters(options) + + assert location == region + + +def test_no_project_id(): + options = PluginOptions() + project_id, _, _ = get_plugin_parameters(options) + assert not project_id + + +def test_specific_project_id(): + expected_project_id = 'parameter-project-id' + options = PluginOptions(project_id=expected_project_id) + project_id, _, _ = get_plugin_parameters(options) + assert project_id == expected_project_id + + +def test_project_id_from_env(monkeypatch): + monkeypatch.setenv(const.GCLOUD_PROJECT, GCLOUD_ENV_PROJECT_ID) + + options = PluginOptions() + project_id, _, _ = get_plugin_parameters(options) + assert project_id == GCLOUD_ENV_PROJECT_ID + + +def test_project_id_from_firebase_config(monkeypatch): + monkeypatch.setenv(const.FIREBASE_CONFIG, + json.dumps(SAMPLE_FIREBASE_CONFIG)) + options = PluginOptions() + project_id, _, _ = get_plugin_parameters(options) + assert project_id == SAMPLE_FIREBASE_CONFIG['projectId'] + + +def test_project_id_env_priority(monkeypatch): + monkeypatch.setenv(const.FIREBASE_CONFIG, + json.dumps(SAMPLE_FIREBASE_CONFIG)) + monkeypatch.setenv(const.GCLOUD_PROJECT, GCLOUD_ENV_PROJECT_ID) + + options = PluginOptions() + project_id, _, _ = get_plugin_parameters(options) + assert project_id == GCLOUD_ENV_PROJECT_ID + + +def test_project_id_parameter_priority(monkeypatch): + monkeypatch.setenv(const.FIREBASE_CONFIG, + json.dumps(SAMPLE_FIREBASE_CONFIG)) + monkeypatch.setenv(const.GCLOUD_PROJECT, GCLOUD_ENV_PROJECT_ID) + + expected_project_id = 'parameter-project-id' + options = PluginOptions(project_id=expected_project_id) + project_id, _, _ = get_plugin_parameters(options) + assert project_id == expected_project_id diff --git a/py/samples/hello/hello.py b/py/samples/hello/hello.py index ff8a39515..7a38349af 100644 --- a/py/samples/hello/hello.py +++ b/py/samples/hello/hello.py @@ -2,12 +2,17 @@ # SPDX-License-Identifier: Apache-2.0 from typing import Any -from genkit.core.schemas import GenerateRequest, Message, Role, TextPart -from genkit.plugins.vertex_ai import gemini, vertex_ai +from genkit.core.schemas import Message, TextPart, GenerateRequest, Role +from genkit.plugins.vertex_ai import vertexAI, gemini, PluginOptions + from genkit.veneer.veneer import Genkit + from pydantic import BaseModel, Field -ai = Genkit(plugins=[vertex_ai()], model=gemini('gemini-1.5-flash')) + +ai = Genkit(plugins=[vertexAI(PluginOptions())], + model=gemini('gemini-1.5-flash')) + class MyInput(BaseModel): @@ -25,7 +30,6 @@ def hi_fn(hi_input) -> GenerateRequest: ] ) - # hi = ai.define_prompt( # name="hi", # fn=hi_fn, diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 000000000..d2e120516 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,6 @@ +[project] +name = "genkit" +version = "0.1.0" +description = "Add your description here" +requires-python = ">=3.12" +dependencies = []