Skip to content

Commit

Permalink
Merge branch 'master' of github.com:OpenSPG/KAG into kag_law_test
Browse files Browse the repository at this point in the history
  • Loading branch information
royzhao committed Jan 20, 2025
2 parents 474e9f7 + 1e57016 commit dd8233c
Show file tree
Hide file tree
Showing 18 changed files with 104 additions and 57 deletions.
32 changes: 31 additions & 1 deletion kag/bridge/spg_server_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@


def init_kag_config(project_id: str, host_addr: str):

os.environ[KAGConstants.ENV_KAG_PROJECT_ID] = project_id
os.environ[KAGConstants.ENV_KAG_PROJECT_HOST_ADDR] = host_addr
init_env()
Expand Down Expand Up @@ -47,3 +46,34 @@ def run_component(self, component_name, component_config, input_data):
if hasattr(instance.input_types, "from_dict"):
input_data = instance.input_types.from_dict(input_data)
return [x.to_dict() for x in instance.invoke(input_data, write_ckpt=False)]

def run_llm_config_check(self, llm_config):
from kag.common.llm.llm_config_checker import LLMConfigChecker

return LLMConfigChecker().check(llm_config)

def run_vectorizer_config_check(self, vec_config):
from kag.common.vectorize_model.vectorize_model_config_checker import (
VectorizeModelConfigChecker,
)

return VectorizeModelConfigChecker().check(vec_config)

def run_solver(
self,
project_id,
task_id,
query,
func_name="invoke",
is_report=True,
host_addr="http://127.0.0.1:8887",
):
from kag.solver.main_solver import SolverMain

return getattr(SolverMain(), func_name)(
project_id=project_id,
task_id=task_id,
query=query,
is_report=is_report,
host_addr=host_addr,
)
14 changes: 10 additions & 4 deletions kag/builder/component/extractor/schema_constraint_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,9 @@ def named_entity_recognition(self, passage: str):
Returns:
The result of the named entity recognition operation.
"""
ner_result = self.llm.invoke({"input": passage}, self.ner_prompt)
ner_result = self.llm.invoke(
{"input": passage}, self.ner_prompt, with_except=False
)
if self.external_graph:
extra_ner_result = self.external_graph.ner(passage)
else:
Expand Down Expand Up @@ -133,7 +135,9 @@ def named_entity_standardization(self, passage: str, entities: List[Dict]):
The result of the named entity standardization operation.
"""
return self.llm.invoke(
{"input": passage, "named_entities": entities}, self.std_prompt
{"input": passage, "named_entities": entities},
self.std_prompt,
with_except=False,
)

@retry(stop=stop_after_attempt(3))
Expand All @@ -153,7 +157,9 @@ def relations_extraction(self, passage: str, entities: List[Dict]):

return []
return self.llm.invoke(
{"input": passage, "entity_list": entities}, self.relation_prompt
{"input": passage, "entity_list": entities},
self.relation_prompt,
with_except=False,
)

@retry(stop=stop_after_attempt(3))
Expand All @@ -170,7 +176,7 @@ def event_extraction(self, passage: str):
if self.event_prompt is None:
logger.debug("Event extraction prompt not configured, skip.")
return []
return self.llm.invoke({"input": passage}, self.event_prompt)
return self.llm.invoke({"input": passage}, self.event_prompt, with_except=False)

def parse_nodes_and_edges(self, entities: List[Dict], category: str = None):
"""
Expand Down
12 changes: 9 additions & 3 deletions kag/builder/component/extractor/schema_free_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,9 @@ def named_entity_recognition(self, passage: str):
Returns:
The result of the named entity recognition operation.
"""
ner_result = self.llm.invoke({"input": passage}, self.ner_prompt)
ner_result = self.llm.invoke(
{"input": passage}, self.ner_prompt, with_except=False
)
if self.external_graph:
extra_ner_result = self.external_graph.ner(passage)
else:
Expand Down Expand Up @@ -140,7 +142,9 @@ def named_entity_standardization(self, passage: str, entities: List[Dict]):
Standardized entity information.
"""
return self.llm.invoke(
{"input": passage, "named_entities": entities}, self.std_prompt
{"input": passage, "named_entities": entities},
self.std_prompt,
with_except=False,
)

@retry(stop=stop_after_attempt(3))
Expand All @@ -154,7 +158,9 @@ def triples_extraction(self, passage: str, entities: List[Dict]):
The result of the triples extraction operation.
"""
return self.llm.invoke(
{"input": passage, "entity_list": entities}, self.triple_prompt
{"input": passage, "entity_list": entities},
self.triple_prompt,
with_except=False,
)

def assemble_sub_graph_with_spg_records(self, entities: List[Dict]):
Expand Down
7 changes: 4 additions & 3 deletions kag/builder/component/postprocessor/kag_postprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class KAGPostProcessor(PostProcessorABC):

def __init__(
self,
similarity_threshold: float = 0.9,
similarity_threshold: float = None,
external_graph: ExternalGraphLoaderABC = None,
):
"""
Expand Down Expand Up @@ -180,8 +180,9 @@ def _invoke(self, input, **kwargs):
origin_num_nodes = len(input.nodes)
origin_num_edges = len(input.edges)
new_graph = self.filter_invalid_data(input)
self.similarity_based_link(new_graph)
self.external_graph_based_link(new_graph)
if self.similarity_threshold is not None:
self.similarity_based_link(new_graph)
self.external_graph_based_link(new_graph)
new_num_nodes = len(new_graph.nodes)
new_num_edges = len(new_graph.edges)
logger.debug(
Expand Down
29 changes: 18 additions & 11 deletions kag/common/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,15 @@ def _closest_cfg(
return _closest_cfg(path.parent, path)


def load_config(prod: bool = False):
def validate_config_file(config_file: str):
if not config_file:
return False
if not os.path.exists(config_file):
return False
return True


def load_config(prod: bool = False, config_file: str = None):
"""
Get kag config file as a ConfigParser.
"""
Expand All @@ -121,7 +129,8 @@ def load_config(prod: bool = False):
config["vectorize_model"] = config["vectorizer"]
return config
else:
config_file = _closest_cfg()
if not validate_config_file(config_file):
config_file = _closest_cfg()
if os.path.exists(config_file) and os.path.isfile(config_file):
print(f"found config file: {config_file}")
with open(config_file, "r") as reader:
Expand All @@ -148,13 +157,11 @@ def init_log_config(self, config):
logging.getLogger("neo4j.io").setLevel(logging.INFO)
logging.getLogger("neo4j.pool").setLevel(logging.INFO)

def initialize(self, prod: bool = True):
config = load_config(prod)
def initialize(self, prod: bool = True, config_file: str = None):
config = load_config(prod, config_file)
if self._is_initialized:
print(
"Reinitialize the KAG configuration, an operation that should exclusively be triggered within the Java invocation context."
)
print(f"original config: {self.config}")
print("WARN: Reinitialize the KAG configuration.")
print(f"original config: {self.config}\n\n")
print(f"new config: {config}")
self.prod = prod
self.config = config
Expand All @@ -173,15 +180,15 @@ def all_config(self):
KAG_PROJECT_CONF = KAG_CONFIG.global_config


def init_env():
def init_env(config_file: str = None):
project_id = os.getenv(KAGConstants.ENV_KAG_PROJECT_ID)
host_addr = os.getenv(KAGConstants.ENV_KAG_PROJECT_HOST_ADDR)
if project_id and host_addr:
if project_id and host_addr and not validate_config_file(config_file):
prod = True
else:
prod = False
global KAG_CONFIG
KAG_CONFIG.initialize(prod)
KAG_CONFIG.initialize(prod, config_file)

if prod:
msg = "Done init config from server"
Expand Down
1 change: 0 additions & 1 deletion kag/examples/2wiki/kag_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ kag_builder_pipeline:
type: dict_reader # kag.builder.component.reader.dict_reader.DictReader
post_processor:
type: kag_post_processor # kag.builder.component.postprocessor.kag_postprocessor.KAGPostProcessor
similarity_threshold: 0.9
splitter:
type: length_splitter # kag.builder.component.splitter.length_splitter.LengthSplitter
split_length: 100000
Expand Down
1 change: 0 additions & 1 deletion kag/examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ kag_builder_pipeline:
type: dict_reader # kag.builder.component.reader.dict_reader.DictReader
post_processor:
type: kag_post_processor # kag.builder.component.postprocessor.kag_postprocessor.KAGPostProcessor
similarity_threshold: 0.9
splitter:
type: length_splitter # kag.builder.component.splitter.length_splitter.LengthSplitter
split_length: 100000
Expand Down
1 change: 0 additions & 1 deletion kag/examples/README_cn.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ kag_builder_pipeline:
type: dict_reader # kag.builder.component.reader.dict_reader.DictReader
post_processor:
type: kag_post_processor # kag.builder.component.postprocessor.kag_postprocessor.KAGPostProcessor
similarity_threshold: 0.9
splitter:
type: length_splitter # kag.builder.component.splitter.length_splitter.LengthSplitter
split_length: 100000
Expand Down
1 change: 0 additions & 1 deletion kag/examples/baike/kag_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ kag_builder_pipeline:
type: txt_reader # kag.builder.component.reader.txt_reader.TXTReader
post_processor:
type: kag_post_processor # kag.builder.component.postprocessor.kag_postprocessor.KAGPostProcessor
similarity_threshold: 0.9
splitter:
type: length_splitter # kag.builder.component.splitter.length_splitter.LengthSplitter
split_length: 300
Expand Down
1 change: 0 additions & 1 deletion kag/examples/csqa/kag_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ kag_builder_pipeline:
type: txt_reader # kag.builder.component.reader.txt_reader.TXTReader
post_processor:
type: kag_post_processor # kag.builder.component.postprocessor.kag_postprocessor.KAGPostProcessor
similarity_threshold: 0.9
splitter:
type: length_splitter # kag.builder.component.splitter.length_splitter.LengthSplitter
split_length: 4950
Expand Down
1 change: 0 additions & 1 deletion kag/examples/domain_kg/kag_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ kag_builder_pipeline:
type: txt_reader # kag.builder.component.reader.text_reader.TXTReader
post_processor:
type: kag_post_processor # kag.builder.component.postprocessor.kag_postprocessor.KAGPostProcessor
similarity_threshold: 0.9
external_graph: *external_graph_loader
splitter:
type: length_splitter # kag.builder.component.splitter.length_splitter.LengthSplitter
Expand Down
1 change: 0 additions & 1 deletion kag/examples/example_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ kag_builder_pipeline:
type: dict_reader # kag.builder.component.reader.dict_reader.DictReader
post_processor:
type: kag_post_processor # kag.builder.component.postprocessor.kag_postprocessor.KAGPostProcessor
similarity_threshold: 0.9
splitter:
type: length_splitter # kag.builder.component.splitter.length_splitter.LengthSplitter
split_length: 100000
Expand Down
1 change: 0 additions & 1 deletion kag/examples/hotpotqa/kag_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ kag_builder_pipeline:
type: dict_reader # kag.builder.component.reader.dict_reader.DictReader
post_processor:
type: kag_post_processor # kag.builder.component.postprocessor.kag_postprocessor.KAGPostProcessor
similarity_threshold: 0.9
splitter:
type: length_splitter # kag.builder.component.splitter.length_splitter.LengthSplitter
split_length: 100000
Expand Down
1 change: 0 additions & 1 deletion kag/examples/medicine/kag_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ extract_runner:
name_col: title
post_processor:
type: kag_post_processor # kag.builder.component.postprocessor.kag_postprocessor.KAGPostProcessor
similarity_threshold: 0.9
splitter:
type: length_splitter # kag.builder.component.splitter.length_splitter.LengthSplitter
split_length: 100000
Expand Down
1 change: 0 additions & 1 deletion kag/examples/musique/kag_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ kag_builder_pipeline:
type: dict_reader # kag.builder.component.reader.dict_reader.DictReader
post_processor:
type: kag_post_processor # kag.builder.component.postprocessor.kag_postprocessor.KAGPostProcessor
similarity_threshold: 0.9
splitter:
type: length_splitter # kag.builder.component.splitter.length_splitter.LengthSplitter
split_length: 100000
Expand Down
4 changes: 0 additions & 4 deletions kag/interface/builder/postprocessor_abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,3 @@ def input_types(self):
@property
def output_types(self):
return SubGraph

@property
def ckpt_subdir(self):
return "postprocessor"
6 changes: 3 additions & 3 deletions kag/interface/common/llm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,12 +108,12 @@ def invoke(
logger.debug(f"Result: {result}")
except Exception as e:
import traceback

logger.error(f"Error {e} during invocation: {traceback.format_exc()}")
logger.info(f"Error {e} during invocation: {traceback.format_exc()}")
if with_except:
raise RuntimeError(
f"LLM invoke exception, info: {e}\nllm input: {input}\nllm output: {response}"
f"LLM invoke exception, info: {e}\nllm input: \n{prompt}\nllm output: \n{response}"
)

return result

def batch(
Expand Down
47 changes: 29 additions & 18 deletions kag/solver/retriever/impl/default_chunk_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,6 @@ def calculate_sim_scores(self, query: str, doc_nums: int):
Returns:
dict: A dictionary with keys as document chunk IDs and values as the vector similarity scores.
"""
scores = dict()
try:
scores = query_sim_doc_cache.get(query)
if scores:
Expand All @@ -186,6 +185,7 @@ def calculate_sim_scores(self, query: str, doc_nums: int):
scores = {item["node"]["id"]: item["score"] for item in top_k}
query_sim_doc_cache.put(query, scores)
except Exception as e:
scores = dict()
logger.error(f"run calculate_sim_scores failed, info: {e}", exc_info=True)
return scores

Expand Down Expand Up @@ -386,14 +386,20 @@ def convert_entity_data_to_ppr_cand(related_entities: List[EntityData]):
return matched_entities

def _parse_ner_list(self, query):
ner_list = ner_cache.get(query)
if ner_list:
return ner_list
ner_list = self.named_entity_recognition(query)
if self.with_semantic:
std_ner_list = self.named_entity_standardization(query, ner_list)
self.append_official_name(ner_list, std_ner_list)
ner_cache.put(query, ner_list)
ner_list = []
try:
ner_list = ner_cache.get(query)
if ner_list:
return ner_list
ner_list = self.named_entity_recognition(query)
if self.with_semantic:
std_ner_list = self.named_entity_standardization(query, ner_list)
self.append_official_name(ner_list, std_ner_list)
ner_cache.put(query, ner_list)
except Exception as e:
if not ner_list:
ner_list = []
logger.warning(f"_parse_ner_list {query} failed {e}", exc_info=True)
return ner_list

def recall_docs(
Expand Down Expand Up @@ -506,15 +512,20 @@ def get_all_docs_by_id(self, queries: List[str], doc_ids: list, top_k: int):
else:
doc_score = doc_ids[doc_id]
counter += 1
node = self.graph_api.get_entity_prop_by_id(
label=self.schema.get_label_within_prefix(CHUNK_TYPE),
biz_id=doc_id,
)
node_dict = dict(node.items())
matched_docs.append(
f"#{node_dict['name']}#{node_dict['content']}#{doc_score}"
)
hits_docs.add(node_dict["name"])
try:
node = self.graph_api.get_entity_prop_by_id(
label=self.schema.get_label_within_prefix(CHUNK_TYPE),
biz_id=doc_id,
)
node_dict = dict(node.items())
matched_docs.append(
f"#{node_dict['name']}#{node_dict['content']}#{doc_score}"
)
hits_docs.add(node_dict["name"])
except Exception as e:
logger.warning(
f"{doc_id} get_entity_prop_by_id failed: {e}", exc_info=True
)
query = "\n".join(queries)
try:
text_matched = self.search_api.search_text(
Expand Down

0 comments on commit dd8233c

Please sign in to comment.