Skip to content

Commit

Permalink
feat(py): Refactored Plugins API to follow generic Plugin interface.
Browse files Browse the repository at this point in the history
  • Loading branch information
kirgrim committed Feb 12, 2025
1 parent a881e74 commit f0c4bb7
Show file tree
Hide file tree
Showing 5 changed files with 125 additions and 45 deletions.
60 changes: 60 additions & 0 deletions py/packages/genkit/src/genkit/core/plugin_abc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright 2025 Google LLC
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations

import abc
import typing

from genkit.core.schemas import GenerateRequest, GenerateResponse

if typing.TYPE_CHECKING:
from genkit.veneer import Genkit


class Plugin(abc.ABC):
"""
Abstract class defining common interface
for the Genkit Plugin implementation
"""

@abc.abstractmethod
def attach_to_veneer(self, veneer: Genkit) -> None:
"""
Entrypoint for attaching the plugin to the requested Genkit Veneer
:param veneer: requested `genkit.veneer.Genkit` instance
"""
pass

def add_model_to_veneer(
self, veneer: Genkit, name: str, metadata: dict | None = None
) -> None:
"""
Generic method for defining arbitrary plugin's model
in the Genkit Registry
Uses self._model_callback as a generic callback wrapper,
the actual implementation is up to inherited Plugin's implementation
:param veneer: requested `genkit.veneer.Genkit` instance
:param name: name of the model to attach
:param metadata: metadata information associated
with the provided model (optional)
"""
if not metadata:
metadata = {}
veneer.define_model(
name=name, fn=self._model_callback, metadata=metadata
)

@abc.abstractmethod
def _model_callback(self, request: GenerateRequest) -> GenerateResponse:
"""
Wrapper around any plugin's model callback.
Is considered an entrypoint for any model's request
:param request: incoming request as generic
`genkit.core.schemas.GenerateRequest` instance
:returns: model response represented as generic
`genkit.core.schemas.GenerateResponse` instance
"""
pass
4 changes: 2 additions & 2 deletions py/packages/genkit/src/genkit/core/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import os
import sys
from collections.abc import Sequence
from typing import Any, cast
from typing import Any

import requests # type: ignore[import-untyped]
from opentelemetry import trace as trace_api
Expand Down Expand Up @@ -37,7 +37,7 @@ def export(self, spans: Sequence[ReadableSpan]) -> SpanExportResult:
'startTime': span.start_time / 1000000,
'endTime': span.end_time / 1000000,
'attributes': convert_attributes(
attributes=cast(span.attributes, dict), # type: ignore
attributes=span.attributes, # type: ignore
),
'displayName': span.name,
# "links": span.links,
Expand Down
7 changes: 3 additions & 4 deletions py/packages/genkit/src/genkit/veneer/veneer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,11 @@
from genkit.ai.model import ModelFn
from genkit.ai.prompt import PromptFn
from genkit.core.action import Action
from genkit.core.plugin_abc import Plugin
from genkit.core.reflection import make_reflection_server
from genkit.core.registry import Registry
from genkit.core.schemas import GenerateRequest, GenerateResponse, Message

Plugin = Callable[['Genkit'], None]


class Genkit:
"""An entrypoint for a user that encapsulate the SDK functionality."""
Expand Down Expand Up @@ -66,9 +65,9 @@ def delete_runtime_file() -> None:
self.thread = threading.Thread(target=self.start_server)
self.thread.start()

if plugins is not None:
if plugins:
for plugin in plugins:
plugin(self)
plugin.attach_to_veneer(veneer=self)

def start_server(self) -> None:
httpd = HTTPServer(
Expand Down
95 changes: 58 additions & 37 deletions py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
Google Cloud Vertex AI Plugin for Genkit.
"""

from collections.abc import Callable
from typing import Any

import vertexai
from genkit.core.plugin_abc import Plugin
from genkit.core.schemas import (
GenerateRequest,
GenerateResponse,
Expand All @@ -24,46 +25,66 @@ def package_name() -> str:
return 'genkit.plugins.vertex_ai'


def vertex_ai(project_id: str | None = None) -> Callable[[Genkit], None]:
def plugin(ai: Genkit) -> None:
vertexai.init(location='us-central1', project=project_id)

def handle_gemini_request(request: GenerateRequest) -> GenerateResponse:
gemini_msgs: list[Content] = []
for m in request.messages:
gemini_parts: list[Part] = []
for p in m.content:
if p.root.text is not None:
gemini_parts.append(Part.from_text(p.root.text))
else:
raise Exception('unsupported part type')
gemini_msgs.append(
Content(role=m.role.value, parts=gemini_parts)
)
model = GenerativeModel('gemini-1.5-flash-002')
response = model.generate_content(contents=gemini_msgs)
return GenerateResponse(
message=Message(
role=Role.model, content=[TextPart(text=response.text)]
)
)
def gemini(name: str) -> str:
return f'vertexai/{name}'


class VertexAI(Plugin):
LOCATION = 'us-central1'
VERTEX_AI_MODEL_NAME = gemini('gemini-1.5-flash')
VERTEX_AI_GENERATIVE_MODEL_NAME = 'gemini-1.5-flash-002'

def __init__(self, project_id: str | None = None):
self.project_id = project_id
vertexai.init(location=self.LOCATION, project=self.project_id)

def attach_to_veneer(self, veneer: 'Genkit') -> None:
self.add_model_to_veneer(veneer=veneer)

ai.define_model(
name='vertexai/gemini-1.5-flash',
fn=handle_gemini_request,
metadata={
'model': {
'label': 'banana',
'supports': {'multiturn': True},
}
},
def add_model_to_veneer(self, veneer: 'Genkit', **kwargs) -> None:
return super().add_model_to_veneer(
veneer=veneer,
name=self.VERTEX_AI_MODEL_NAME,
metadata=self.vertex_ai_model_metadata,
)

return plugin
@property
def vertex_ai_model_metadata(self) -> dict[str, dict[str, Any]]:
return {
'model': {
'label': 'banana',
'supports': {'multiturn': True},
}
}

def _model_callback(self, request: GenerateRequest) -> GenerateResponse:
return self._handle_gemini_request(request=request)

def gemini(name: str) -> str:
return f'vertexai/{name}'
def _handle_gemini_request(
self, request: GenerateRequest
) -> GenerateResponse:
gemini_msgs: list[Content] = []
for m in request.messages:
gemini_parts: list[Part] = []
for p in m.content:
if p.root.text is not None:
gemini_parts.append(Part.from_text(p.root.text))
else:
raise Exception('unsupported part type')
gemini_msgs.append(Content(role=m.role.value, parts=gemini_parts))
response = self.vertex_ai_generative_model.generate_content(
contents=gemini_msgs
)
return GenerateResponse(
message=Message(
role=Role.model,
content=[TextPart(text=response.text)],
)
)

@property
def vertex_ai_generative_model(self) -> GenerativeModel:
return GenerativeModel(self.VERTEX_AI_GENERATIVE_MODEL_NAME)


__all__ = ['package_name', 'vertex_ai', 'gemini']
__all__ = ['package_name', 'VertexAI', 'gemini']
4 changes: 2 additions & 2 deletions py/samples/hello/hello.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
from typing import Any

from genkit.core.schemas import GenerateRequest, Message, Role, TextPart
from genkit.plugins.vertex_ai import gemini, vertex_ai
from genkit.plugins.vertex_ai import VertexAI
from genkit.veneer.veneer import Genkit
from pydantic import BaseModel, Field

ai = Genkit(plugins=[vertex_ai()], model=gemini('gemini-1.5-flash'))
ai = Genkit(plugins=[VertexAI()], model=VertexAI.VERTEX_AI_MODEL_NAME)


class MyInput(BaseModel):
Expand Down

0 comments on commit f0c4bb7

Please sign in to comment.