Skip to content

Commit

Permalink
added google gen ai client, support gemini-1.0-pro and gemini-1.5-pro…
Browse files Browse the repository at this point in the history
…-latest
  • Loading branch information
liyin2015 committed May 19, 2024
1 parent 67ca83e commit 8639b55
Show file tree
Hide file tree
Showing 8 changed files with 755 additions and 222 deletions.
1 change: 1 addition & 0 deletions components/api_client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
from .groq_client import *
from .anthropic_client import *
from .transformers_client import *
from .google_client import *
101 changes: 101 additions & 0 deletions components/api_client/google_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
"""
This demonstrates how to wrap the OpenAI API client to fit into LightRAG APIClient
"""

import os
from core.api_client import APIClient
from typing import Any, Dict, Sequence, Union, Optional, List
from core.data_classes import ModelType
import backoff

try:
import google.generativeai as genai
from google.api_core.exceptions import (
InternalServerError,
BadRequest,
GoogleAPICallError,
)
from google.generativeai.types import GenerateContentResponse
except ImportError:
raise ImportError("Please install google-generativeai to use GoogleGenAIClient")


class GoogleGenAIClient(APIClient):
def __init__(self):
super().__init__()
self.sync_client = self._init_sync_client()
self.async_client = None # only initialize if the async call is called
self.tested_llm_models = ["gemini-1.0-pro", "gemini-1.5-pro-latest"]

def _init_sync_client(self):
api_key = os.getenv("GOOGLE_API_KEY")
if not api_key:
raise ValueError("Environment variable GOOGLE_API_KEY must be set")
genai.configure(api_key=api_key)
return genai

# def _init_async_client(self):
# api_key = os.getenv("GOOGLE_API_KEY")
# if not api_key:
# raise ValueError("Environment variable GOOGLE_API_KEY must be set")
# return AsyncOpenAI()

def parse_chat_completion(self, completion: GenerateContentResponse) -> str:
"""
Parse the completion to a structure your sytem standarizes. (here is str)
# TODO: standardize the completion
"""
print(f"completion: {completion}")
return completion.text
# result["candidates"][0]["contents"]["parts"][0]["text"]

def convert_input_to_api_kwargs(
self,
input: Union[str, Sequence],
system_input: Optional[Union[str]] = None,
combined_model_kwargs: Dict = {},
model_type: ModelType = ModelType.UNDEFINED,
) -> Dict:
r"""
Specify the API input type and output api_kwargs that will be used in _call and _acall methods.
Convert the Component's standard input, and system_input(chat model) and model_kwargs into API-specific format
"""
final_model_kwargs = combined_model_kwargs.copy()
if model_type == ModelType.EMBEDDER:
if isinstance(input, str):
input = [input]
# convert input to input
assert isinstance(input, Sequence), "input must be a sequence of text"
final_model_kwargs["input"] = input
elif model_type == ModelType.LLM:
prompt: str = f"{system_input}\n\nUser query: {input}\n You:"

final_model_kwargs["prompt"] = prompt
else:
raise ValueError(f"model_type {model_type} is not supported")
return final_model_kwargs

@backoff.on_exception(
backoff.expo,
(
InternalServerError,
BadRequest,
GoogleAPICallError,
),
max_time=5,
)
def call(self, api_kwargs: Dict = {}, model_type: ModelType = ModelType.UNDEFINED):
"""
kwargs is the combined input and model_kwargs
"""

if model_type == ModelType.LLM:
# remove model from api_kwargs
model = api_kwargs.pop("model")
prompt = api_kwargs.pop("prompt")

config = genai.GenerationConfig(**api_kwargs)
llm = genai.GenerativeModel(model_name=model, generation_config=config)
return llm.generate_content(contents=prompt)
else:
raise ValueError(f"model_type {model_type} is not supported")
15 changes: 11 additions & 4 deletions components/api_client/transformers_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,10 @@ def __init__(self) -> None:
def _init_sync_client(self):
return TransformerEmbedder()

def _call(self, kwargs: Dict = {}, model_type: ModelType = ModelType.UNDEFINED):
return self.sync_client(**kwargs)
def call(self, api_kwargs: Dict = {}, model_type: ModelType = ModelType.UNDEFINED):
return self.sync_client(**api_kwargs)

def _combine_input_and_model_kwargs(
def convert_input_to_api_kwargs(
self,
input: Any,
combined_model_kwargs: dict = {},
Expand Down Expand Up @@ -131,9 +131,16 @@ def test_transformer_client():
"model": "thenlper/gte-base",
"mock": False,
}
api_kwargs = transformer_client.convert_input_to_api_kwargs(
input="Hello world",
combined_model_kwargs=kwargs,
model_type=ModelType.EMBEDDER,
)
print(api_kwargs)
output = transformer_client.call(
input="Hello world", model_type=ModelType.EMBEDDER, model_kwargs=kwargs
api_kwargs=api_kwargs, model_type=ModelType.EMBEDDER
)

print(transformer_client)
print(output)

Expand Down
Loading

0 comments on commit 8639b55

Please sign in to comment.