From ca31351971ea0d3e397525ac968af34da023fa9b Mon Sep 17 00:00:00 2001 From: zhuzhongshu123 <152354526+zhuzhongshu123@users.noreply.github.com> Date: Wed, 15 Jan 2025 18:10:55 +0800 Subject: [PATCH 1/5] support custom kag config file (#279) --- kag/common/conf.py | 29 ++++++++++++++++++----------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/kag/common/conf.py b/kag/common/conf.py index d970054f..06be534a 100644 --- a/kag/common/conf.py +++ b/kag/common/conf.py @@ -97,7 +97,15 @@ def _closest_cfg( return _closest_cfg(path.parent, path) -def load_config(prod: bool = False): +def validate_config_file(config_file: str): + if not config_file: + return False + if not os.path.exists(config_file): + return False + return True + + +def load_config(prod: bool = False, config_file: str = None): """ Get kag config file as a ConfigParser. """ @@ -121,7 +129,8 @@ def load_config(prod: bool = False): config["vectorize_model"] = config["vectorizer"] return config else: - config_file = _closest_cfg() + if not validate_config_file(config_file): + config_file = _closest_cfg() if os.path.exists(config_file) and os.path.isfile(config_file): print(f"found config file: {config_file}") with open(config_file, "r") as reader: @@ -148,13 +157,11 @@ def init_log_config(self, config): logging.getLogger("neo4j.io").setLevel(logging.INFO) logging.getLogger("neo4j.pool").setLevel(logging.INFO) - def initialize(self, prod: bool = True): - config = load_config(prod) + def initialize(self, prod: bool = True, config_file: str = None): + config = load_config(prod, config_file) if self._is_initialized: - print( - "Reinitialize the KAG configuration, an operation that should exclusively be triggered within the Java invocation context." - ) - print(f"original config: {self.config}") + print("WARN: Reinitialize the KAG configuration.") + print(f"original config: {self.config}\n\n") print(f"new config: {config}") self.prod = prod self.config = config @@ -173,15 +180,15 @@ def all_config(self): KAG_PROJECT_CONF = KAG_CONFIG.global_config -def init_env(): +def init_env(config_file: str = None): project_id = os.getenv(KAGConstants.ENV_KAG_PROJECT_ID) host_addr = os.getenv(KAGConstants.ENV_KAG_PROJECT_HOST_ADDR) - if project_id and host_addr: + if project_id and host_addr and not validate_config_file(config_file): prod = True else: prod = False global KAG_CONFIG - KAG_CONFIG.initialize(prod) + KAG_CONFIG.initialize(prod, config_file) if prod: msg = "Done init config from server" From deae27751010af386035532777c6cbd0c9c860df Mon Sep 17 00:00:00 2001 From: zhuzhongshu123 <152354526+zhuzhongshu123@users.noreply.github.com> Date: Fri, 17 Jan 2025 13:52:00 +0800 Subject: [PATCH 2/5] feat(bridge): spg server bridge supports config check and run solver (#287) * x * x (#280) * bridge add solver * x * feat(bridge): spg server bridge (#283) * x * bridge add solver * x * add invoke * llm client catch error --- kag/bridge/spg_server_bridge.py | 32 +++++++++++++++++++++++++++++- kag/interface/common/llm_client.py | 6 +++--- 2 files changed, 34 insertions(+), 4 deletions(-) diff --git a/kag/bridge/spg_server_bridge.py b/kag/bridge/spg_server_bridge.py index 7fde8f72..51b0ca25 100644 --- a/kag/bridge/spg_server_bridge.py +++ b/kag/bridge/spg_server_bridge.py @@ -16,7 +16,6 @@ def init_kag_config(project_id: str, host_addr: str): - os.environ[KAGConstants.ENV_KAG_PROJECT_ID] = project_id os.environ[KAGConstants.ENV_KAG_PROJECT_HOST_ADDR] = host_addr init_env() @@ -47,3 +46,34 @@ def run_component(self, component_name, component_config, input_data): if hasattr(instance.input_types, "from_dict"): input_data = instance.input_types.from_dict(input_data) return [x.to_dict() for x in instance.invoke(input_data, write_ckpt=False)] + + def run_llm_config_check(self, llm_config): + from kag.common.llm.llm_config_checker import LLMConfigChecker + + return LLMConfigChecker().check(llm_config) + + def run_vectorizer_config_check(self, vec_config): + from kag.common.vectorize_model.vectorize_model_config_checker import ( + VectorizeModelConfigChecker, + ) + + return VectorizeModelConfigChecker().check(vec_config) + + def run_solver( + self, + project_id, + task_id, + query, + func_name="invoke", + is_report=True, + host_addr="http://127.0.0.1:8887", + ): + from kag.solver.main_solver import SolverMain + + return getattr(SolverMain(), func_name)( + project_id=project_id, + task_id=task_id, + query=query, + is_report=is_report, + host_addr=host_addr, + ) diff --git a/kag/interface/common/llm_client.py b/kag/interface/common/llm_client.py index e6816896..f9571a71 100644 --- a/kag/interface/common/llm_client.py +++ b/kag/interface/common/llm_client.py @@ -77,7 +77,7 @@ def invoke( variables: Dict[str, Any], prompt_op: PromptABC, with_json_parse: bool = True, - with_except: bool = True, + with_except: bool = False, ): """ Call the model and process the result. @@ -109,10 +109,10 @@ def invoke( except Exception as e: import traceback - logger.error(f"Error {e} during invocation: {traceback.format_exc()}") + logger.debug(f"Error {e} during invocation: {traceback.format_exc()}") if with_except: raise RuntimeError( - f"LLM invoke exception, info: {e}\nllm input: {input}\nllm output: {response}" + f"LLM invoke exception, info: {e}\nllm input: \n{prompt}\nllm output: \n{response}" ) return result From 7666ca40dd99ce4a7b7ee3ca82bec60923c3659b Mon Sep 17 00:00:00 2001 From: zhuzhongshu123 <152354526+zhuzhongshu123@users.noreply.github.com> Date: Fri, 17 Jan 2025 17:11:51 +0800 Subject: [PATCH 3/5] feat(kag): catch unexpected exceptions (#298) * x (#280) * feat(bridge): spg server bridge (#283) * x * bridge add solver * x * feat(bridge): Spg server bridge check (#285) * x * bridge add solver * x * add invoke * feat(common): llm client catch exception (#294) * x * bridge add solver * x * add invoke * llm client catch error * feat(solver): catch chunk retriever exception (#297) * x * bridge add solver * x * add invoke * llm client catch error * catch exception * feat(common):llm except (#299) * x * bridge add solver * x * add invoke * llm client catch error * catch exception * print llm invoke error info * with except * feat(common): force raise except (#300) * x * bridge add solver * x * add invoke * llm client catch error * catch exception * print llm invoke error info * with except * force raise except --- .../extractor/schema_constraint_extractor.py | 14 ++++-- .../extractor/schema_free_extractor.py | 12 +++-- kag/interface/common/llm_client.py | 6 +-- .../retriever/impl/default_chunk_retrieval.py | 47 ++++++++++++------- 4 files changed, 51 insertions(+), 28 deletions(-) diff --git a/kag/builder/component/extractor/schema_constraint_extractor.py b/kag/builder/component/extractor/schema_constraint_extractor.py index 4dfbb2ac..fd857e6f 100644 --- a/kag/builder/component/extractor/schema_constraint_extractor.py +++ b/kag/builder/component/extractor/schema_constraint_extractor.py @@ -90,7 +90,9 @@ def named_entity_recognition(self, passage: str): Returns: The result of the named entity recognition operation. """ - ner_result = self.llm.invoke({"input": passage}, self.ner_prompt) + ner_result = self.llm.invoke( + {"input": passage}, self.ner_prompt, with_except=False + ) if self.external_graph: extra_ner_result = self.external_graph.ner(passage) else: @@ -133,7 +135,9 @@ def named_entity_standardization(self, passage: str, entities: List[Dict]): The result of the named entity standardization operation. """ return self.llm.invoke( - {"input": passage, "named_entities": entities}, self.std_prompt + {"input": passage, "named_entities": entities}, + self.std_prompt, + with_except=False, ) @retry(stop=stop_after_attempt(3)) @@ -153,7 +157,9 @@ def relations_extraction(self, passage: str, entities: List[Dict]): return [] return self.llm.invoke( - {"input": passage, "entity_list": entities}, self.relation_prompt + {"input": passage, "entity_list": entities}, + self.relation_prompt, + with_except=False, ) @retry(stop=stop_after_attempt(3)) @@ -170,7 +176,7 @@ def event_extraction(self, passage: str): if self.event_prompt is None: logger.debug("Event extraction prompt not configured, skip.") return [] - return self.llm.invoke({"input": passage}, self.event_prompt) + return self.llm.invoke({"input": passage}, self.event_prompt, with_except=False) def parse_nodes_and_edges(self, entities: List[Dict], category: str = None): """ diff --git a/kag/builder/component/extractor/schema_free_extractor.py b/kag/builder/component/extractor/schema_free_extractor.py index ccf29128..da932265 100644 --- a/kag/builder/component/extractor/schema_free_extractor.py +++ b/kag/builder/component/extractor/schema_free_extractor.py @@ -98,7 +98,9 @@ def named_entity_recognition(self, passage: str): Returns: The result of the named entity recognition operation. """ - ner_result = self.llm.invoke({"input": passage}, self.ner_prompt) + ner_result = self.llm.invoke( + {"input": passage}, self.ner_prompt, with_except=False + ) if self.external_graph: extra_ner_result = self.external_graph.ner(passage) else: @@ -140,7 +142,9 @@ def named_entity_standardization(self, passage: str, entities: List[Dict]): Standardized entity information. """ return self.llm.invoke( - {"input": passage, "named_entities": entities}, self.std_prompt + {"input": passage, "named_entities": entities}, + self.std_prompt, + with_except=False, ) @retry(stop=stop_after_attempt(3)) @@ -154,7 +158,9 @@ def triples_extraction(self, passage: str, entities: List[Dict]): The result of the triples extraction operation. """ return self.llm.invoke( - {"input": passage, "entity_list": entities}, self.triple_prompt + {"input": passage, "entity_list": entities}, + self.triple_prompt, + with_except=False, ) def assemble_sub_graph_with_spg_records(self, entities: List[Dict]): diff --git a/kag/interface/common/llm_client.py b/kag/interface/common/llm_client.py index f9571a71..aba82756 100644 --- a/kag/interface/common/llm_client.py +++ b/kag/interface/common/llm_client.py @@ -77,7 +77,7 @@ def invoke( variables: Dict[str, Any], prompt_op: PromptABC, with_json_parse: bool = True, - with_except: bool = False, + with_except: bool = True, ): """ Call the model and process the result. @@ -108,12 +108,12 @@ def invoke( logger.debug(f"Result: {result}") except Exception as e: import traceback - - logger.debug(f"Error {e} during invocation: {traceback.format_exc()}") + logger.info(f"Error {e} during invocation: {traceback.format_exc()}") if with_except: raise RuntimeError( f"LLM invoke exception, info: {e}\nllm input: \n{prompt}\nllm output: \n{response}" ) + return result def batch( diff --git a/kag/solver/retriever/impl/default_chunk_retrieval.py b/kag/solver/retriever/impl/default_chunk_retrieval.py index 90037122..b8d33696 100644 --- a/kag/solver/retriever/impl/default_chunk_retrieval.py +++ b/kag/solver/retriever/impl/default_chunk_retrieval.py @@ -171,7 +171,6 @@ def calculate_sim_scores(self, query: str, doc_nums: int): Returns: dict: A dictionary with keys as document chunk IDs and values as the vector similarity scores. """ - scores = dict() try: scores = query_sim_doc_cache.get(query) if scores: @@ -186,6 +185,7 @@ def calculate_sim_scores(self, query: str, doc_nums: int): scores = {item["node"]["id"]: item["score"] for item in top_k} query_sim_doc_cache.put(query, scores) except Exception as e: + scores = dict() logger.error(f"run calculate_sim_scores failed, info: {e}", exc_info=True) return scores @@ -386,14 +386,20 @@ def convert_entity_data_to_ppr_cand(related_entities: List[EntityData]): return matched_entities def _parse_ner_list(self, query): - ner_list = ner_cache.get(query) - if ner_list: - return ner_list - ner_list = self.named_entity_recognition(query) - if self.with_semantic: - std_ner_list = self.named_entity_standardization(query, ner_list) - self.append_official_name(ner_list, std_ner_list) - ner_cache.put(query, ner_list) + ner_list = [] + try: + ner_list = ner_cache.get(query) + if ner_list: + return ner_list + ner_list = self.named_entity_recognition(query) + if self.with_semantic: + std_ner_list = self.named_entity_standardization(query, ner_list) + self.append_official_name(ner_list, std_ner_list) + ner_cache.put(query, ner_list) + except Exception as e: + if not ner_list: + ner_list = [] + logger.warning(f"_parse_ner_list {query} failed {e}", exc_info=True) return ner_list def recall_docs( @@ -504,15 +510,20 @@ def get_all_docs_by_id(self, queries: List[str], doc_ids: list, top_k: int): else: doc_score = doc_ids[doc_id] counter += 1 - node = self.graph_api.get_entity_prop_by_id( - label=self.schema.get_label_within_prefix(CHUNK_TYPE), - biz_id=doc_id, - ) - node_dict = dict(node.items()) - matched_docs.append( - f"#{node_dict['name']}#{node_dict['content']}#{doc_score}" - ) - hits_docs.add(node_dict["name"]) + try: + node = self.graph_api.get_entity_prop_by_id( + label=self.schema.get_label_within_prefix(CHUNK_TYPE), + biz_id=doc_id, + ) + node_dict = dict(node.items()) + matched_docs.append( + f"#{node_dict['name']}#{node_dict['content']}#{doc_score}" + ) + hits_docs.add(node_dict["name"]) + except Exception as e: + logger.warning( + f"{doc_id} get_entity_prop_by_id failed: {e}", exc_info=True + ) query = "\n".join(queries) try: text_matched = self.search_api.search_text( From 4ad5bded265b768820c93d14294bb7661838a456 Mon Sep 17 00:00:00 2001 From: zhuzhongshu123 <152354526+zhuzhongshu123@users.noreply.github.com> Date: Sat, 18 Jan 2025 12:05:31 +0800 Subject: [PATCH 4/5] delete checkpoint of postprocess (#302) --- kag/interface/builder/postprocessor_abc.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/kag/interface/builder/postprocessor_abc.py b/kag/interface/builder/postprocessor_abc.py index 71240464..fedc0d84 100644 --- a/kag/interface/builder/postprocessor_abc.py +++ b/kag/interface/builder/postprocessor_abc.py @@ -28,7 +28,3 @@ def input_types(self): @property def output_types(self): return SubGraph - - @property - def ckpt_subdir(self): - return "postprocessor" From 1e5701637390f43c0b52743add27bd0b00d6ea8e Mon Sep 17 00:00:00 2001 From: zhuzhongshu123 <152354526+zhuzhongshu123@users.noreply.github.com> Date: Mon, 20 Jan 2025 11:19:44 +0800 Subject: [PATCH 5/5] disable entity linking in postprocess by default (#304) --- kag/builder/component/postprocessor/kag_postprocessor.py | 7 ++++--- kag/examples/2wiki/kag_config.yaml | 1 - kag/examples/README.md | 1 - kag/examples/README_cn.md | 1 - kag/examples/baike/kag_config.yaml | 1 - kag/examples/csqa/kag_config.yaml | 1 - kag/examples/domain_kg/kag_config.yaml | 1 - kag/examples/example_config.yaml | 1 - kag/examples/hotpotqa/kag_config.yaml | 1 - kag/examples/medicine/kag_config.yaml | 1 - kag/examples/musique/kag_config.yaml | 1 - 11 files changed, 4 insertions(+), 13 deletions(-) diff --git a/kag/builder/component/postprocessor/kag_postprocessor.py b/kag/builder/component/postprocessor/kag_postprocessor.py index 8af36b06..5fbdbdd3 100644 --- a/kag/builder/component/postprocessor/kag_postprocessor.py +++ b/kag/builder/component/postprocessor/kag_postprocessor.py @@ -35,7 +35,7 @@ class KAGPostProcessor(PostProcessorABC): def __init__( self, - similarity_threshold: float = 0.9, + similarity_threshold: float = None, external_graph: ExternalGraphLoaderABC = None, ): """ @@ -180,8 +180,9 @@ def _invoke(self, input, **kwargs): origin_num_nodes = len(input.nodes) origin_num_edges = len(input.edges) new_graph = self.filter_invalid_data(input) - self.similarity_based_link(new_graph) - self.external_graph_based_link(new_graph) + if self.similarity_threshold is not None: + self.similarity_based_link(new_graph) + self.external_graph_based_link(new_graph) new_num_nodes = len(new_graph.nodes) new_num_edges = len(new_graph.edges) logger.debug( diff --git a/kag/examples/2wiki/kag_config.yaml b/kag/examples/2wiki/kag_config.yaml index ac2c8110..5f558aab 100644 --- a/kag/examples/2wiki/kag_config.yaml +++ b/kag/examples/2wiki/kag_config.yaml @@ -47,7 +47,6 @@ kag_builder_pipeline: type: dict_reader # kag.builder.component.reader.dict_reader.DictReader post_processor: type: kag_post_processor # kag.builder.component.postprocessor.kag_postprocessor.KAGPostProcessor - similarity_threshold: 0.9 splitter: type: length_splitter # kag.builder.component.splitter.length_splitter.LengthSplitter split_length: 100000 diff --git a/kag/examples/README.md b/kag/examples/README.md index a6f1aeac..d2b70754 100644 --- a/kag/examples/README.md +++ b/kag/examples/README.md @@ -73,7 +73,6 @@ kag_builder_pipeline: type: dict_reader # kag.builder.component.reader.dict_reader.DictReader post_processor: type: kag_post_processor # kag.builder.component.postprocessor.kag_postprocessor.KAGPostProcessor - similarity_threshold: 0.9 splitter: type: length_splitter # kag.builder.component.splitter.length_splitter.LengthSplitter split_length: 100000 diff --git a/kag/examples/README_cn.md b/kag/examples/README_cn.md index 2ad523a0..7f4f1a2d 100644 --- a/kag/examples/README_cn.md +++ b/kag/examples/README_cn.md @@ -73,7 +73,6 @@ kag_builder_pipeline: type: dict_reader # kag.builder.component.reader.dict_reader.DictReader post_processor: type: kag_post_processor # kag.builder.component.postprocessor.kag_postprocessor.KAGPostProcessor - similarity_threshold: 0.9 splitter: type: length_splitter # kag.builder.component.splitter.length_splitter.LengthSplitter split_length: 100000 diff --git a/kag/examples/baike/kag_config.yaml b/kag/examples/baike/kag_config.yaml index e4ebca6d..33eb9f5a 100644 --- a/kag/examples/baike/kag_config.yaml +++ b/kag/examples/baike/kag_config.yaml @@ -49,7 +49,6 @@ kag_builder_pipeline: type: txt_reader # kag.builder.component.reader.txt_reader.TXTReader post_processor: type: kag_post_processor # kag.builder.component.postprocessor.kag_postprocessor.KAGPostProcessor - similarity_threshold: 0.9 splitter: type: length_splitter # kag.builder.component.splitter.length_splitter.LengthSplitter split_length: 300 diff --git a/kag/examples/csqa/kag_config.yaml b/kag/examples/csqa/kag_config.yaml index 4cc695ef..04a2a14b 100644 --- a/kag/examples/csqa/kag_config.yaml +++ b/kag/examples/csqa/kag_config.yaml @@ -47,7 +47,6 @@ kag_builder_pipeline: type: txt_reader # kag.builder.component.reader.txt_reader.TXTReader post_processor: type: kag_post_processor # kag.builder.component.postprocessor.kag_postprocessor.KAGPostProcessor - similarity_threshold: 0.9 splitter: type: length_splitter # kag.builder.component.splitter.length_splitter.LengthSplitter split_length: 4950 diff --git a/kag/examples/domain_kg/kag_config.yaml b/kag/examples/domain_kg/kag_config.yaml index 43a97e19..5bb5c1f9 100644 --- a/kag/examples/domain_kg/kag_config.yaml +++ b/kag/examples/domain_kg/kag_config.yaml @@ -68,7 +68,6 @@ kag_builder_pipeline: type: txt_reader # kag.builder.component.reader.text_reader.TXTReader post_processor: type: kag_post_processor # kag.builder.component.postprocessor.kag_postprocessor.KAGPostProcessor - similarity_threshold: 0.9 external_graph: *external_graph_loader splitter: type: length_splitter # kag.builder.component.splitter.length_splitter.LengthSplitter diff --git a/kag/examples/example_config.yaml b/kag/examples/example_config.yaml index e3d2bcef..b8cedc94 100644 --- a/kag/examples/example_config.yaml +++ b/kag/examples/example_config.yaml @@ -47,7 +47,6 @@ kag_builder_pipeline: type: dict_reader # kag.builder.component.reader.dict_reader.DictReader post_processor: type: kag_post_processor # kag.builder.component.postprocessor.kag_postprocessor.KAGPostProcessor - similarity_threshold: 0.9 splitter: type: length_splitter # kag.builder.component.splitter.length_splitter.LengthSplitter split_length: 100000 diff --git a/kag/examples/hotpotqa/kag_config.yaml b/kag/examples/hotpotqa/kag_config.yaml index 3b1985b7..3217d3b9 100644 --- a/kag/examples/hotpotqa/kag_config.yaml +++ b/kag/examples/hotpotqa/kag_config.yaml @@ -47,7 +47,6 @@ kag_builder_pipeline: type: dict_reader # kag.builder.component.reader.dict_reader.DictReader post_processor: type: kag_post_processor # kag.builder.component.postprocessor.kag_postprocessor.KAGPostProcessor - similarity_threshold: 0.9 splitter: type: length_splitter # kag.builder.component.splitter.length_splitter.LengthSplitter split_length: 100000 diff --git a/kag/examples/medicine/kag_config.yaml b/kag/examples/medicine/kag_config.yaml index e5e4035c..ff9c2478 100644 --- a/kag/examples/medicine/kag_config.yaml +++ b/kag/examples/medicine/kag_config.yaml @@ -51,7 +51,6 @@ extract_runner: name_col: title post_processor: type: kag_post_processor # kag.builder.component.postprocessor.kag_postprocessor.KAGPostProcessor - similarity_threshold: 0.9 splitter: type: length_splitter # kag.builder.component.splitter.length_splitter.LengthSplitter split_length: 100000 diff --git a/kag/examples/musique/kag_config.yaml b/kag/examples/musique/kag_config.yaml index 122021ec..8e29f5db 100644 --- a/kag/examples/musique/kag_config.yaml +++ b/kag/examples/musique/kag_config.yaml @@ -47,7 +47,6 @@ kag_builder_pipeline: type: dict_reader # kag.builder.component.reader.dict_reader.DictReader post_processor: type: kag_post_processor # kag.builder.component.postprocessor.kag_postprocessor.KAGPostProcessor - similarity_threshold: 0.9 splitter: type: length_splitter # kag.builder.component.splitter.length_splitter.LengthSplitter split_length: 100000