Skip to content

Commit

Permalink
feat(py): VertexAI options
Browse files Browse the repository at this point in the history
  • Loading branch information
Irillit committed Feb 12, 2025
1 parent a881e74 commit a2e17e1
Show file tree
Hide file tree
Showing 10 changed files with 381 additions and 69 deletions.
8 changes: 7 additions & 1 deletion py/packages/genkit/src/genkit/ai/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,13 @@
# SPDX-License-Identifier: Apache-2.0

from collections.abc import Callable
from pydantic import BaseModel
from genkit.core.schemas import GenerateRequest, GenerateResponse, ModelInfo

from genkit.core.schemas import GenerateRequest, GenerateResponse

ModelFn = Callable[[GenerateRequest], GenerateResponse]


class ModelReference(BaseModel):
name: str
info: ModelInfo
67 changes: 4 additions & 63 deletions py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,68 +2,9 @@
# SPDX-License-Identifier: Apache-2.0


"""
Google Cloud Vertex AI Plugin for Genkit.
"""
"""Google Cloud Vertex AI Plugin for Genkit."""
from genkit.plugins.vertex_ai.plugin_api import vertexAI, gemini
from genkit.plugins.vertex_ai.options import PluginOptions

from collections.abc import Callable

import vertexai
from genkit.core.schemas import (
GenerateRequest,
GenerateResponse,
Message,
Role,
TextPart,
)
from genkit.veneer.veneer import Genkit
from vertexai.generative_models import Content, GenerativeModel, Part


def package_name() -> str:
return 'genkit.plugins.vertex_ai'


def vertex_ai(project_id: str | None = None) -> Callable[[Genkit], None]:
def plugin(ai: Genkit) -> None:
vertexai.init(location='us-central1', project=project_id)

def handle_gemini_request(request: GenerateRequest) -> GenerateResponse:
gemini_msgs: list[Content] = []
for m in request.messages:
gemini_parts: list[Part] = []
for p in m.content:
if p.root.text is not None:
gemini_parts.append(Part.from_text(p.root.text))
else:
raise Exception('unsupported part type')
gemini_msgs.append(
Content(role=m.role.value, parts=gemini_parts)
)
model = GenerativeModel('gemini-1.5-flash-002')
response = model.generate_content(contents=gemini_msgs)
return GenerateResponse(
message=Message(
role=Role.model, content=[TextPart(text=response.text)]
)
)

ai.define_model(
name='vertexai/gemini-1.5-flash',
fn=handle_gemini_request,
metadata={
'model': {
'label': 'banana',
'supports': {'multiturn': True},
}
},
)

return plugin


def gemini(name: str) -> str:
return f'vertexai/{name}'


__all__ = ['package_name', 'vertex_ai', 'gemini']
__all__ = ['package_name', 'vertexAI', 'gemini', 'PluginOptions']
10 changes: 10 additions & 0 deletions py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Copyright 2025 Google LLC
# SPDX-License-Identifier: Apache-2.0

"""Constants for VertexAI plugin."""
GCLOUD_PROJECT = 'GCLOUD_PROJECT'
GCLOUD_LOCATION = 'GCLOUD_LOCATION'
GCLOUD_PLATFORM_OAUTH_SCOPE = 'https://www.googleapis.com/auth/cloud-platform'
GCLOUD_SERVICE_ACCOUNT_CREDS = 'GCLOUD_SERVICE_ACCOUNT_CREDS'
FIREBASE_CONFIG = 'FIREBASE_CONFIG'
DEFAULT_REGION = 'us-central1'
140 changes: 140 additions & 0 deletions py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/gemini.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
# Copyright 2025 Google LLC
# SPDX-License-Identifier: Apache-2.0

"""Processes Gemini request."""
import logging

from vertexai.generative_models import Content, GenerativeModel, Part

from genkit.ai.model import ModelReference
from genkit.core.schemas import (
GenerateRequest,
GenerateResponse,
Message,
ModelInfo,
Role,
Supports,
TextPart,
)

LOG = logging.getLogger(__name__)


# Deprecated on 2/15/2025
SUPPORTED_V1_MODELS = {
'gemini-1.0-pro': ModelReference(
name='vertexai/gemini-1.0-pro',
info=ModelInfo(
versions=['gemini-1.0-pro-001', 'gemini-1.0-pro-002'],
label='Vertex AI - Gemini Pro',
supports=Supports(
multiturn=True,
media=False,
tools=True,
systemRole=True
)
))
}

SUPPORTED_V15_MODELS = {
'gemini-1.5-pro': ModelReference(
name='vertexai/gemini-1.5-pro',
info=ModelInfo(
versions=['gemini-1.5-pro-001', 'gemini-1.5-pro-002'],
label='Vertex AI - Gemini 1.5 Pro',
supports=Supports(
multiturn=True,
media=True,
tools=True,
systemRole=True
)
)),
'gemini-1.5-flash': ModelReference(
name='vertexai/gemini-1.5-flash',
info=ModelInfo(
versions=['gemini-1.5-flash-001', 'gemini-1.5-flash-002'],
label='Vertex AI - Gemini 1.5 Flash',
supports=Supports(
multiturn=True,
media=True,
tools=True,
systemRole=True
)
)),
'gemini-2.0-flash-001': ModelReference(
name='vertexai/gemini-2.0-flash-001',
info=ModelInfo(
versions=[],
label='Vertex AI - Gemini 2.0 Flash 001',
supports=Supports(
multiturn=True,
media=True,
tools=True,
systemRole=True
)
)),
'gemini-2.0-flash-lite-preview-02-05': ModelReference(
name='vertexai/gemini-2.0-flash-lite-preview-02-05',
info=ModelInfo(
versions=[],
label='Vertex AI - Gemini 2.0 Flash Lite Preview 02-05',
supports=Supports(
multiturn=True,
media=True,
tools=True,
systemRole=True
)
)),
'gemini-2.0-pro-exp-02-05': ModelReference(
name='vertexai/gemini-2.0-pro-exp-02-05',
info=ModelInfo(
versions=[],
label='Vertex AI - Gemini 2.0 Flash Pro Experimental 02-05',
supports=Supports(
multiturn=True,
media=True,
tools=True,
systemRole=True
)
)),
}

SUPPORTED_MODELS = SUPPORTED_V1_MODELS | SUPPORTED_V15_MODELS


def nearest_gemini_model(model_name):
model = SUPPORTED_MODELS.get(model_name)
if model:
return model
return ModelReference(
name=f'vertexai/{model_name}',
info=ModelInfo(
versions=[],
label='Vertex AI - Gemini',
supports=Supports(
multiturn=True,
media=True,
tools=True,
systemRole=True
)
)
)


def execute_gemini_request(request: GenerateRequest) -> GenerateResponse:
messages: list[Content] = []
for msg in request.messages:
parts: list[Part] = []
for part in msg.content:
if hasattr(part, "text") and part.text:
parts.append(Part.from_text(part.text))
else:
LOG.error("Unsupported message type.")
messages.append(Content(role=msg.role.value, parts=parts))
model = GenerativeModel('gemini-1.5-flash-001')
response = model.generate_content(contents=messages)
return GenerateResponse(
message=Message(
role=Role.model, content=[TextPart(text=response.text)]
)
)
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@


def package_name() -> str:
return 'genkit.plugins.vertex_ai.models'
return 'genkit.plugins.vertex_ai.modelgarden'


__all__ = ['package_name']
67 changes: 67 additions & 0 deletions py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/options.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright 2025 Google LLC
# SPDX-License-Identifier: Apache-2.0

"""Common options for plugin configuration."""
import dataclasses
import json
import logging
import os

import google.oauth2.credentials as oauth2_creds
from google.auth import credentials as auth_credentials

from genkit.core.schemas import ModelInfo
from genkit.plugins.vertex_ai import constants as const

LOG = logging.getLogger(__name__)


@dataclasses.dataclass(slots=True)
class PluginOptions:
project_id: str | None = None
location: str | None = None
google_auth: auth_credentials.Credentials | None = None
models: list[str | ModelInfo] = dataclasses.field(default_factory=list)


def get_project_from_firebase_config() -> str | None:
config = os.getenv(const.FIREBASE_CONFIG)
if config:
try:
project_id = json.loads(config)['projectId']
return project_id
except json.JSONDecodeError:
LOG.error('Invalid JSON syntax in %s environment variable',
const.FIREBASE_CONFIG)
except KeyError:
LOG.error('projectId key is not in %s environment variable',
const.FIREBASE_CONFIG)

return None


def get_plugin_parameters(options: PluginOptions | None):
project_id = options.project_id
if not project_id:
# The project_id retrieval order:
# - defined in a code
# - defined in GOOGLE_CLOUD_PROJECT env variable
# - defined in firebase config variable
# - defined by gcloud auth application-default login
# (by VertexAI Python library)
project_id = (os.getenv(const.GCLOUD_PROJECT)
or get_project_from_firebase_config())

location = (options.location
or os.getenv(const.GCLOUD_LOCATION)
or const.DEFAULT_REGION)

credentials = options.google_auth

sa_env = os.getenv(const.GCLOUD_SERVICE_ACCOUNT_CREDS)
if not credentials and sa_env:
# Credentials from oauth2 inherit from auth module credentials
credentials = oauth2_creds.Credentials.from_authorized_user_file(
sa_env, scopes=[const.GCLOUD_PLATFORM_OAUTH_SCOPE])

return project_id, location, credentials
42 changes: 42 additions & 0 deletions py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/plugin_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright 2025 Google LLC
# SPDX-License-Identifier: Apache-2.0

"""Google Cloud Vertex AI Plugin for Genkit."""
import logging
from collections.abc import Callable

import vertexai

from genkit.plugins.vertex_ai.gemini import execute_gemini_request
from genkit.plugins.vertex_ai.options import (
PluginOptions,
get_plugin_parameters,
)
from genkit.veneer.veneer import Genkit

LOG = logging.getLogger(__name__)


def vertexAI(options: PluginOptions | None) -> Callable[[Genkit], None]:

def plugin(ai: Genkit) -> None:
project_id, location, credentials = get_plugin_parameters(options)
vertexai.init(project=project_id,
location=location,
credentials=credentials)

ai.define_model(
name=gemini('gemini-1.5-flash'),
fn=execute_gemini_request,
metadata={
'model': {
'supports': {'multiturn': True},
}
},
)

return plugin


def gemini(name: str) -> str:
return f'vertexai/{name}'
Loading

0 comments on commit a2e17e1

Please sign in to comment.