From fdea22b62b6532175f1afc97e87610bf42f9194d Mon Sep 17 00:00:00 2001 From: zhuzhongshu123 <152354526+zhuzhongshu123@users.noreply.github.com> Date: Thu, 16 Jan 2025 14:42:26 +0800 Subject: [PATCH] feat(bridge): spg server bridge supports config check and run solver (#288) * fix mix reader (#270) * feat(builder): add Azure Open AI Compatibility (#269) * feat(llm): add Azure OpenAI client and vectorization support * chore: add .DS_Store to .gitignore * refactor(llm):add description for api_version and default value * refactor(vectorize_model): added description for ap_version and default values for some params * refactor(openai_model): enhance docstring for Azure AD token and deployment parameters * fix(builder): fix markdown reader for id (#273) * fix buidler init * add pro commit * rename graphalgoclient to graphclient * first fix * fix(examples): fix qa file name (#251) * support custom kag config file (#279) * x * x (#280) * bridge add solver * x * feat(bridge): spg server bridge (#283) * x * bridge add solver * x * add invoke --------- Co-authored-by: joseosvaldo16 Co-authored-by: Xinhong Zhang --- .gitignore | 1 + kag/bridge/spg_server_bridge.py | 32 ++++- .../component/reader/markdown_reader.py | 5 +- kag/builder/component/reader/mix_reader.py | 8 +- kag/common/conf.py | 29 +++-- kag/common/llm/openai_client.py | 121 +++++++++++++++++- kag/common/vectorize_model/openai_model.py | 72 ++++++++++- kag/examples/domain_kg/README.md | 2 +- 8 files changed, 248 insertions(+), 22 deletions(-) diff --git a/.gitignore b/.gitignore index 3dfa7d36..e7450b6e 100644 --- a/.gitignore +++ b/.gitignore @@ -15,3 +15,4 @@ .idea/ .venv/ __pycache__/ +.DS_Store \ No newline at end of file diff --git a/kag/bridge/spg_server_bridge.py b/kag/bridge/spg_server_bridge.py index 7fde8f72..51b0ca25 100644 --- a/kag/bridge/spg_server_bridge.py +++ b/kag/bridge/spg_server_bridge.py @@ -16,7 +16,6 @@ def init_kag_config(project_id: str, host_addr: str): - os.environ[KAGConstants.ENV_KAG_PROJECT_ID] = project_id os.environ[KAGConstants.ENV_KAG_PROJECT_HOST_ADDR] = host_addr init_env() @@ -47,3 +46,34 @@ def run_component(self, component_name, component_config, input_data): if hasattr(instance.input_types, "from_dict"): input_data = instance.input_types.from_dict(input_data) return [x.to_dict() for x in instance.invoke(input_data, write_ckpt=False)] + + def run_llm_config_check(self, llm_config): + from kag.common.llm.llm_config_checker import LLMConfigChecker + + return LLMConfigChecker().check(llm_config) + + def run_vectorizer_config_check(self, vec_config): + from kag.common.vectorize_model.vectorize_model_config_checker import ( + VectorizeModelConfigChecker, + ) + + return VectorizeModelConfigChecker().check(vec_config) + + def run_solver( + self, + project_id, + task_id, + query, + func_name="invoke", + is_report=True, + host_addr="http://127.0.0.1:8887", + ): + from kag.solver.main_solver import SolverMain + + return getattr(SolverMain(), func_name)( + project_id=project_id, + task_id=task_id, + query=query, + is_report=is_report, + host_addr=host_addr, + ) diff --git a/kag/builder/component/reader/markdown_reader.py b/kag/builder/component/reader/markdown_reader.py index ba212c8e..fba9d0b4 100644 --- a/kag/builder/component/reader/markdown_reader.py +++ b/kag/builder/component/reader/markdown_reader.py @@ -21,6 +21,7 @@ from typing import List, Dict +from kag.common.utils import generate_hash_id from kag.interface import ReaderABC from kag.builder.model.chunk import Chunk from kag.interface import LLMClient @@ -299,7 +300,7 @@ def collect_children_content(n: MarkdownNode): all_content.extend(child_content) current_output = Chunk( - id=f"{id}_{len(outputs)}", + id=f"{generate_hash_id(full_title)}", parent_id=parent_id, name=full_title, content="\n".join(filter(None, all_content)), @@ -360,7 +361,7 @@ def collect_children_content(n: MarkdownNode): all_content.extend(child_content) current_output = Chunk( - id=f"{id}_{len(outputs)}", + id=f"{generate_hash_id(full_title)}", parent_id=parent_id, name=full_title, content="\n".join(filter(None, all_content)), diff --git a/kag/builder/component/reader/mix_reader.py b/kag/builder/component/reader/mix_reader.py index 6af7380a..519d154f 100644 --- a/kag/builder/component/reader/mix_reader.py +++ b/kag/builder/component/reader/mix_reader.py @@ -52,7 +52,7 @@ def __init__( dict_reader (DictReader, optional): Reader for dictionary inputs. Defaults to None. """ super().__init__() - self.parse_map = { + self.reader_map = { "txt": txt_reader, "pdf": pdf_reader, "docx": docx_reader, @@ -83,11 +83,11 @@ def _invoke(self, input: Input, **kwargs) -> List[Output]: reader_type = "dict" else: - if os.path.exists(input): + if not os.path.exists(input): raise FileNotFoundError(f"File {input} not found.") file_suffix = input.split(".")[-1] - if file_suffix not in self.parse_map: + if file_suffix not in self.reader_map: raise NotImplementedError( f"File suffix {file_suffix} not supported yet." ) @@ -96,4 +96,4 @@ def _invoke(self, input: Input, **kwargs) -> List[Output]: reader = self.reader_map[reader_type] if reader is None: raise KeyError(f"{reader_type} reader not correctly configured.") - return self.parse_map[file_suffix]._invoke(input) + return reader._invoke(input) diff --git a/kag/common/conf.py b/kag/common/conf.py index d970054f..06be534a 100644 --- a/kag/common/conf.py +++ b/kag/common/conf.py @@ -97,7 +97,15 @@ def _closest_cfg( return _closest_cfg(path.parent, path) -def load_config(prod: bool = False): +def validate_config_file(config_file: str): + if not config_file: + return False + if not os.path.exists(config_file): + return False + return True + + +def load_config(prod: bool = False, config_file: str = None): """ Get kag config file as a ConfigParser. """ @@ -121,7 +129,8 @@ def load_config(prod: bool = False): config["vectorize_model"] = config["vectorizer"] return config else: - config_file = _closest_cfg() + if not validate_config_file(config_file): + config_file = _closest_cfg() if os.path.exists(config_file) and os.path.isfile(config_file): print(f"found config file: {config_file}") with open(config_file, "r") as reader: @@ -148,13 +157,11 @@ def init_log_config(self, config): logging.getLogger("neo4j.io").setLevel(logging.INFO) logging.getLogger("neo4j.pool").setLevel(logging.INFO) - def initialize(self, prod: bool = True): - config = load_config(prod) + def initialize(self, prod: bool = True, config_file: str = None): + config = load_config(prod, config_file) if self._is_initialized: - print( - "Reinitialize the KAG configuration, an operation that should exclusively be triggered within the Java invocation context." - ) - print(f"original config: {self.config}") + print("WARN: Reinitialize the KAG configuration.") + print(f"original config: {self.config}\n\n") print(f"new config: {config}") self.prod = prod self.config = config @@ -173,15 +180,15 @@ def all_config(self): KAG_PROJECT_CONF = KAG_CONFIG.global_config -def init_env(): +def init_env(config_file: str = None): project_id = os.getenv(KAGConstants.ENV_KAG_PROJECT_ID) host_addr = os.getenv(KAGConstants.ENV_KAG_PROJECT_HOST_ADDR) - if project_id and host_addr: + if project_id and host_addr and not validate_config_file(config_file): prod = True else: prod = False global KAG_CONFIG - KAG_CONFIG.initialize(prod) + KAG_CONFIG.initialize(prod, config_file) if prod: msg = "Done init config from server" diff --git a/kag/common/llm/openai_client.py b/kag/common/llm/openai_client.py index 0c8ff3aa..e4af7e2e 100644 --- a/kag/common/llm/openai_client.py +++ b/kag/common/llm/openai_client.py @@ -12,19 +12,22 @@ import json -from openai import OpenAI +from openai import OpenAI, AzureOpenAI import logging from kag.interface import LLMClient from tenacity import retry, stop_after_attempt +from typing import Callable logging.getLogger("openai").setLevel(logging.ERROR) logging.getLogger("httpx").setLevel(logging.ERROR) logger = logging.getLogger(__name__) +AzureADTokenProvider = Callable[[], str] @LLMClient.register("maas") @LLMClient.register("openai") + class OpenAIClient(LLMClient): """ A client class for interacting with the OpenAI API. @@ -134,3 +137,119 @@ def call_with_json_parse(self, prompt): except: return rsp return json_result +@LLMClient.register("azure_openai") +class AzureOpenAIClient (LLMClient): + def __init__( + self, + api_key: str, + base_url: str, + model: str, + stream: bool = False, + api_version: str = "2024-12-01-preview", + temperature: float = 0.7, + azure_deployment: str = None, + timeout: float = None, + azure_ad_token: str = None, + azure_ad_token_provider: AzureADTokenProvider = None, + ): + """ + Initializes the AzureOpenAIClient instance. + + Args: + api_key (str): The API key for accessing the Azure OpenAI API. + api_version (str): The API version for the Azure OpenAI API (eg. "2024-12-01-preview, 2024-10-01-preview,2024-05-01-preview"). + base_url (str): The base URL for the Azure OpenAI API. + azure_deployment (str): The deployment name for the Azure OpenAI model + model (str): The default model to use for requests. + stream (bool, optional): Whether to stream the response. Defaults to False. + temperature (float, optional): The temperature parameter for the model. Defaults to 0.7. + timeout (float): The timeout duration for the service request. Defaults to None, means no timeout. + azure_ad_token: Your Azure Active Directory token, https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id + azure_ad_token_provider: A function that returns an Azure Active Directory token, will be invoked on every request. + azure_deployment: A model deployment, if given sets the base client URL to include `/deployments/{azure_deployment}`. + Note: this means you won't be able to use non-deployment endpoints. Not supported with Assistants APIs. + """ + + self.api_key = api_key + self.base_url = base_url + self.azure_deployment = azure_deployment + self.model = model + self.stream = stream + self.temperature = temperature + self.timeout = timeout + self.api_version = api_version + self.azure_ad_token = azure_ad_token + self.azure_ad_token_provider = azure_ad_token_provider + self.client = AzureOpenAI(api_key=self.api_key, base_url=self.base_url,azure_deployment=self.azure_deployment ,model=self.model,api_version=self.api_version, azure_ad_token=self.azure_ad_token, azure_ad_token_provider=self.azure_ad_token_provider) + self.check() + + def __call__(self, prompt: str, image_url: str = None): + """ + Executes a model request when the object is called and returns the result. + + Parameters: + prompt (str): The prompt provided to the model. + + Returns: + str: The response content generated by the model. + """ + # Call the model with the given prompt and return the response + if image_url: + message = [ + {"role": "system", "content": "you are a helpful assistant"}, + { + "role": "user", + "content": [ + {"type": "text", "text": prompt}, + {"type": "image_url", "image_url": {"url": image_url}}, + ], + }, + ] + response = self.client.chat.completions.create( + model=self.model, + messages=message, + stream=self.stream, + temperature=self.temperature, + timeout=self.timeout, + ) + rsp = response.choices[0].message.content + return rsp + + else: + message = [ + {"role": "system", "content": "you are a helpful assistant"}, + {"role": "user", "content": prompt}, + ] + response = self.client.chat.completions.create( + model=self.model, + messages=message, + stream=self.stream, + temperature=self.temperature, + timeout=self.timeout, + ) + rsp = response.choices[0].message.content + return rsp + @retry(stop=stop_after_attempt(3)) + def call_with_json_parse(self, prompt): + """ + Calls the model and attempts to parse the response into JSON format. + + Parameters: + prompt (str): The prompt provided to the model. + + Returns: + Union[dict, str]: If the response is valid JSON, returns the parsed dictionary; otherwise, returns the original response. + """ + # Call the model and attempt to parse the response into JSON format + rsp = self(prompt) + _end = rsp.rfind("```") + _start = rsp.find("```json") + if _end != -1 and _start != -1: + json_str = rsp[_start + len("```json") : _end].strip() + else: + json_str = rsp + try: + json_result = json.loads(json_str) + except: + return rsp + return json_result \ No newline at end of file diff --git a/kag/common/vectorize_model/openai_model.py b/kag/common/vectorize_model/openai_model.py index 133b13a9..e03216de 100644 --- a/kag/common/vectorize_model/openai_model.py +++ b/kag/common/vectorize_model/openai_model.py @@ -10,9 +10,9 @@ # or implied. from typing import Union, Iterable -from openai import OpenAI +from openai import OpenAI, AzureOpenAI from kag.interface import VectorizeModelABC, EmbeddingVector - +from typing import Callable @VectorizeModelABC.register("openai") class OpenAIVectorizeModel(VectorizeModelABC): @@ -65,3 +65,71 @@ def vectorize( else: assert len(results) == len(texts) return results + +@VectorizeModelABC.register("azure_openai") +class AzureOpenAIVectorizeModel(VectorizeModelABC): + ''' A class that extends the VectorizeModelABC base class. + It invokes Azure OpenAI or Azure OpenAI-compatible embedding services to convert texts into embedding vectors. + ''' + + def __init__( + self, + base_url: str, + api_key: str, + model: str = "text-embedding-ada-002", + api_version: str = "2024-12-01-preview", + vector_dimensions: int = None, + timeout: float = None, + azure_deployment: str = None, + azure_ad_token: str = None, + azure_ad_token_provider: Callable = None, + ): + """ + Initializes the AzureOpenAIVectorizeModel instance. + + Args: + model (str, optional): The model to use for embedding. Defaults to "text-embedding-3-small". + api_key (str, optional): The API key for accessing the Azure OpenAI service. Defaults to "". + api_version (str): The API version for the Azure OpenAI API (eg. "2024-12-01-preview, 2024-10-01-preview,2024-05-01-preview"). + base_url (str, optional): The base URL for the Azure OpenAI service. Defaults to "". + vector_dimensions (int, optional): The number of dimensions for the embedding vectors. Defaults to None. + azure_ad_token: Your Azure Active Directory token, https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id + azure_ad_token_provider: A function that returns an Azure Active Directory token, will be invoked on every request. + azure_deployment: A model deployment, if given sets the base client URL to include `/deployments/{azure_deployment}`. + Note: this means you won't be able to use non-deployment endpoints. Not supported with Assistants APIs. + """ + super().__init__(vector_dimensions) + self.model = model + self.timeout = timeout + self.client = AzureOpenAI( + api_key=api_key, + base_url=base_url, + azure_deployment=azure_deployment, + model=model, + api_version=api_version, + azure_ad_token=azure_ad_token, + azure_ad_token_provider=azure_ad_token_provider, + ) + + def vectorize( + self, texts: Union[str, Iterable[str]] + ) -> Union[EmbeddingVector, Iterable[EmbeddingVector]]: + """ + Vectorizes a text string into an embedding vector or multiple text strings into multiple embedding vectors. + + Args: + texts (Union[str, Iterable[str]]): The text or texts to vectorize. + + Returns: + Union[EmbeddingVector, Iterable[EmbeddingVector]]: The embedding vector(s) of the text(s). + """ + results = self.client.embeddings.create( + input=texts, model=self.model, timeout=self.timeout + ) + results = [item.embedding for item in results.data] + if isinstance(texts, str): + assert len(results) == 1 + return results[0] + else: + assert len(results) == len(texts) + return results \ No newline at end of file diff --git a/kag/examples/domain_kg/README.md b/kag/examples/domain_kg/README.md index 49d7a8d7..b8ae3ea7 100644 --- a/kag/examples/domain_kg/README.md +++ b/kag/examples/domain_kg/README.md @@ -60,7 +60,7 @@ cd builder && python indexer.py && cd .. ### Step 6: Execute the QA tasks -Execute [evaFor2wiki.py](./solver/evaFor2wiki.py) in the [solver](./solver) directory to generate the answer to the question. +Execute [qa.py](./solver/qa.py) in the [solver](./solver) directory to generate the answer to the question. ```bash cd solver && python qa.py && cd ..