Skip to content

Commit

Permalink
feat(bridge): spg server bridge supports config check and run solver (#…
Browse files Browse the repository at this point in the history
…288)

* fix mix reader (#270)

* feat(builder): add Azure Open AI Compatibility (#269)

* feat(llm): add Azure OpenAI client and vectorization support

* chore: add .DS_Store to .gitignore

* refactor(llm):add description for api_version and default value

* refactor(vectorize_model): added description for ap_version and default values for some params

* refactor(openai_model): enhance docstring for Azure AD token and deployment parameters

* fix(builder): fix markdown reader for id (#273)

* fix buidler init

* add pro commit

* rename graphalgoclient to graphclient

* first fix

* fix(examples): fix qa file name (#251)

* support custom kag config file (#279)

* x

* x (#280)

* bridge add solver

* x

* feat(bridge): spg server bridge (#283)

* x

* bridge add solver

* x

* add invoke

---------

Co-authored-by: joseosvaldo16 <[email protected]>
Co-authored-by: Xinhong Zhang <[email protected]>
3 people authored Jan 16, 2025
1 parent c2056ef commit fdea22b
Showing 8 changed files with 248 additions and 22 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -15,3 +15,4 @@
.idea/
.venv/
__pycache__/
.DS_Store
32 changes: 31 additions & 1 deletion kag/bridge/spg_server_bridge.py
Original file line number Diff line number Diff line change
@@ -16,7 +16,6 @@


def init_kag_config(project_id: str, host_addr: str):

os.environ[KAGConstants.ENV_KAG_PROJECT_ID] = project_id
os.environ[KAGConstants.ENV_KAG_PROJECT_HOST_ADDR] = host_addr
init_env()
@@ -47,3 +46,34 @@ def run_component(self, component_name, component_config, input_data):
if hasattr(instance.input_types, "from_dict"):
input_data = instance.input_types.from_dict(input_data)
return [x.to_dict() for x in instance.invoke(input_data, write_ckpt=False)]

def run_llm_config_check(self, llm_config):
from kag.common.llm.llm_config_checker import LLMConfigChecker

return LLMConfigChecker().check(llm_config)

def run_vectorizer_config_check(self, vec_config):
from kag.common.vectorize_model.vectorize_model_config_checker import (
VectorizeModelConfigChecker,
)

return VectorizeModelConfigChecker().check(vec_config)

def run_solver(
self,
project_id,
task_id,
query,
func_name="invoke",
is_report=True,
host_addr="http://127.0.0.1:8887",
):
from kag.solver.main_solver import SolverMain

return getattr(SolverMain(), func_name)(
project_id=project_id,
task_id=task_id,
query=query,
is_report=is_report,
host_addr=host_addr,
)
5 changes: 3 additions & 2 deletions kag/builder/component/reader/markdown_reader.py
Original file line number Diff line number Diff line change
@@ -21,6 +21,7 @@
from typing import List, Dict


from kag.common.utils import generate_hash_id
from kag.interface import ReaderABC
from kag.builder.model.chunk import Chunk
from kag.interface import LLMClient
@@ -299,7 +300,7 @@ def collect_children_content(n: MarkdownNode):
all_content.extend(child_content)

current_output = Chunk(
id=f"{id}_{len(outputs)}",
id=f"{generate_hash_id(full_title)}",
parent_id=parent_id,
name=full_title,
content="\n".join(filter(None, all_content)),
@@ -360,7 +361,7 @@ def collect_children_content(n: MarkdownNode):
all_content.extend(child_content)

current_output = Chunk(
id=f"{id}_{len(outputs)}",
id=f"{generate_hash_id(full_title)}",
parent_id=parent_id,
name=full_title,
content="\n".join(filter(None, all_content)),
8 changes: 4 additions & 4 deletions kag/builder/component/reader/mix_reader.py
Original file line number Diff line number Diff line change
@@ -52,7 +52,7 @@ def __init__(
dict_reader (DictReader, optional): Reader for dictionary inputs. Defaults to None.
"""
super().__init__()
self.parse_map = {
self.reader_map = {
"txt": txt_reader,
"pdf": pdf_reader,
"docx": docx_reader,
@@ -83,11 +83,11 @@ def _invoke(self, input: Input, **kwargs) -> List[Output]:
reader_type = "dict"

else:
if os.path.exists(input):
if not os.path.exists(input):
raise FileNotFoundError(f"File {input} not found.")

file_suffix = input.split(".")[-1]
if file_suffix not in self.parse_map:
if file_suffix not in self.reader_map:
raise NotImplementedError(
f"File suffix {file_suffix} not supported yet."
)
@@ -96,4 +96,4 @@ def _invoke(self, input: Input, **kwargs) -> List[Output]:
reader = self.reader_map[reader_type]
if reader is None:
raise KeyError(f"{reader_type} reader not correctly configured.")
return self.parse_map[file_suffix]._invoke(input)
return reader._invoke(input)
29 changes: 18 additions & 11 deletions kag/common/conf.py
Original file line number Diff line number Diff line change
@@ -97,7 +97,15 @@ def _closest_cfg(
return _closest_cfg(path.parent, path)


def load_config(prod: bool = False):
def validate_config_file(config_file: str):
if not config_file:
return False
if not os.path.exists(config_file):
return False
return True


def load_config(prod: bool = False, config_file: str = None):
"""
Get kag config file as a ConfigParser.
"""
@@ -121,7 +129,8 @@ def load_config(prod: bool = False):
config["vectorize_model"] = config["vectorizer"]
return config
else:
config_file = _closest_cfg()
if not validate_config_file(config_file):
config_file = _closest_cfg()
if os.path.exists(config_file) and os.path.isfile(config_file):
print(f"found config file: {config_file}")
with open(config_file, "r") as reader:
@@ -148,13 +157,11 @@ def init_log_config(self, config):
logging.getLogger("neo4j.io").setLevel(logging.INFO)
logging.getLogger("neo4j.pool").setLevel(logging.INFO)

def initialize(self, prod: bool = True):
config = load_config(prod)
def initialize(self, prod: bool = True, config_file: str = None):
config = load_config(prod, config_file)
if self._is_initialized:
print(
"Reinitialize the KAG configuration, an operation that should exclusively be triggered within the Java invocation context."
)
print(f"original config: {self.config}")
print("WARN: Reinitialize the KAG configuration.")
print(f"original config: {self.config}\n\n")
print(f"new config: {config}")
self.prod = prod
self.config = config
@@ -173,15 +180,15 @@ def all_config(self):
KAG_PROJECT_CONF = KAG_CONFIG.global_config


def init_env():
def init_env(config_file: str = None):
project_id = os.getenv(KAGConstants.ENV_KAG_PROJECT_ID)
host_addr = os.getenv(KAGConstants.ENV_KAG_PROJECT_HOST_ADDR)
if project_id and host_addr:
if project_id and host_addr and not validate_config_file(config_file):
prod = True
else:
prod = False
global KAG_CONFIG
KAG_CONFIG.initialize(prod)
KAG_CONFIG.initialize(prod, config_file)

if prod:
msg = "Done init config from server"
121 changes: 120 additions & 1 deletion kag/common/llm/openai_client.py
Original file line number Diff line number Diff line change
@@ -12,19 +12,22 @@


import json
from openai import OpenAI
from openai import OpenAI, AzureOpenAI
import logging

from kag.interface import LLMClient
from tenacity import retry, stop_after_attempt
from typing import Callable

logging.getLogger("openai").setLevel(logging.ERROR)
logging.getLogger("httpx").setLevel(logging.ERROR)
logger = logging.getLogger(__name__)

AzureADTokenProvider = Callable[[], str]

@LLMClient.register("maas")
@LLMClient.register("openai")

class OpenAIClient(LLMClient):
"""
A client class for interacting with the OpenAI API.
@@ -134,3 +137,119 @@ def call_with_json_parse(self, prompt):
except:
return rsp
return json_result
@LLMClient.register("azure_openai")
class AzureOpenAIClient (LLMClient):
def __init__(
self,
api_key: str,
base_url: str,
model: str,
stream: bool = False,
api_version: str = "2024-12-01-preview",
temperature: float = 0.7,
azure_deployment: str = None,
timeout: float = None,
azure_ad_token: str = None,
azure_ad_token_provider: AzureADTokenProvider = None,
):
"""
Initializes the AzureOpenAIClient instance.
Args:
api_key (str): The API key for accessing the Azure OpenAI API.
api_version (str): The API version for the Azure OpenAI API (eg. "2024-12-01-preview, 2024-10-01-preview,2024-05-01-preview").
base_url (str): The base URL for the Azure OpenAI API.
azure_deployment (str): The deployment name for the Azure OpenAI model
model (str): The default model to use for requests.
stream (bool, optional): Whether to stream the response. Defaults to False.
temperature (float, optional): The temperature parameter for the model. Defaults to 0.7.
timeout (float): The timeout duration for the service request. Defaults to None, means no timeout.
azure_ad_token: Your Azure Active Directory token, https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id
azure_ad_token_provider: A function that returns an Azure Active Directory token, will be invoked on every request.
azure_deployment: A model deployment, if given sets the base client URL to include `/deployments/{azure_deployment}`.
Note: this means you won't be able to use non-deployment endpoints. Not supported with Assistants APIs.
"""

self.api_key = api_key
self.base_url = base_url
self.azure_deployment = azure_deployment
self.model = model
self.stream = stream
self.temperature = temperature
self.timeout = timeout
self.api_version = api_version
self.azure_ad_token = azure_ad_token
self.azure_ad_token_provider = azure_ad_token_provider
self.client = AzureOpenAI(api_key=self.api_key, base_url=self.base_url,azure_deployment=self.azure_deployment ,model=self.model,api_version=self.api_version, azure_ad_token=self.azure_ad_token, azure_ad_token_provider=self.azure_ad_token_provider)
self.check()

def __call__(self, prompt: str, image_url: str = None):
"""
Executes a model request when the object is called and returns the result.
Parameters:
prompt (str): The prompt provided to the model.
Returns:
str: The response content generated by the model.
"""
# Call the model with the given prompt and return the response
if image_url:
message = [
{"role": "system", "content": "you are a helpful assistant"},
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{"type": "image_url", "image_url": {"url": image_url}},
],
},
]
response = self.client.chat.completions.create(
model=self.model,
messages=message,
stream=self.stream,
temperature=self.temperature,
timeout=self.timeout,
)
rsp = response.choices[0].message.content
return rsp

else:
message = [
{"role": "system", "content": "you are a helpful assistant"},
{"role": "user", "content": prompt},
]
response = self.client.chat.completions.create(
model=self.model,
messages=message,
stream=self.stream,
temperature=self.temperature,
timeout=self.timeout,
)
rsp = response.choices[0].message.content
return rsp
@retry(stop=stop_after_attempt(3))
def call_with_json_parse(self, prompt):
"""
Calls the model and attempts to parse the response into JSON format.
Parameters:
prompt (str): The prompt provided to the model.
Returns:
Union[dict, str]: If the response is valid JSON, returns the parsed dictionary; otherwise, returns the original response.
"""
# Call the model and attempt to parse the response into JSON format
rsp = self(prompt)
_end = rsp.rfind("```")
_start = rsp.find("```json")
if _end != -1 and _start != -1:
json_str = rsp[_start + len("```json") : _end].strip()
else:
json_str = rsp
try:
json_result = json.loads(json_str)
except:
return rsp
return json_result
72 changes: 70 additions & 2 deletions kag/common/vectorize_model/openai_model.py
Original file line number Diff line number Diff line change
@@ -10,9 +10,9 @@
# or implied.

from typing import Union, Iterable
from openai import OpenAI
from openai import OpenAI, AzureOpenAI
from kag.interface import VectorizeModelABC, EmbeddingVector

from typing import Callable

@VectorizeModelABC.register("openai")
class OpenAIVectorizeModel(VectorizeModelABC):
@@ -65,3 +65,71 @@ def vectorize(
else:
assert len(results) == len(texts)
return results

@VectorizeModelABC.register("azure_openai")
class AzureOpenAIVectorizeModel(VectorizeModelABC):
''' A class that extends the VectorizeModelABC base class.
It invokes Azure OpenAI or Azure OpenAI-compatible embedding services to convert texts into embedding vectors.
'''

def __init__(
self,
base_url: str,
api_key: str,
model: str = "text-embedding-ada-002",
api_version: str = "2024-12-01-preview",
vector_dimensions: int = None,
timeout: float = None,
azure_deployment: str = None,
azure_ad_token: str = None,
azure_ad_token_provider: Callable = None,
):
"""
Initializes the AzureOpenAIVectorizeModel instance.
Args:
model (str, optional): The model to use for embedding. Defaults to "text-embedding-3-small".
api_key (str, optional): The API key for accessing the Azure OpenAI service. Defaults to "".
api_version (str): The API version for the Azure OpenAI API (eg. "2024-12-01-preview, 2024-10-01-preview,2024-05-01-preview").
base_url (str, optional): The base URL for the Azure OpenAI service. Defaults to "".
vector_dimensions (int, optional): The number of dimensions for the embedding vectors. Defaults to None.
azure_ad_token: Your Azure Active Directory token, https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id
azure_ad_token_provider: A function that returns an Azure Active Directory token, will be invoked on every request.
azure_deployment: A model deployment, if given sets the base client URL to include `/deployments/{azure_deployment}`.
Note: this means you won't be able to use non-deployment endpoints. Not supported with Assistants APIs.
"""
super().__init__(vector_dimensions)
self.model = model
self.timeout = timeout
self.client = AzureOpenAI(
api_key=api_key,
base_url=base_url,
azure_deployment=azure_deployment,
model=model,
api_version=api_version,
azure_ad_token=azure_ad_token,
azure_ad_token_provider=azure_ad_token_provider,
)

def vectorize(
self, texts: Union[str, Iterable[str]]
) -> Union[EmbeddingVector, Iterable[EmbeddingVector]]:
"""
Vectorizes a text string into an embedding vector or multiple text strings into multiple embedding vectors.
Args:
texts (Union[str, Iterable[str]]): The text or texts to vectorize.
Returns:
Union[EmbeddingVector, Iterable[EmbeddingVector]]: The embedding vector(s) of the text(s).
"""
results = self.client.embeddings.create(
input=texts, model=self.model, timeout=self.timeout
)
results = [item.embedding for item in results.data]
if isinstance(texts, str):
assert len(results) == 1
return results[0]
else:
assert len(results) == len(texts)
return results
2 changes: 1 addition & 1 deletion kag/examples/domain_kg/README.md
Original file line number Diff line number Diff line change
@@ -60,7 +60,7 @@ cd builder && python indexer.py && cd ..

### Step 6: Execute the QA tasks

Execute [evaFor2wiki.py](./solver/evaFor2wiki.py) in the [solver](./solver) directory to generate the answer to the question.
Execute [qa.py](./solver/qa.py) in the [solver](./solver) directory to generate the answer to the question.

```bash
cd solver && python qa.py && cd ..

0 comments on commit fdea22b

Please sign in to comment.