Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(bridge): spg server bridge supports config check and run solver #288

Merged
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@
.idea/
.venv/
__pycache__/
.DS_Store
32 changes: 31 additions & 1 deletion kag/bridge/spg_server_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@


def init_kag_config(project_id: str, host_addr: str):

os.environ[KAGConstants.ENV_KAG_PROJECT_ID] = project_id
os.environ[KAGConstants.ENV_KAG_PROJECT_HOST_ADDR] = host_addr
init_env()
Expand Down Expand Up @@ -47,3 +46,34 @@ def run_component(self, component_name, component_config, input_data):
if hasattr(instance.input_types, "from_dict"):
input_data = instance.input_types.from_dict(input_data)
return [x.to_dict() for x in instance.invoke(input_data, write_ckpt=False)]

def run_llm_config_check(self, llm_config):
from kag.common.llm.llm_config_checker import LLMConfigChecker

return LLMConfigChecker().check(llm_config)

def run_vectorizer_config_check(self, vec_config):
from kag.common.vectorize_model.vectorize_model_config_checker import (
VectorizeModelConfigChecker,
)

return VectorizeModelConfigChecker().check(vec_config)

def run_solver(
self,
project_id,
task_id,
query,
func_name="invoke",
is_report=True,
host_addr="http://127.0.0.1:8887",
):
from kag.solver.main_solver import SolverMain

return getattr(SolverMain(), func_name)(
project_id=project_id,
task_id=task_id,
query=query,
is_report=is_report,
host_addr=host_addr,
)
5 changes: 3 additions & 2 deletions kag/builder/component/reader/markdown_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from typing import List, Dict


from kag.common.utils import generate_hash_id
from kag.interface import ReaderABC
from kag.builder.model.chunk import Chunk
from kag.interface import LLMClient
Expand Down Expand Up @@ -299,7 +300,7 @@ def collect_children_content(n: MarkdownNode):
all_content.extend(child_content)

current_output = Chunk(
id=f"{id}_{len(outputs)}",
id=f"{generate_hash_id(full_title)}",
parent_id=parent_id,
name=full_title,
content="\n".join(filter(None, all_content)),
Expand Down Expand Up @@ -360,7 +361,7 @@ def collect_children_content(n: MarkdownNode):
all_content.extend(child_content)

current_output = Chunk(
id=f"{id}_{len(outputs)}",
id=f"{generate_hash_id(full_title)}",
parent_id=parent_id,
name=full_title,
content="\n".join(filter(None, all_content)),
Expand Down
8 changes: 4 additions & 4 deletions kag/builder/component/reader/mix_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(
dict_reader (DictReader, optional): Reader for dictionary inputs. Defaults to None.
"""
super().__init__()
self.parse_map = {
self.reader_map = {
"txt": txt_reader,
"pdf": pdf_reader,
"docx": docx_reader,
Expand Down Expand Up @@ -83,11 +83,11 @@ def _invoke(self, input: Input, **kwargs) -> List[Output]:
reader_type = "dict"

else:
if os.path.exists(input):
if not os.path.exists(input):
raise FileNotFoundError(f"File {input} not found.")

file_suffix = input.split(".")[-1]
if file_suffix not in self.parse_map:
if file_suffix not in self.reader_map:
raise NotImplementedError(
f"File suffix {file_suffix} not supported yet."
)
Expand All @@ -96,4 +96,4 @@ def _invoke(self, input: Input, **kwargs) -> List[Output]:
reader = self.reader_map[reader_type]
if reader is None:
raise KeyError(f"{reader_type} reader not correctly configured.")
return self.parse_map[file_suffix]._invoke(input)
return reader._invoke(input)
29 changes: 18 additions & 11 deletions kag/common/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,15 @@ def _closest_cfg(
return _closest_cfg(path.parent, path)


def load_config(prod: bool = False):
def validate_config_file(config_file: str):
if not config_file:
return False
if not os.path.exists(config_file):
return False
return True


def load_config(prod: bool = False, config_file: str = None):
"""
Get kag config file as a ConfigParser.
"""
Expand All @@ -121,7 +129,8 @@ def load_config(prod: bool = False):
config["vectorize_model"] = config["vectorizer"]
return config
else:
config_file = _closest_cfg()
if not validate_config_file(config_file):
config_file = _closest_cfg()
if os.path.exists(config_file) and os.path.isfile(config_file):
print(f"found config file: {config_file}")
with open(config_file, "r") as reader:
Expand All @@ -148,13 +157,11 @@ def init_log_config(self, config):
logging.getLogger("neo4j.io").setLevel(logging.INFO)
logging.getLogger("neo4j.pool").setLevel(logging.INFO)

def initialize(self, prod: bool = True):
config = load_config(prod)
def initialize(self, prod: bool = True, config_file: str = None):
config = load_config(prod, config_file)
if self._is_initialized:
print(
"Reinitialize the KAG configuration, an operation that should exclusively be triggered within the Java invocation context."
)
print(f"original config: {self.config}")
print("WARN: Reinitialize the KAG configuration.")
print(f"original config: {self.config}\n\n")
print(f"new config: {config}")
self.prod = prod
self.config = config
Expand All @@ -173,15 +180,15 @@ def all_config(self):
KAG_PROJECT_CONF = KAG_CONFIG.global_config


def init_env():
def init_env(config_file: str = None):
project_id = os.getenv(KAGConstants.ENV_KAG_PROJECT_ID)
host_addr = os.getenv(KAGConstants.ENV_KAG_PROJECT_HOST_ADDR)
if project_id and host_addr:
if project_id and host_addr and not validate_config_file(config_file):
prod = True
else:
prod = False
global KAG_CONFIG
KAG_CONFIG.initialize(prod)
KAG_CONFIG.initialize(prod, config_file)

if prod:
msg = "Done init config from server"
Expand Down
121 changes: 120 additions & 1 deletion kag/common/llm/openai_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,22 @@


import json
from openai import OpenAI
from openai import OpenAI, AzureOpenAI
import logging

from kag.interface import LLMClient
from tenacity import retry, stop_after_attempt
from typing import Callable

logging.getLogger("openai").setLevel(logging.ERROR)
logging.getLogger("httpx").setLevel(logging.ERROR)
logger = logging.getLogger(__name__)

AzureADTokenProvider = Callable[[], str]

@LLMClient.register("maas")
@LLMClient.register("openai")

class OpenAIClient(LLMClient):
"""
A client class for interacting with the OpenAI API.
Expand Down Expand Up @@ -134,3 +137,119 @@ def call_with_json_parse(self, prompt):
except:
return rsp
return json_result
@LLMClient.register("azure_openai")
class AzureOpenAIClient (LLMClient):
def __init__(
self,
api_key: str,
base_url: str,
model: str,
stream: bool = False,
api_version: str = "2024-12-01-preview",
temperature: float = 0.7,
azure_deployment: str = None,
timeout: float = None,
azure_ad_token: str = None,
azure_ad_token_provider: AzureADTokenProvider = None,
):
"""
Initializes the AzureOpenAIClient instance.

Args:
api_key (str): The API key for accessing the Azure OpenAI API.
api_version (str): The API version for the Azure OpenAI API (eg. "2024-12-01-preview, 2024-10-01-preview,2024-05-01-preview").
base_url (str): The base URL for the Azure OpenAI API.
azure_deployment (str): The deployment name for the Azure OpenAI model
model (str): The default model to use for requests.
stream (bool, optional): Whether to stream the response. Defaults to False.
temperature (float, optional): The temperature parameter for the model. Defaults to 0.7.
timeout (float): The timeout duration for the service request. Defaults to None, means no timeout.
azure_ad_token: Your Azure Active Directory token, https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id
azure_ad_token_provider: A function that returns an Azure Active Directory token, will be invoked on every request.
azure_deployment: A model deployment, if given sets the base client URL to include `/deployments/{azure_deployment}`.
Note: this means you won't be able to use non-deployment endpoints. Not supported with Assistants APIs.
"""

self.api_key = api_key
self.base_url = base_url
self.azure_deployment = azure_deployment
self.model = model
self.stream = stream
self.temperature = temperature
self.timeout = timeout
self.api_version = api_version
self.azure_ad_token = azure_ad_token
self.azure_ad_token_provider = azure_ad_token_provider
self.client = AzureOpenAI(api_key=self.api_key, base_url=self.base_url,azure_deployment=self.azure_deployment ,model=self.model,api_version=self.api_version, azure_ad_token=self.azure_ad_token, azure_ad_token_provider=self.azure_ad_token_provider)
self.check()

def __call__(self, prompt: str, image_url: str = None):
"""
Executes a model request when the object is called and returns the result.

Parameters:
prompt (str): The prompt provided to the model.

Returns:
str: The response content generated by the model.
"""
# Call the model with the given prompt and return the response
if image_url:
message = [
{"role": "system", "content": "you are a helpful assistant"},
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{"type": "image_url", "image_url": {"url": image_url}},
],
},
]
response = self.client.chat.completions.create(
model=self.model,
messages=message,
stream=self.stream,
temperature=self.temperature,
timeout=self.timeout,
)
rsp = response.choices[0].message.content
return rsp

else:
message = [
{"role": "system", "content": "you are a helpful assistant"},
{"role": "user", "content": prompt},
]
response = self.client.chat.completions.create(
model=self.model,
messages=message,
stream=self.stream,
temperature=self.temperature,
timeout=self.timeout,
)
rsp = response.choices[0].message.content
return rsp
@retry(stop=stop_after_attempt(3))
def call_with_json_parse(self, prompt):
"""
Calls the model and attempts to parse the response into JSON format.

Parameters:
prompt (str): The prompt provided to the model.

Returns:
Union[dict, str]: If the response is valid JSON, returns the parsed dictionary; otherwise, returns the original response.
"""
# Call the model and attempt to parse the response into JSON format
rsp = self(prompt)
_end = rsp.rfind("```")
_start = rsp.find("```json")
if _end != -1 and _start != -1:
json_str = rsp[_start + len("```json") : _end].strip()
else:
json_str = rsp
try:
json_result = json.loads(json_str)
except:
return rsp
return json_result
72 changes: 70 additions & 2 deletions kag/common/vectorize_model/openai_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
# or implied.

from typing import Union, Iterable
from openai import OpenAI
from openai import OpenAI, AzureOpenAI
from kag.interface import VectorizeModelABC, EmbeddingVector

from typing import Callable

@VectorizeModelABC.register("openai")
class OpenAIVectorizeModel(VectorizeModelABC):
Expand Down Expand Up @@ -65,3 +65,71 @@ def vectorize(
else:
assert len(results) == len(texts)
return results

@VectorizeModelABC.register("azure_openai")
class AzureOpenAIVectorizeModel(VectorizeModelABC):
''' A class that extends the VectorizeModelABC base class.
It invokes Azure OpenAI or Azure OpenAI-compatible embedding services to convert texts into embedding vectors.
'''

def __init__(
self,
base_url: str,
api_key: str,
model: str = "text-embedding-ada-002",
api_version: str = "2024-12-01-preview",
vector_dimensions: int = None,
timeout: float = None,
azure_deployment: str = None,
azure_ad_token: str = None,
azure_ad_token_provider: Callable = None,
):
"""
Initializes the AzureOpenAIVectorizeModel instance.

Args:
model (str, optional): The model to use for embedding. Defaults to "text-embedding-3-small".
api_key (str, optional): The API key for accessing the Azure OpenAI service. Defaults to "".
api_version (str): The API version for the Azure OpenAI API (eg. "2024-12-01-preview, 2024-10-01-preview,2024-05-01-preview").
base_url (str, optional): The base URL for the Azure OpenAI service. Defaults to "".
vector_dimensions (int, optional): The number of dimensions for the embedding vectors. Defaults to None.
azure_ad_token: Your Azure Active Directory token, https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id
azure_ad_token_provider: A function that returns an Azure Active Directory token, will be invoked on every request.
azure_deployment: A model deployment, if given sets the base client URL to include `/deployments/{azure_deployment}`.
Note: this means you won't be able to use non-deployment endpoints. Not supported with Assistants APIs.
"""
super().__init__(vector_dimensions)
self.model = model
self.timeout = timeout
self.client = AzureOpenAI(
api_key=api_key,
base_url=base_url,
azure_deployment=azure_deployment,
model=model,
api_version=api_version,
azure_ad_token=azure_ad_token,
azure_ad_token_provider=azure_ad_token_provider,
)

def vectorize(
self, texts: Union[str, Iterable[str]]
) -> Union[EmbeddingVector, Iterable[EmbeddingVector]]:
"""
Vectorizes a text string into an embedding vector or multiple text strings into multiple embedding vectors.

Args:
texts (Union[str, Iterable[str]]): The text or texts to vectorize.

Returns:
Union[EmbeddingVector, Iterable[EmbeddingVector]]: The embedding vector(s) of the text(s).
"""
results = self.client.embeddings.create(
input=texts, model=self.model, timeout=self.timeout
)
results = [item.embedding for item in results.data]
if isinstance(texts, str):
assert len(results) == 1
return results[0]
else:
assert len(results) == len(texts)
return results
Loading
Loading