diff --git a/.github/workflows/cloud-code-scan.yml b/.github/workflows/cloud-code-scan.yml
new file mode 100644
index 00000000..a8afaf2c
--- /dev/null
+++ b/.github/workflows/cloud-code-scan.yml
@@ -0,0 +1,22 @@
+name: Alipay Cloud Devops Codescan
+on:
+ pull_request_target:
+jobs:
+ stc:
+ runs-on: ubuntu-latest
+ steps:
+ - name: codeScan
+ uses: layotto/alipay-cloud-devops-codescan@main
+ with:
+ parent_uid: ${{ secrets.ALI_PID }}
+ private_key: ${{ secrets.ALI_PK }}
+ scan_type: stc
+ sca:
+ runs-on: ubuntu-latest
+ steps:
+ - name: codeScan
+ uses: layotto/alipay-cloud-devops-codescan@main
+ with:
+ parent_uid: ${{ secrets.ALI_PID }}
+ private_key: ${{ secrets.ALI_PK }}
+ scan_type: sca
\ No newline at end of file
diff --git a/.github/workflows/code-format-check.yml b/.github/workflows/code-format-check.yml
new file mode 100644
index 00000000..f0e3cabf
--- /dev/null
+++ b/.github/workflows/code-format-check.yml
@@ -0,0 +1,28 @@
+name: Code Format Check
+
+on:
+ push:
+ pull_request:
+ workflow_dispatch:
+ repository_dispatch:
+ types: [my_event]
+jobs:
+ format-check:
+ runs-on: ubuntu-latest
+
+ steps:
+ - name: Checkout repository
+ uses: actions/checkout@v4
+
+ - name: Set up Python
+ uses: actions/setup-python@v5
+ with:
+ python-version: '3.10'
+
+ - name: Install dependencies
+ run: |
+ python -m pip install --upgrade pip
+ pip install pre-commit
+
+ - name: Run pre-commit
+ run: pre-commit run --all-files
\ No newline at end of file
diff --git a/.github/workflows/license-checker.yml b/.github/workflows/license-checker.yml
new file mode 100644
index 00000000..fa38b09f
--- /dev/null
+++ b/.github/workflows/license-checker.yml
@@ -0,0 +1,25 @@
+name: License Checker
+
+on:
+ push:
+ branches:
+ - master
+ pull_request:
+ branches:
+ - master
+
+jobs:
+ check:
+ name: "License Validation"
+ runs-on: ubuntu-latest
+
+ steps:
+ - uses: actions/checkout@v4
+ - name: Check License Header
+ uses: apache/skywalking-eyes@main
+ env:
+ GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
+ with:
+ log: info
+ - name: Check Dependencies' License
+ uses: apache/skywalking-eyes/dependency@main
\ No newline at end of file
diff --git a/.github/workflows/pr-title-check.yml b/.github/workflows/pr-title-check.yml
new file mode 100644
index 00000000..ae7befd4
--- /dev/null
+++ b/.github/workflows/pr-title-check.yml
@@ -0,0 +1,28 @@
+name: "Lint PR"
+
+on:
+ pull_request_target:
+ types:
+ - opened
+ - edited
+ - synchronize
+
+jobs:
+ main:
+ name: Validate PR title
+ runs-on: ubuntu-latest
+ steps:
+ # https://www.conventionalcommits.org/en/v1.0.0/#summary
+ - uses: amannn/action-semantic-pull-request@v5
+ env:
+ GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
+ with:
+ requireScope: true
+ subjectPattern: ^(?![A-Z]).+$
+ # If `subjectPattern` is configured, you can use this property to override
+ # the default error message that is shown when the pattern doesn't match.
+ # The variables `subject` and `title` can be used within the message.
+ subjectPatternError: |
+ The subject "{subject}" found in the pull request title "{title}"
+ didn't match the configured pattern. Please ensure that the subject
+ doesn't start with an uppercase character.
\ No newline at end of file
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
new file mode 100644
index 00000000..44b7bcd3
--- /dev/null
+++ b/.pre-commit-config.yaml
@@ -0,0 +1,12 @@
+repos:
+ - repo: https://github.com/psf/black
+ rev: 22.3.0
+ hooks:
+ - id: black
+ files: ^kag/.*\.py$
+ - repo: https://github.com/pycqa/flake8
+ rev: 4.0.1
+ hooks:
+ - id: flake8
+ files: ^kag/.*\.py$
+
\ No newline at end of file
diff --git a/kag/builder/component/aligner/__init__.py b/kag/builder/component/aligner/__init__.py
index 123acd8d..93aa6cd4 100644
--- a/kag/builder/component/aligner/__init__.py
+++ b/kag/builder/component/aligner/__init__.py
@@ -9,4 +9,3 @@
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
-
diff --git a/kag/builder/component/base.py b/kag/builder/component/base.py
index 047867c9..e51b370c 100644
--- a/kag/builder/component/base.py
+++ b/kag/builder/component/base.py
@@ -54,9 +54,9 @@ def _init_llm(self) -> LLMClient:
try:
config = ProjectClient().get_config(project_id)
llm_config.update(config.get("llm", {}))
- except:
+ except Exception as e:
logging.warning(
- f"Failed to get project config for project id: {project_id}"
+ f"Failed to get project config for project id: {project_id}, info: {e}"
)
llm = LLMClient.from_config(llm_config)
return llm
diff --git a/kag/builder/component/extractor/kag_extractor.py b/kag/builder/component/extractor/kag_extractor.py
index 9e00e54d..bb1be45e 100644
--- a/kag/builder/component/extractor/kag_extractor.py
+++ b/kag/builder/component/extractor/kag_extractor.py
@@ -40,10 +40,16 @@ def __init__(self, **kwargs):
super().__init__(**kwargs)
self.llm = self._init_llm()
self.prompt_config = self.config.get("prompt", {})
- self.biz_scene = self.prompt_config.get("biz_scene") or os.getenv("KAG_PROMPT_BIZ_SCENE", "default")
- self.language = self.prompt_config.get("language") or os.getenv("KAG_PROMPT_LANGUAGE", "en")
+ self.biz_scene = self.prompt_config.get("biz_scene") or os.getenv(
+ "KAG_PROMPT_BIZ_SCENE", "default"
+ )
+ self.language = self.prompt_config.get("language") or os.getenv(
+ "KAG_PROMPT_LANGUAGE", "en"
+ )
self.schema = SchemaClient(project_id=self.project_id).load()
- self.ner_prompt = PromptOp.load(self.biz_scene, "ner")(language=self.language, project_id=self.project_id)
+ self.ner_prompt = PromptOp.load(self.biz_scene, "ner")(
+ language=self.language, project_id=self.project_id
+ )
self.std_prompt = PromptOp.load(self.biz_scene, "std")(language=self.language)
self.triple_prompt = PromptOp.load(self.biz_scene, "triple")(
language=self.language
@@ -60,7 +66,9 @@ def __init__(self, **kwargs):
self.kg_types.append(type_name)
break
if self.kg_types:
- self.kg_prompt = SPG_KGPrompt(self.kg_types, language=self.language, project_id=self.project_id)
+ self.kg_prompt = SPG_KGPrompt(
+ self.kg_types, language=self.language, project_id=self.project_id
+ )
@property
def input_types(self) -> Type[Input]:
@@ -130,6 +138,7 @@ def assemble_sub_graph_with_spg_records(self, entities: List[Dict]):
continue
if prop_name in spg_type.properties:
from knext.schema.model.property import Property
+
prop: Property = spg_type.properties.get(prop_name)
o_label = prop.object_type_name_en
if o_label not in BASIC_TYPES:
@@ -137,10 +146,18 @@ def assemble_sub_graph_with_spg_records(self, entities: List[Dict]):
prop_value = [prop_value]
for o_name in prop_value:
sub_graph.add_node(id=o_name, name=o_name, label=o_label)
- sub_graph.add_edge(s_id=s_name, s_label=s_label, p=prop_name, o_id=o_name, o_label=o_label)
+ sub_graph.add_edge(
+ s_id=s_name,
+ s_label=s_label,
+ p=prop_name,
+ o_id=o_name,
+ o_label=o_label,
+ )
tmp_properties.pop(prop_name)
record["properties"] = tmp_properties
- sub_graph.add_node(id=s_name, name=s_name, label=s_label, properties=properties)
+ sub_graph.add_node(
+ id=s_name, name=s_name, label=s_label, properties=properties
+ )
return sub_graph, entities
@staticmethod
@@ -174,10 +191,9 @@ def get_category(entities_data, entity_name):
if o_category is None:
o_category = OTHER_TYPE
sub_graph.add_node(tri[2], tri[2], o_category)
-
- sub_graph.add_edge(
- tri[0], s_category, to_camel_case(tri[1]), tri[2], o_category
- )
+ edge_type = to_camel_case(tri[1])
+ if edge_type:
+ sub_graph.add_edge(tri[0], s_category, edge_type, tri[2], o_category)
return sub_graph
@@ -199,14 +215,18 @@ def assemble_sub_graph_with_chunk(sub_graph: SubGraph, chunk: Chunk):
"id": chunk.id,
"name": chunk.name,
"content": f"{chunk.name}\n{chunk.content}",
- **chunk.kwargs
+ **chunk.kwargs,
},
)
sub_graph.id = chunk.id
return sub_graph
def assemble_sub_graph(
- self, sub_graph: SubGraph, chunk: Chunk, entities: List[Dict], triples: List[list]
+ self,
+ sub_graph: SubGraph,
+ chunk: Chunk,
+ entities: List[Dict],
+ triples: List[list],
):
"""
Integrates entity and triple information into a subgraph, and associates it with a chunk of text.
@@ -311,7 +331,10 @@ def invoke(self, input: Input, **kwargs) -> List[Output]:
try:
entities = self.named_entity_recognition(passage)
sub_graph, entities = self.assemble_sub_graph_with_spg_records(entities)
- filtered_entities = [{k: v for k, v in ent.items() if k in ["entity", "category"]} for ent in entities]
+ filtered_entities = [
+ {k: v for k, v in ent.items() if k in ["entity", "category"]}
+ for ent in entities
+ ]
triples = self.triples_extraction(passage, filtered_entities)
std_entities = self.named_entity_standardization(passage, filtered_entities)
self.append_official_name(entities, std_entities)
diff --git a/kag/builder/component/extractor/spg_extractor.py b/kag/builder/component/extractor/spg_extractor.py
index c4369f21..dc160e06 100644
--- a/kag/builder/component/extractor/spg_extractor.py
+++ b/kag/builder/component/extractor/spg_extractor.py
@@ -42,7 +42,9 @@ def __init__(self, **kwargs):
self.spg_ner_types.append(type_name)
continue
self.kag_ner_types.append(type_name)
- self.kag_ner_prompt = PromptOp.load(self.biz_scene, "ner")(language=self.language, project_id=self.project_id)
+ self.kag_ner_prompt = PromptOp.load(self.biz_scene, "ner")(
+ language=self.language, project_id=self.project_id
+ )
self.spg_ner_prompt = SPG_KGPrompt(self.spg_ner_types, self.language)
@retry(stop=stop_after_attempt(3))
@@ -72,6 +74,7 @@ def assemble_sub_graph_with_spg_records(self, entities: List[Dict]):
continue
if prop_name in spg_type.properties:
from knext.schema.model.property import Property
+
prop: Property = spg_type.properties.get(prop_name)
o_label = prop.object_type_name_en
if o_label not in BASIC_TYPES:
@@ -79,10 +82,18 @@ def assemble_sub_graph_with_spg_records(self, entities: List[Dict]):
prop_value = [prop_value]
for o_name in prop_value:
sub_graph.add_node(id=o_name, name=o_name, label=o_label)
- sub_graph.add_edge(s_id=s_name, s_label=s_label, p=prop_name, o_id=o_name, o_label=o_label)
+ sub_graph.add_edge(
+ s_id=s_name,
+ s_label=s_label,
+ p=prop_name,
+ o_id=o_name,
+ o_label=o_label,
+ )
tmp_properties.pop(prop_name)
record["properties"] = tmp_properties
- sub_graph.add_node(id=s_name, name=s_name, label=s_label, properties=properties)
+ sub_graph.add_node(
+ id=s_name, name=s_name, label=s_label, properties=properties
+ )
return sub_graph, entities
def invoke(self, input: Input, **kwargs) -> List[Output]:
@@ -102,7 +113,10 @@ def invoke(self, input: Input, **kwargs) -> List[Output]:
try:
entities = self.named_entity_recognition(passage)
sub_graph, entities = self.assemble_sub_graph_with_spg_records(entities)
- filtered_entities = [{k: v for k, v in ent.items() if k in ["entity", "category"]} for ent in entities]
+ filtered_entities = [
+ {k: v for k, v in ent.items() if k in ["entity", "category"]}
+ for ent in entities
+ ]
triples = self.triples_extraction(passage, filtered_entities)
std_entities = self.named_entity_standardization(passage, filtered_entities)
self.append_official_name(entities, std_entities)
diff --git a/kag/builder/component/mapping/relation_mapping.py b/kag/builder/component/mapping/relation_mapping.py
index 47fa9f64..637e3d4f 100644
--- a/kag/builder/component/mapping/relation_mapping.py
+++ b/kag/builder/component/mapping/relation_mapping.py
@@ -40,7 +40,7 @@ def __init__(
subject_name: SPGTypeName,
predicate_name: RelationName,
object_name: SPGTypeName,
- **kwargs
+ **kwargs,
):
super().__init__(**kwargs)
schema = SchemaClient(project_id=self.project_id).load()
diff --git a/kag/builder/component/mapping/spg_type_mapping.py b/kag/builder/component/mapping/spg_type_mapping.py
index 49400f70..5ddac4ba 100644
--- a/kag/builder/component/mapping/spg_type_mapping.py
+++ b/kag/builder/component/mapping/spg_type_mapping.py
@@ -39,7 +39,9 @@ class SPGTypeMapping(MappingABC):
fuse_op (FuseOpABC, optional): The user-defined fuse operator. Defaults to None.
"""
- def __init__(self, spg_type_name: SPGTypeName, fuse_func: FuseFunc = None, **kwargs):
+ def __init__(
+ self, spg_type_name: SPGTypeName, fuse_func: FuseFunc = None, **kwargs
+ ):
super().__init__(**kwargs)
self.schema = SchemaClient(project_id=self.project_id).load()
assert (
diff --git a/kag/builder/component/mapping/spo_mapping.py b/kag/builder/component/mapping/spo_mapping.py
index 2ab11c93..8ff7eff3 100644
--- a/kag/builder/component/mapping/spo_mapping.py
+++ b/kag/builder/component/mapping/spo_mapping.py
@@ -20,7 +20,6 @@
class SPOMapping(MappingABC):
-
def __init__(self):
super().__init__()
self.s_type_col = None
@@ -39,7 +38,14 @@ def input_types(self) -> Type[Input]:
def output_types(self) -> Type[Output]:
return SubGraph
- def add_field_mappings(self, s_id_col: str, p_type_col: str, o_id_col: str, s_type_col: str = None, o_type_col: str = None):
+ def add_field_mappings(
+ self,
+ s_id_col: str,
+ p_type_col: str,
+ o_id_col: str,
+ s_type_col: str = None,
+ o_type_col: str = None,
+ ):
self.s_type_col = s_type_col
self.s_id_col = s_id_col
self.p_type_col = p_type_col
@@ -86,14 +92,21 @@ def assemble_sub_graph(self, record: Dict[str, str]):
sub_graph.add_node(id=o_id, name=o_id, label=o_type)
sub_properties = {}
if self.sub_property_col:
- sub_properties = json.loads(record.get(self.sub_property_col, '{}'))
+ sub_properties = json.loads(record.get(self.sub_property_col, "{}"))
sub_properties = {k: str(v) for k, v in sub_properties.items()}
else:
for target_name, source_names in self.sub_property_mapping.items():
for source_name in source_names:
value = record.get(source_name)
sub_properties[target_name] = value
- sub_graph.add_edge(s_id=s_id, s_label=s_type, p=p, o_id=o_id, o_label=o_type, properties=sub_properties)
+ sub_graph.add_edge(
+ s_id=s_id,
+ s_label=s_type,
+ p=p,
+ o_id=o_id,
+ o_label=o_type,
+ properties=sub_properties,
+ )
return sub_graph
def invoke(self, input: Input, **kwargs) -> List[Output]:
diff --git a/kag/builder/component/reader/__init__.py b/kag/builder/component/reader/__init__.py
index df6c45b5..d2235151 100644
--- a/kag/builder/component/reader/__init__.py
+++ b/kag/builder/component/reader/__init__.py
@@ -16,7 +16,11 @@
from kag.builder.component.reader.markdown_reader import MarkDownReader
from kag.builder.component.reader.docx_reader import DocxReader
from kag.builder.component.reader.txt_reader import TXTReader
-from kag.builder.component.reader.dataset_reader import HotpotqaCorpusReader, TwowikiCorpusReader, MusiqueCorpusReader
+from kag.builder.component.reader.dataset_reader import (
+ HotpotqaCorpusReader,
+ TwowikiCorpusReader,
+ MusiqueCorpusReader,
+)
from kag.builder.component.reader.yuque_reader import YuqueReader
__all__ = [
diff --git a/kag/builder/component/reader/csv_reader.py b/kag/builder/component/reader/csv_reader.py
index 9c7c157d..98f2dc23 100644
--- a/kag/builder/component/reader/csv_reader.py
+++ b/kag/builder/component/reader/csv_reader.py
@@ -75,13 +75,18 @@ def invoke(self, input: Input, **kwargs) -> List[Output]:
chunks = []
basename, _ = os.path.splitext(os.path.basename(input))
for idx, row in enumerate(data.to_dict(orient="records")):
- kwargs = {k: v for k, v in row.items() if k not in [self.id_col, self.name_col, self.content_col]}
+ kwargs = {
+ k: v
+ for k, v in row.items()
+ if k not in [self.id_col, self.name_col, self.content_col]
+ }
chunks.append(
Chunk(
- id=row.get(self.id_col) or Chunk.generate_hash_id(f"{input}#{idx}"),
+ id=row.get(self.id_col)
+ or Chunk.generate_hash_id(f"{input}#{idx}"),
name=row.get(self.name_col) or f"{basename}#{idx}",
content=row[self.content_col],
- **kwargs
+ **kwargs,
)
)
return chunks
diff --git a/kag/builder/component/reader/docx_reader.py b/kag/builder/component/reader/docx_reader.py
index f0006e97..952735c8 100644
--- a/kag/builder/component/reader/docx_reader.py
+++ b/kag/builder/component/reader/docx_reader.py
@@ -15,7 +15,6 @@
from docx import Document
-from kag.builder.component.reader import MarkDownReader
from kag.builder.model.chunk import Chunk
from kag.interface.builder import SourceReaderABC
from knext.common.base.runnable import Input, Output
diff --git a/kag/builder/component/reader/json_reader.py b/kag/builder/component/reader/json_reader.py
index e3752796..54237021 100644
--- a/kag/builder/component/reader/json_reader.py
+++ b/kag/builder/component/reader/json_reader.py
@@ -14,7 +14,6 @@
import os
from typing import List, Type, Dict, Union
-from kag.builder.component.reader.markdown_reader import MarkDownReader
from kag.builder.model.chunk import Chunk
from kag.interface.builder.reader_abc import SourceReaderABC
from knext.common.base.runnable import Input, Output
diff --git a/kag/builder/component/reader/markdown_reader.py b/kag/builder/component/reader/markdown_reader.py
index adfcffbd..a0ce7cf0 100644
--- a/kag/builder/component/reader/markdown_reader.py
+++ b/kag/builder/component/reader/markdown_reader.py
@@ -94,30 +94,28 @@ def tag_to_text(self, tag: bs4.element.Tag):
html_table = str(tag)
table_df = pd.read_html(html_table)[0]
return f"{self.TABLE_CHUCK_FLAG}{table_df.to_markdown(index=False)}{self.TABLE_CHUCK_FLAG}"
- except:
- logging.warning("parse table tag to text error", exc_info=True)
+ except Exception as e:
+ logging.warning(f"parse table tag to text error: {e}", exc_info=True)
return tag.text
@retry(stop=stop_after_attempt(5))
- def analyze_table(self, table,analyze_mathod="human"):
+ def analyze_table(self, table, analyze_mathod="human"):
if analyze_mathod == "llm":
- if self.llm_module == None:
+ if self.llm_module is None:
logging.INFO("llm_module is None, cannot use analyze_table")
return table
- variables = {
- "table": table
- }
+ variables = {"table": table}
response = self.llm_module.invoke(
- variables = variables,
- prompt_op = self.analyze_table_prompt,
- with_json_parse=False
+ variables=variables,
+ prompt_op=self.analyze_table_prompt,
+ with_json_parse=False,
)
- if response is None or response == "" or response == []:
+ if response is None or response == "" or response == []:
raise Exception("llm_module return None")
return response
else:
- from io import StringIO
import pandas as pd
+
try:
df = pd.read_html(StringIO(table))[0]
except Exception as e:
@@ -125,18 +123,16 @@ def analyze_table(self, table,analyze_mathod="human"):
return table
content = ""
for index, row in df.iterrows():
- content+=f"第{index+1}行的数据如下:"
+ content += f"第{index+1}行的数据如下:"
for col_name, value in row.items():
- content+=f"{col_name}的值为{value},"
- content+='\n'
+ content += f"{col_name}的值为{value},"
+ content += "\n"
return content
-
@retry(stop=stop_after_attempt(5))
def analyze_img(self, img_url):
response = requests.get(img_url)
response.raise_for_status()
- image_data = response.content
pass
@@ -188,11 +184,11 @@ def extract_table(self, level_tags, header=""):
return tables
def parse_level_tags(
- self,
- level_tags: list,
- level: str,
- parent_header: str = "",
- cur_header: str = "",
+ self,
+ level_tags: list,
+ level: str,
+ parent_header: str = "",
+ cur_header: str = "",
):
"""
Recursively parses level tags to organize them into a structured format.
@@ -264,10 +260,14 @@ def cut(self, level_tags, cur_level, final_level):
if cur_level == final_level:
cur_prefix = []
for sublevel_tags in level_tags:
- if (
- isinstance(sublevel_tags, tuple)
- ):
- cur_prefix.append(self.to_text([sublevel_tags,]))
+ if isinstance(sublevel_tags, tuple):
+ cur_prefix.append(
+ self.to_text(
+ [
+ sublevel_tags,
+ ]
+ )
+ )
else:
break
cur_prefix = "\n".join(cur_prefix)
@@ -281,9 +281,7 @@ def cut(self, level_tags, cur_level, final_level):
else:
cur_prefix = []
for sublevel_tags in level_tags:
- if (
- isinstance(sublevel_tags, tuple)
- ):
+ if isinstance(sublevel_tags, tuple):
cur_prefix.append(sublevel_tags[1].text)
else:
break
@@ -296,7 +294,9 @@ def cut(self, level_tags, cur_level, final_level):
output += self.cut(sublevel_tags, cur_level + 1, final_level)
return output
- def solve_content(self, id: str, title: str, content: str, **kwargs) -> List[Output]:
+ def solve_content(
+ self, id: str, title: str, content: str, **kwargs
+ ) -> List[Output]:
"""
Converts Markdown content into structured chunks.
@@ -352,7 +352,9 @@ def solve_content(self, id: str, title: str, content: str, **kwargs) -> List[Out
chunks.append(chunk)
return chunks
- def get_table_chuck(self, table_chunk_str: str, title: str, id: str, idx: int) -> Chunk:
+ def get_table_chuck(
+ self, table_chunk_str: str, title: str, id: str, idx: int
+ ) -> Chunk:
"""
convert table chunk
:param table_chunk_str:
@@ -369,7 +371,9 @@ def get_table_chuck(self, table_chunk_str: str, title: str, id: str, idx: int) -
content=table_chunk_str,
)
table_markdown_str = matches[0]
- html_table_str = markdown.markdown(table_markdown_str, extensions=["markdown.extensions.tables"])
+ html_table_str = markdown.markdown(
+ table_markdown_str, extensions=["markdown.extensions.tables"]
+ )
try:
df = pd.read_html(html_table_str)[0]
except Exception as e:
@@ -377,7 +381,9 @@ def get_table_chuck(self, table_chunk_str: str, title: str, id: str, idx: int) -
df = pd.DataFrame()
# 确认是表格Chunk,去除内容中的TABLE_CHUCK_FLAG
- replaced_table_text = re.sub(pattern, f'\n{table_markdown_str}\n', table_chunk_str, flags=re.DOTALL)
+ replaced_table_text = re.sub(
+ pattern, f"\n{table_markdown_str}\n", table_chunk_str, flags=re.DOTALL
+ )
return Chunk(
id=Chunk.generate_hash_id(f"{id}#{idx}"),
name=f"{title}#{idx}",
diff --git a/kag/builder/component/reader/pdf_reader.py b/kag/builder/component/reader/pdf_reader.py
index c60020d8..ffa40f2e 100644
--- a/kag/builder/component/reader/pdf_reader.py
+++ b/kag/builder/component/reader/pdf_reader.py
@@ -14,8 +14,7 @@
import re
from typing import List, Sequence, Type, Union
-from langchain_community.document_loaders import PyPDFLoader
-import pdfminer.layout
+import pdfminer.layout # noqa
from kag.builder.model.chunk import Chunk
from kag.interface.builder import SourceReaderABC
@@ -27,13 +26,7 @@
from pdfminer.layout import LTTextContainer, LTPage
from pdfminer.pdfparser import PDFParser
from pdfminer.pdfdocument import PDFDocument
-from pdfminer.layout import LAParams,LTTextBox
-from pdfminer.pdfpage import PDFPage
-from pdfminer.pdfparser import PDFParser
-from pdfminer.pdfinterp import PDFResourceManager, PDFPageInterpreter
-from pdfminer.converter import PDFPageAggregator
-from pdfminer.pdfpage import PDFTextExtractionNotAllowed
-import pdfminer
+import pdfminer # noqa
import logging
@@ -59,7 +52,6 @@ def __init__(self, **kwargs):
language = os.getenv("KAG_PROMPT_LANGUAGE", "zh")
self.prompt = OutlinePrompt(language)
-
@property
def input_types(self) -> Type[Input]:
return str
@@ -67,8 +59,8 @@ def input_types(self) -> Type[Input]:
@property
def output_types(self) -> Type[Output]:
return Chunk
-
- def outline_chunk(self, chunk: Union[Chunk, List[Chunk]],basename) -> List[Chunk]:
+
+ def outline_chunk(self, chunk: Union[Chunk, List[Chunk]], basename) -> List[Chunk]:
if isinstance(chunk, Chunk):
chunk = [chunk]
outlines = []
@@ -76,26 +68,28 @@ def outline_chunk(self, chunk: Union[Chunk, List[Chunk]],basename) -> List[Chunk
outline = self.llm.invoke({"input": c.content}, self.prompt)
outlines.extend(outline)
content = "\n".join([c.content for c in chunk])
- chunks = self.sep_by_outline(content, outlines,basename)
+ chunks = self.sep_by_outline(content, outlines, basename)
return chunks
-
- def sep_by_outline(self,content,outlines,basename):
+
+ def sep_by_outline(self, content, outlines, basename):
position_check = []
for outline in outlines:
start = content.find(outline)
- position_check.append((outline,start))
+ position_check.append((outline, start))
chunks = []
- for idx,pc in enumerate(position_check):
+ for idx, pc in enumerate(position_check):
chunk = Chunk(
- id = Chunk.generate_hash_id(f"{basename}#{pc[0]}"),
+ id=Chunk.generate_hash_id(f"{basename}#{pc[0]}"),
name=f"{basename}#{pc[0]}",
- content=content[pc[1]:position_check[idx+1][1] if idx+1 < len(position_check) else len(position_check)],
+ content=content[
+ pc[1] : position_check[idx + 1][1]
+ if idx + 1 < len(position_check)
+ else len(position_check)
+ ],
)
chunks.append(chunk)
return chunks
-
-
@staticmethod
def _process_single_page(
page: str,
@@ -170,22 +164,19 @@ def invoke(self, input: str, **kwargs) -> Sequence[Output]:
if not os.path.isfile(input):
raise FileNotFoundError(f"The file {input} does not exist.")
-
self.fd = open(input, "rb")
self.parser = PDFParser(self.fd)
self.document = PDFDocument(self.parser)
chunks = []
basename, _ = os.path.splitext(os.path.basename(input))
-
# get outline
try:
outlines = self.document.get_outlines()
except Exception as e:
logger.warning(f"loading PDF file: {e}")
self.outline_flag = False
-
-
+
if not self.outline_flag:
with open(input, "rb") as file:
@@ -201,18 +192,18 @@ def invoke(self, input: str, **kwargs) -> Sequence[Output]:
)
chunks.append(chunk)
try:
- outline_chunks = self.outline_chunk(chunks, basename)
+ outline_chunks = self.outline_chunk(chunks, basename)
except Exception as e:
raise RuntimeError(f"Error loading PDF file: {e}")
if len(outline_chunks) > 0:
chunks = outline_chunks
-
+
else:
split_words = []
-
+
for item in outlines:
level, title, dest, a, se = item
- split_words.append(title.strip().replace(" ",""))
+ split_words.append(title.strip().replace(" ", ""))
# save the outline position in content
try:
text = extract_text(input)
@@ -228,27 +219,33 @@ def invoke(self, input: str, **kwargs) -> Sequence[Output]:
sentences += cleaned_page
content = "".join(sentences)
- positions = [(input,0)]
+ positions = [(input, 0)]
for split_word in split_words:
pattern = re.compile(split_word)
- for i,match in enumerate(re.finditer(pattern, content)):
+ for i, match in enumerate(re.finditer(pattern, content)):
if i == 1:
start, end = match.span()
- positions.append((split_word,start))
-
- for idx,position in enumerate(positions):
+ positions.append((split_word, start))
+
+ for idx, position in enumerate(positions):
chunk = Chunk(
- id = Chunk.generate_hash_id(f"{basename}#{position[0]}"),
+ id=Chunk.generate_hash_id(f"{basename}#{position[0]}"),
name=f"{basename}#{position[0]}",
- content=content[position[1]:positions[idx+1][1] if idx+1 < len(positions) else None],
+ content=content[
+ position[1] : positions[idx + 1][1]
+ if idx + 1 < len(positions)
+ else None
+ ],
)
chunks.append(chunk)
return chunks
-if __name__ == '__main__':
+if __name__ == "__main__":
reader = PDFReader(split_using_outline=True)
- pdf_path = os.path.join(os.path.dirname(__file__),"../../../../tests/builder/data/aiwen.pdf")
+ pdf_path = os.path.join(
+ os.path.dirname(__file__), "../../../../tests/builder/data/aiwen.pdf"
+ )
chunk = reader.invoke(pdf_path)
- print(chunk)
\ No newline at end of file
+ print(chunk)
diff --git a/kag/builder/component/reader/txt_reader.py b/kag/builder/component/reader/txt_reader.py
index 6f9d7a08..15a84326 100644
--- a/kag/builder/component/reader/txt_reader.py
+++ b/kag/builder/component/reader/txt_reader.py
@@ -51,7 +51,7 @@ def invoke(self, input: Input, **kwargs) -> List[Output]:
try:
if os.path.exists(input):
- with open(input, "r", encoding='utf-8') as f:
+ with open(input, "r", encoding="utf-8") as f:
content = f.read()
else:
content = input
diff --git a/kag/builder/component/splitter/base_table_splitter.py b/kag/builder/component/splitter/base_table_splitter.py
index 7ccef439..7a775f3e 100644
--- a/kag/builder/component/splitter/base_table_splitter.py
+++ b/kag/builder/component/splitter/base_table_splitter.py
@@ -10,9 +10,6 @@
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
-from abc import ABC
-from typing import Type, List, Union
-
from kag.builder.model.chunk import Chunk
from kag.interface.builder import SplitterABC
@@ -27,7 +24,9 @@ def split_table(self, org_chunk: Chunk, chunk_size: int = 2000, sep: str = "\n")
split markdown format table into smaller markdown table
"""
try:
- return self._split_table(org_chunk=org_chunk, chunk_size=chunk_size, sep=sep)
+ return self._split_table(
+ org_chunk=org_chunk, chunk_size=chunk_size, sep=sep
+ )
except Exception:
return None
@@ -63,7 +62,7 @@ def _split_table(self, org_chunk: Chunk, chunk_size: int = 2000, sep: str = "\n"
name=f"{org_chunk.name}#{idx}",
content=sep.join(sentences),
type=org_chunk.type,
- **org_chunk.kwargs
+ **org_chunk.kwargs,
)
output.append(chunk)
return output
diff --git a/kag/builder/component/splitter/length_splitter.py b/kag/builder/component/splitter/length_splitter.py
index 2e9dcfcd..85beb784 100644
--- a/kag/builder/component/splitter/length_splitter.py
+++ b/kag/builder/component/splitter/length_splitter.py
@@ -10,7 +10,7 @@
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
-from typing import Type, List, Union
+from typing import Type, List
from kag.builder.model.chunk import Chunk, ChunkTypeEnum
from knext.common.base.runnable import Input, Output
@@ -55,7 +55,7 @@ def split_sentence(self, content):
for idx, char in enumerate(content):
if char in sentence_delimiters:
end = idx
- tmp = content[start: end + 1].strip()
+ tmp = content[start : end + 1].strip()
if len(tmp) > 0:
output.append(tmp)
start = idx + 1
@@ -65,11 +65,11 @@ def split_sentence(self, content):
return output
def slide_window_chunk(
- self,
- org_chunk: Chunk,
- chunk_size: int = 2000,
- window_length: int = 300,
- sep: str = "\n",
+ self,
+ org_chunk: Chunk,
+ chunk_size: int = 2000,
+ window_length: int = 300,
+ sep: str = "\n",
) -> List[Chunk]:
"""
Splits the content into chunks using a sliding window approach.
@@ -84,7 +84,9 @@ def slide_window_chunk(
List[Chunk]: A list of Chunk objects.
"""
if org_chunk.type == ChunkTypeEnum.Table:
- table_chunks = self.split_table(org_chunk=org_chunk, chunk_size=chunk_size, sep=sep)
+ table_chunks = self.split_table(
+ org_chunk=org_chunk, chunk_size=chunk_size, sep=sep
+ )
if table_chunks is not None:
return table_chunks
content = self.split_sentence(org_chunk.content)
@@ -116,7 +118,7 @@ def slide_window_chunk(
name=f"{org_chunk.name}",
content=sep.join(sentences),
type=org_chunk.type,
- **org_chunk.kwargs
+ **org_chunk.kwargs,
)
output.append(chunk)
return output
@@ -133,17 +135,13 @@ def invoke(self, input: Chunk, **kwargs) -> List[Output]:
List[Output]: A list of split chunks.
"""
cutted = []
- if isinstance(input,list):
+ if isinstance(input, list):
for item in input:
cutted.extend(
- self.slide_window_chunk(
- item, self.split_length, self.window_length
- )
+ self.slide_window_chunk(item, self.split_length, self.window_length)
)
else:
cutted.extend(
- self.slide_window_chunk(
- input, self.split_length, self.window_length
- )
+ self.slide_window_chunk(input, self.split_length, self.window_length)
)
return cutted
diff --git a/kag/builder/component/splitter/outline_splitter.py b/kag/builder/component/splitter/outline_splitter.py
index 22796e9b..15481565 100644
--- a/kag/builder/component/splitter/outline_splitter.py
+++ b/kag/builder/component/splitter/outline_splitter.py
@@ -11,7 +11,6 @@
# or implied.
import logging
import os
-import re
from typing import List, Type, Union
from kag.interface.builder import SplitterABC
@@ -70,28 +69,3 @@ def sep_by_outline(self, content, outlines):
def invoke(self, input: Input, **kwargs) -> List[Chunk]:
chunks = self.outline_chunk(input)
return chunks
-
-
-if __name__ == "__main__":
- from kag.builder.component.splitter.length_splitter import LengthSplitter
- from kag.builder.component.splitter.outline_splitter import OutlineSplitter
- from kag.builder.component.reader.docx_reader import DocxReader
- from kag.common.env import init_kag_config
-
- init_kag_config(
- os.path.join(
- os.path.dirname(__file__),
- "../../../../tests/builder/component/test_config.cfg",
- )
- )
- docx_reader = DocxReader()
- length_splitter = LengthSplitter(split_length=8000)
- outline_splitter = OutlineSplitter()
- docx_path = os.path.join(
- os.path.dirname(__file__), "../../../../tests/builder/data/test_docx.docx"
- )
- # chain = docx_reader >> length_splitter >> outline_splitter
- chunk = docx_reader.invoke(docx_path)
- chunks = length_splitter.invoke(chunk)
- chunks = outline_splitter.invoke(chunks)
- print(chunks)
diff --git a/kag/builder/component/splitter/pattern_splitter.py b/kag/builder/component/splitter/pattern_splitter.py
index 0b72f265..df894c23 100644
--- a/kag/builder/component/splitter/pattern_splitter.py
+++ b/kag/builder/component/splitter/pattern_splitter.py
@@ -14,7 +14,7 @@
import re
import os
-from kag.builder.model.chunk import Chunk, ChunkTypeEnum
+from kag.builder.model.chunk import Chunk
from kag.interface.builder.splitter_abc import SplitterABC
from knext.common.base.runnable import Input, Output
diff --git a/kag/builder/component/splitter/semantic_splitter.py b/kag/builder/component/splitter/semantic_splitter.py
index e9cbdba1..3c7deaea 100644
--- a/kag/builder/component/splitter/semantic_splitter.py
+++ b/kag/builder/component/splitter/semantic_splitter.py
@@ -18,7 +18,6 @@
from kag.builder.prompt.semantic_seg_prompt import SemanticSegPrompt
from kag.builder.model.chunk import Chunk
from knext.common.base.runnable import Input, Output
-from kag.common.llm import LLMClient
logger = logging.getLogger(__name__)
diff --git a/kag/builder/component/vectorizer/batch_vectorizer.py b/kag/builder/component/vectorizer/batch_vectorizer.py
index 019f4a71..1b785c4f 100644
--- a/kag/builder/component/vectorizer/batch_vectorizer.py
+++ b/kag/builder/component/vectorizer/batch_vectorizer.py
@@ -18,7 +18,6 @@
from kag.common.vectorizer import Vectorizer
from kag.interface.builder.vectorizer_abc import VectorizerABC
from knext.schema.client import SchemaClient
-from knext.project.client import ProjectClient
from knext.schema.model.base import IndexTypeEnum
diff --git a/kag/builder/component/writer/kg_writer.py b/kag/builder/component/writer/kg_writer.py
index 155bf1bf..e3d9c1cd 100644
--- a/kag/builder/component/writer/kg_writer.py
+++ b/kag/builder/component/writer/kg_writer.py
@@ -50,7 +50,10 @@ def output_types(self) -> Type[Output]:
return None
def invoke(
- self, input: Input, alter_operation: str = AlterOperationEnum.Upsert, lead_to_builder: bool = False
+ self,
+ input: Input,
+ alter_operation: str = AlterOperationEnum.Upsert,
+ lead_to_builder: bool = False,
) -> List[Output]:
"""
Invokes the specified operation (upsert or delete) on the graph store.
@@ -63,11 +66,15 @@ def invoke(
Returns:
List[Output]: A list of output objects (currently always [None]).
"""
- self.client.write_graph(sub_graph=input.to_dict(), operation=alter_operation, lead_to_builder=lead_to_builder)
+ self.client.write_graph(
+ sub_graph=input.to_dict(),
+ operation=alter_operation,
+ lead_to_builder=lead_to_builder,
+ )
return [None]
def _handle(self, input: Dict, alter_operation: str, **kwargs):
"""The calling interface provided for SPGServer."""
_input = self.input_types.from_dict(input)
- _output = self.invoke(_input, alter_operation)
+ _output = self.invoke(_input, alter_operation) # noqa
return None
diff --git a/kag/builder/default_chain.py b/kag/builder/default_chain.py
index ab04aff9..aa3ce185 100644
--- a/kag/builder/default_chain.py
+++ b/kag/builder/default_chain.py
@@ -27,9 +27,11 @@
def get_reader(file_path: str):
file = os.path.basename(file_path)
suffix = file.split(".")[-1]
- assert suffix.lower() in READER_MAPPING, f"{suffix} is not supported. Supported suffixes are: {list(READER_MAPPING.keys())}"
+ assert (
+ suffix.lower() in READER_MAPPING
+ ), f"{suffix} is not supported. Supported suffixes are: {list(READER_MAPPING.keys())}"
reader_path = READER_MAPPING.get(suffix.lower())
- mod_path, class_name = reader_path.rsplit('.', 1)
+ mod_path, class_name = reader_path.rsplit(".", 1)
module = importlib.import_module(mod_path)
reader_class = getattr(module, class_name)
@@ -138,7 +140,14 @@ def build(self, **kwargs) -> Chain:
chain = source >> splitter >> extractor >> vectorizer >> sink
return chain
- def invoke(self, file_path: str, split_length: int = 500, window_length: int = 100, max_workers=10, **kwargs):
+ def invoke(
+ self,
+ file_path: str,
+ split_length: int = 500,
+ window_length: int = 100,
+ max_workers=10,
+ **kwargs,
+ ):
logger.info(f"begin processing file_path:{file_path}")
"""
Invokes the processing chain with the given file path and optional parameters.
@@ -154,4 +163,10 @@ def invoke(self, file_path: str, split_length: int = 500, window_length: int = 1
Returns:
The result of invoking the processing chain.
"""
- return super().invoke(file_path=file_path, max_workers=max_workers, split_length=window_length, window_length=window_length, **kwargs)
+ return super().invoke(
+ file_path=file_path,
+ max_workers=max_workers,
+ split_length=window_length,
+ window_length=window_length,
+ **kwargs,
+ )
diff --git a/kag/builder/model/chunk.py b/kag/builder/model/chunk.py
index a5db11c3..fe5d3a70 100644
--- a/kag/builder/model/chunk.py
+++ b/kag/builder/model/chunk.py
@@ -26,7 +26,7 @@ def __init__(
name: str,
content: str,
type: ChunkTypeEnum = ChunkTypeEnum.Text,
- **kwargs
+ **kwargs,
):
self.id = id
self.name = name
@@ -59,7 +59,9 @@ def to_dict(self):
"id": self.id,
"name": self.name,
"content": self.content,
- "type": self.type.value if isinstance(self.type, ChunkTypeEnum) else self.type,
+ "type": self.type.value
+ if isinstance(self.type, ChunkTypeEnum)
+ else self.type,
"properties": self.kwargs,
}
diff --git a/kag/builder/model/spg_record.py b/kag/builder/model/spg_record.py
index 5c5b6825..2dd1d2f8 100644
--- a/kag/builder/model/spg_record.py
+++ b/kag/builder/model/spg_record.py
@@ -129,9 +129,9 @@ def append_property(self, property_name: PropertyName, value: str):
"""
property_value = self.get_property(property_name)
if property_value:
- property_value_list = property_value.split(',')
+ property_value_list = property_value.split(",")
if value not in property_value_list:
- self.properties[property_name] = property_value + ',' + value
+ self.properties[property_name] = property_value + "," + value
else:
self.properties[property_name] = value
return self
diff --git a/kag/builder/model/sub_graph.py b/kag/builder/model/sub_graph.py
index ff4ebb7f..b359ca2b 100644
--- a/kag/builder/model/sub_graph.py
+++ b/kag/builder/model/sub_graph.py
@@ -41,7 +41,7 @@ def from_spg_record(cls, idx, spg_record: SPGRecord):
@staticmethod
def unique_key(spg_record):
- return spg_record.spg_type_name + '_' + spg_record.get_property("name", "")
+ return spg_record.spg_type_name + "_" + spg_record.get_property("name", "")
def to_dict(self):
return {
@@ -61,7 +61,11 @@ def from_dict(cls, input: Dict):
)
def __eq__(self, other):
- return self.name == other.name and self.label == other.label and self.properties == other.properties
+ return (
+ self.name == other.name
+ and self.label == other.label
+ and self.properties == other.properties
+ )
class Edge(object):
@@ -74,7 +78,12 @@ class Edge(object):
properties: Dict[str, str]
def __init__(
- self, _id: str, from_node: Node, to_node: Node, label: str, properties: Dict[str, str]
+ self,
+ _id: str,
+ from_node: Node,
+ to_node: Node,
+ label: str,
+ properties: Dict[str, str],
):
self.from_id = from_node.id
self.from_type = from_node.label
@@ -88,12 +97,19 @@ def __init__(
@classmethod
def from_spg_record(
- cls, s_idx, subject_record: SPGRecord, o_idx, object_record: SPGRecord, label: str
+ cls,
+ s_idx,
+ subject_record: SPGRecord,
+ o_idx,
+ object_record: SPGRecord,
+ label: str,
):
from_node = Node.from_spg_record(s_idx, subject_record)
to_node = Node.from_spg_record(o_idx, object_record)
- return cls(_id="", from_node=from_node, to_node=to_node, label=label, properties={})
+ return cls(
+ _id="", from_node=from_node, to_node=to_node, label=label, properties={}
+ )
def to_dict(self):
return {
@@ -110,14 +126,28 @@ def to_dict(self):
def from_dict(cls, input: Dict):
return cls(
_id=input["id"],
- from_node=Node(_id=input["from"], name=input["from"],label=input["fromType"], properties={}),
- to_node=Node(_id=input["to"], name=input["to"], label=input["toType"], properties={}),
+ from_node=Node(
+ _id=input["from"],
+ name=input["from"],
+ label=input["fromType"],
+ properties={},
+ ),
+ to_node=Node(
+ _id=input["to"], name=input["to"], label=input["toType"], properties={}
+ ),
label=input["label"],
properties=input["properties"],
)
def __eq__(self, other):
- return self.from_id == other.from_id and self.to_id == other.to_id and self.label == other.label and self.properties == other.properties and self.from_type == other.from_type and self.to_type == other.to_type
+ return (
+ self.from_id == other.from_id
+ and self.to_id == other.to_id
+ and self.label == other.label
+ and self.properties == other.properties
+ and self.from_type == other.from_type
+ and self.to_type == other.to_type
+ )
class SubGraph(object):
@@ -135,12 +165,18 @@ def add_node(self, id: str, name: str, label: str, properties=None):
self.nodes.append(Node(_id=id, name=name, label=label, properties=properties))
return self
- def add_edge(self, s_id: str, s_label: str, p: str, o_id: str, o_label: str, properties=None):
+ def add_edge(
+ self, s_id: str, s_label: str, p: str, o_id: str, o_label: str, properties=None
+ ):
if not properties:
properties = dict()
s_node = Node(_id=s_id, name=s_id, label=s_label, properties={})
o_node = Node(_id=o_id, name=o_id, label=o_label, properties={})
- self.edges.append(Edge(_id="", from_node=s_node, to_node=o_node, label=p, properties=properties))
+ self.edges.append(
+ Edge(
+ _id="", from_node=s_node, to_node=o_node, label=p, properties=properties
+ )
+ )
return self
def to_dict(self):
@@ -152,7 +188,7 @@ def to_dict(self):
def __repr__(self):
return pprint.pformat(self.to_dict())
- def merge(self, sub_graph: 'SubGraph'):
+ def merge(self, sub_graph: "SubGraph"):
self.nodes.extend(sub_graph.nodes)
self.edges.extend(sub_graph.edges)
@@ -164,21 +200,30 @@ def from_spg_record(
for record in spg_records:
s_id = record.id
s_name = record.name
- s_label = record.spg_type_name.split('.')[-1]
+ s_label = record.spg_type_name.split(".")[-1]
properties = record.properties
spg_type = spg_types.get(record.spg_type_name)
for prop_name, prop_value in record.properties.items():
if prop_name in spg_type.properties:
from knext.schema.model.property import Property
+
prop: Property = spg_type.properties.get(prop_name)
- o_label = prop.object_type_name.split('.')[-1]
+ o_label = prop.object_type_name.split(".")[-1]
if o_label not in BASIC_TYPES:
- prop_value_list = prop_value.split(',')
+ prop_value_list = prop_value.split(",")
for o_id in prop_value_list:
- sub_graph.add_edge(s_id=s_id, s_label=s_label, p=prop_name, o_id=o_id, o_label=o_label)
+ sub_graph.add_edge(
+ s_id=s_id,
+ s_label=s_label,
+ p=prop_name,
+ o_id=o_id,
+ o_label=o_label,
+ )
properties.pop(prop_name)
- sub_graph.add_node(id=s_id, name=s_name, label=s_label, properties=properties)
+ sub_graph.add_node(
+ id=s_id, name=s_name, label=s_label, properties=properties
+ )
return sub_graph
diff --git a/kag/builder/operator/__init__.py b/kag/builder/operator/__init__.py
index 123acd8d..93aa6cd4 100644
--- a/kag/builder/operator/__init__.py
+++ b/kag/builder/operator/__init__.py
@@ -9,4 +9,3 @@
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
-
diff --git a/kag/builder/prompt/analyze_table_prompt.py b/kag/builder/prompt/analyze_table_prompt.py
index 00b9ade0..36325f83 100644
--- a/kag/builder/prompt/analyze_table_prompt.py
+++ b/kag/builder/prompt/analyze_table_prompt.py
@@ -18,8 +18,6 @@
logger = logging.getLogger(__name__)
-
-
class AnalyzeTablePrompt(PromptOp):
template_zh: str = """你是一个分析表格的专家, 从table中提取信息并分析,最后返回表格有效信息"""
template_en: str = """You are an expert in knowledge graph extraction. Based on the schema defined by the constraint, extract all entities and their attributes from the input. Return NAN for attributes not explicitly mentioned in the input. Output the results in standard JSON format, as a list."""
@@ -36,11 +34,10 @@ def build_prompt(self, variables) -> str:
return json.dumps(
{
"instruction": self.template,
- "table": variables.get("table",""),
+ "table": variables.get("table", ""),
},
ensure_ascii=False,
)
def parse_response(self, response: str, **kwargs):
return response
-
diff --git a/kag/builder/prompt/default/ner.py b/kag/builder/prompt/default/ner.py
index 1cc92310..7240572f 100644
--- a/kag/builder/prompt/default/ner.py
+++ b/kag/builder/prompt/default/ner.py
@@ -146,9 +146,7 @@ class OpenIENERPrompt(PromptOp):
}
"""
- def __init__(
- self, language: Optional[str] = "en", **kwargs
- ):
+ def __init__(self, language: Optional[str] = "en", **kwargs):
super().__init__(language, **kwargs)
self.schema = SchemaClient(project_id=self.project_id).extract_types()
self.template = Template(self.template).safe_substitute(schema=self.schema)
diff --git a/kag/builder/prompt/medical/ner.py b/kag/builder/prompt/medical/ner.py
index 07c6298a..7a9a9333 100644
--- a/kag/builder/prompt/medical/ner.py
+++ b/kag/builder/prompt/medical/ner.py
@@ -45,9 +45,7 @@ class OpenIENERPrompt(PromptOp):
template_en = template_zh
- def __init__(
- self, language: Optional[str] = "en", **kwargs
- ):
+ def __init__(self, language: Optional[str] = "en", **kwargs):
super().__init__(language, **kwargs)
self.schema = SchemaClient(project_id=self.project_id).extract_types()
self.template = Template(self.template).safe_substitute(schema=self.schema)
diff --git a/kag/builder/prompt/medical/triple.py b/kag/builder/prompt/medical/triple.py
index 2b5aaff8..d925d963 100644
--- a/kag/builder/prompt/medical/triple.py
+++ b/kag/builder/prompt/medical/triple.py
@@ -11,7 +11,7 @@
# or implied.
import json
-from typing import Optional, List, Dict, Any
+from typing import Optional, List
from kag.common.base.prompt_op import PromptOp
diff --git a/kag/builder/prompt/spg_prompt.py b/kag/builder/prompt/spg_prompt.py
index 98f47255..12c48af9 100644
--- a/kag/builder/prompt/spg_prompt.py
+++ b/kag/builder/prompt/spg_prompt.py
@@ -27,7 +27,15 @@
class SPGPrompt(PromptOp, ABC):
spg_types: Dict[str, BaseSpgType]
ignored_types: List[str] = ["Chunk"]
- ignored_properties: List[str] = ["id", "name", "description", "stdId", "eventTime", "desc", "semanticType"]
+ ignored_properties: List[str] = [
+ "id",
+ "name",
+ "description",
+ "stdId",
+ "eventTime",
+ "desc",
+ "semanticType",
+ ]
ignored_relations: List[str] = ["isA"]
basic_types = {"Text": "文本", "Integer": "整型", "Float": "浮点型"}
@@ -43,7 +51,9 @@ def __init__(
if not spg_type_names:
self.spg_types = self.all_schema_types
else:
- self.spg_types = {k: v for k, v in self.all_schema_types.items() if k in spg_type_names}
+ self.spg_types = {
+ k: v for k, v in self.all_schema_types.items() if k in spg_type_names
+ }
self.schema_list = []
self._init_render_variables()
@@ -138,16 +148,9 @@ class SPG_KGPrompt(SPGPrompt):
template_en: str = template_zh
def __init__(
- self,
- spg_type_names: List[SPGTypeName],
- language: str = "zh",
- **kwargs
+ self, spg_type_names: List[SPGTypeName], language: str = "zh", **kwargs
):
- super().__init__(
- spg_type_names=spg_type_names,
- language=language,
- **kwargs
- )
+ super().__init__(spg_type_names=spg_type_names, language=language, **kwargs)
self._render()
def build_prompt(self, variables: Dict[str, str]) -> str:
@@ -173,13 +176,19 @@ def parse_response(self, response: str, **kwargs) -> List[SPGRecord]:
def _render(self):
spo_list = []
for type_name, spg_type in self.spg_types.items():
- if spg_type.spg_type_enum not in [SpgTypeEnum.Entity, SpgTypeEnum.Concept, SpgTypeEnum.Event]:
+ if spg_type.spg_type_enum not in [
+ SpgTypeEnum.Entity,
+ SpgTypeEnum.Concept,
+ SpgTypeEnum.Event,
+ ]:
continue
constraint = {}
properties = {}
properties.update(
{
- v.name: (f"{v.name_zh}" if not v.desc else f"{v.name_zh},{v.desc}") if self.language == "zh" else (f"{v.name}" if not v.desc else f"{v.name}, {v.desc}")
+ v.name: (f"{v.name_zh}" if not v.desc else f"{v.name_zh},{v.desc}")
+ if self.language == "zh"
+ else (f"{v.name}" if not v.desc else f"{v.name}, {v.desc}")
for k, v in spg_type.properties.items()
if k not in self.ignored_properties
}
@@ -190,7 +199,9 @@ def _render(self):
f"{v.name_zh},类型是{v.object_type_name_zh}"
if not v.desc
else f"{v.name_zh},{v.desc},类型是{v.object_type_name_zh}"
- ) if self.language == "zh" else (
+ )
+ if self.language == "zh"
+ else (
f"{v.name}, the type is {v.object_type_name_en}"
if not v.desc
else f"{v.name},{v.desc}, the type is {v.object_type_name_en}"
diff --git a/kag/common/__init__.py b/kag/common/__init__.py
index 123acd8d..93aa6cd4 100644
--- a/kag/common/__init__.py
+++ b/kag/common/__init__.py
@@ -9,4 +9,3 @@
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
-
diff --git a/kag/common/arks_pb2.py b/kag/common/arks_pb2.py
index 0a693f00..01462624 100644
--- a/kag/common/arks_pb2.py
+++ b/kag/common/arks_pb2.py
@@ -6,191 +6,166 @@
# 参考文档: https://yuque.antfin-inc.com/ai-infra/ndhopc/smk38dcs9zqr1ssh#Kb7e0
import sys
-_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
+
+_b = sys.version_info[0] < 3 and (lambda x: x) or (lambda x: x.encode("latin1"))
from google.protobuf.internal import enum_type_wrapper
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database
+
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
-
-
DESCRIPTOR = _descriptor.FileDescriptor(
- name='arks.proto',
- package='arks',
- syntax='proto2',
- serialized_options=_b('\n\025com.alipay.arks.proto'),
- serialized_pb=_b('\n\narks.proto\x12\x04\x61rks\"\xfc\x01\n\x13InferTensorContents\x12\x14\n\x0cstring_value\x18\x01 \x03(\t\x12\x12\n\nbool_value\x18\x02 \x03(\x08\x12\x11\n\tint_value\x18\x03 \x03(\x05\x12\x13\n\x0bint64_value\x18\x04 \x03(\x03\x12\x12\n\nuint_value\x18\x05 \x03(\r\x12\x14\n\x0cuint64_value\x18\x06 \x03(\x04\x12\x12\n\nfp32_value\x18\x07 \x03(\x02\x12\x12\n\nfp64_value\x18\x08 \x03(\x01\x12\x12\n\nbyte_value\x18\t \x03(\x0c\x12-\n\x04type\x18\n \x01(\x0e\x32\x11.arks.ContentType:\x0cTYPE_INVALID\"q\n\x04Pair\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t\x12+\n\x08\x63ontents\x18\x03 \x01(\x0b\x32\x19.arks.InferTensorContents\x12\x10\n\x08pb_value\x18\x04 \x03(\x0c\x12\x0e\n\x06shapes\x18\x05 \x03(\x05\"\x97\x01\n\x06RowKey\x12\x0f\n\x07row_key\x18\x01 \x01(\t\x12\x10\n\x08versions\x18\x02 \x03(\x03\x12\x1a\n\x12\x61nt_fea_track_info\x18\x03 \x01(\t\x12\'\n\npartitions\x18\x04 \x03(\x0b\x32\x13.arks.PartitionInfo\x12%\n\x11realtime_features\x18\x05 \x03(\x0b\x32\n.arks.Pair\",\n\rPartitionInfo\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t\"\xb9\x01\n\x04Item\x12\x0f\n\x07item_id\x18\x01 \x02(\t\x12\x1c\n\x08\x66\x65\x61tures\x18\x02 \x03(\x0b\x32\n.arks.Pair\x12\x1e\n\nattributes\x18\x03 \x03(\x0b\x32\n.arks.Pair\x12\r\n\x05score\x18\x04 \x01(\x02\x12 \n\tsub_items\x18\x05 \x03(\x0b\x32\r.arks.SubItem\x12\x1d\n\x11is_features_valid\x18\x06 \x03(\x08\x42\x02\x10\x01\x12\x12\n\x06scores\x18\x07 \x03(\x02\x42\x02\x10\x01\"\x9a\x01\n\x07SubItem\x12\x0f\n\x07item_id\x18\x01 \x01(\t\x12\x1c\n\x08\x66\x65\x61tures\x18\x02 \x03(\x0b\x32\n.arks.Pair\x12\r\n\x05score\x18\x03 \x01(\x02\x12\x1d\n\x11is_features_valid\x18\x04 \x03(\x08\x42\x02\x10\x01\x12\x12\n\x06scores\x18\x05 \x03(\x02\x42\x02\x10\x01\x12\x1e\n\nattributes\x18\x06 \x03(\x0b\x32\n.arks.Pair\"\xc7\x03\n\x08SeekPlan\x12\x14\n\x0cstorage_type\x18\x01 \x01(\t\x12\r\n\x05table\x18\x02 \x01(\t\x12\x15\n\rcolumn_family\x18\x03 \x01(\t\x12\x0f\n\x07\x63olumns\x18\x04 \x03(\t\x12\x12\n\nkvpair_sep\x18\x05 \x01(\t\x12\x0e\n\x06kv_sep\x18\x06 \x01(\t\x12\x0f\n\x07\x63luster\x18\x07 \x01(\t\x12\x1e\n\x08row_keys\x18\x08 \x03(\x0b\x32\x0c.arks.RowKey\x12\x12\n\ntimeout_ms\x18\t \x01(\x05\x12\x1b\n\x13\x63\x61\x63he_expire_second\x18\n \x01(\x05\x12\x10\n\x08url_user\x18\x0b \x01(\t\x12\x10\n\x08url_item\x18\x0c \x01(\t\x12\x17\n\x0f\x61nt_feature_req\x18\r \x01(\x0c\x12\n\n\x02id\x18\x0e \x01(\t\x12\x16\n\x0ekb_feature_req\x18\x0f \x01(\x0c\x12\x11\n\tdebuginfo\x18\x10 \x01(\t\x12\x11\n\tseparator\x18\x11 \x01(\t\x12=\n\x12item_sequence_type\x18\x12 \x01(\x0e\x32\x16.arks.ItemSequenceType:\tTYPE_NONE\x12\"\n\x0emissing_values\x18\x13 \x03(\x0b\x32\n.arks.Pair\"\x8f\x01\n\x0b\x44umpReqInfo\x12\x0e\n\x06time_s\x18\x01 \x01(\x05\x12\x0e\n\x06oss_id\x18\x02 \x01(\t\x12\x0f\n\x07oss_key\x18\x03 \x01(\t\x12\x13\n\x0btarget_addr\x18\x04 \x01(\t\x12\x10\n\x08query_id\x18\x05 \x01(\x03\x12\r\n\x05token\x18\x06 \x01(\t\x12\x0b\n\x03\x61pp\x18\x07 \x01(\t\x12\x0c\n\x04host\x18\x08 \x01(\t\"\xb3\x04\n\x0b\x41rksRequest\x12\x12\n\x07version\x18\x01 \x01(\x05:\x01\x31\x12\r\n\x05\x64\x65\x62ug\x18\x02 \x01(\x05\x12\x0f\n\x07is_ping\x18\x03 \x01(\x08\x12\x12\n\nsession_id\x18\x04 \x01(\t\x12\x13\n\x0b\x62ucket_name\x18\x05 \x01(\t\x12\x0b\n\x03uid\x18\x06 \x01(\t\x12 \n\x0cuser_profile\x18\x07 \x03(\x0b\x32\n.arks.Pair\x12\"\n\x0escene_features\x18\x08 \x03(\x0b\x32\n.arks.Pair\x12\x19\n\x05items\x18\t \x03(\x0b\x32\n.arks.Item\x12\x15\n\x07is_sort\x18\n \x01(\x08:\x04true\x12\x11\n\x05\x63ount\x18\x0b \x01(\x05:\x02\x31\x30\x12.\n\nout_format\x18\x0c \x01(\x0e\x32\x16.arks.OutputFormatType:\x02PB\x12\x12\n\nchain_name\x18\r \x01(\t\x12\x0b\n\x03scm\x18\x0e \x01(\t\x12\x12\n\nscene_name\x18\x0f \x01(\t\x12\x14\n\x0citem_schemas\x18\x10 \x03(\t\x12\x18\n\x10sub_item_schemas\x18\x11 \x03(\t\x12\"\n\nseek_plans\x18\x12 \x03(\x0b\x32\x0e.arks.SeekPlan\x12(\n\rdump_req_info\x18\x13 \x01(\x0b\x32\x11.arks.DumpReqInfo\x12\x10\n\x08\x61pp_name\x18\x14 \x01(\t\x12\x16\n\x0ereq_timeout_ms\x18\x15 \x01(\x04\x12\x16\n\x0e\x63lient_version\x18\x16 \x01(\t\x12\n\n\x02ip\x18\x17 \x01(\t\"\xba\x02\n\x0c\x41rksResponse\x12,\n\nerror_code\x18\x01 \x01(\x0e\x32\x0f.arks.ErrorCode:\x07SUCCESS\x12\x12\n\nsession_id\x18\x02 \x01(\t\x12\x13\n\x0b\x62ucket_name\x18\x03 \x01(\t\x12 \n\x0cuser_profile\x18\x04 \x03(\x0b\x32\n.arks.Pair\x12\x19\n\x05items\x18\x05 \x03(\x0b\x32\n.arks.Item\x12\x11\n\tdebug_msg\x18\x06 \x01(\t\x12\x0b\n\x03scm\x18\x07 \x01(\t\x12\"\n\nseek_plans\x18\x08 \x03(\x0b\x32\x0e.arks.SeekPlan\x12\x0f\n\x07\x65rr_msg\x18\t \x01(\t\x12\x10\n\x08\x61lgo_ret\x18\n \x01(\x05\x12\x10\n\x08\x61lgo_msg\x18\x0b \x01(\t\x12\x11\n\ttrace_msg\x18\x0c \x01(\t\x12\n\n\x02rt\x18\r \x01(\x05*T\n\x10OutputFormatType\x12\x06\n\x02PB\x10\x01\x12\x08\n\x04JSON\x10\x02\x12\x08\n\x04TEXT\x10\x03\x12\r\n\tSNAPPY_PB\x10\x04\x12\x06\n\x02\x46\x42\x10\x05\x12\r\n\tSNAPPY_FB\x10\x06*\x86\x01\n\tErrorCode\x12\x0b\n\x07SUCCESS\x10\x00\x12\x0b\n\x07TIMEOUT\x10\x01\x12\r\n\tSCENE_ERR\x10\x02\x12\r\n\tPARAM_ERR\x10\x03\x12\x0e\n\nSYSTEM_ERR\x10\x04\x12\x0f\n\x0bSERVICE_ERR\x10\x05\x12\x10\n\x0c\x46LOW_CONTROL\x10\x06\x12\x0e\n\nOTHERS_ERR\x10\x07*\xae\x01\n\x0b\x43ontentType\x12\x10\n\x0cTYPE_INVALID\x10\x00\x12\r\n\tTYPE_BOOL\x10\x01\x12\x0e\n\nTYPE_INT32\x10\x02\x12\x0e\n\nTYPE_INT64\x10\x03\x12\x0f\n\x0bTYPE_UINT32\x10\x04\x12\x0f\n\x0bTYPE_UINT64\x10\x05\x12\r\n\tTYPE_FP32\x10\x06\x12\r\n\tTYPE_FP64\x10\x07\x12\x0f\n\x0bTYPE_STRING\x10\x08\x12\r\n\tTYPE_BYTE\x10\t*A\n\x10ItemSequenceType\x12\r\n\tTYPE_NONE\x10\x00\x12\x0f\n\x0bTYPE_CONCAT\x10\x01\x12\r\n\tTYPE_FLAT\x10\x02\x42\x17\n\x15\x63om.alipay.arks.proto')
+ name="arks.proto",
+ package="arks",
+ syntax="proto2",
+ serialized_options=_b("\n\025com.alipay.arks.proto"),
+ serialized_pb=_b(
+ '\n\narks.proto\x12\x04\x61rks"\xfc\x01\n\x13InferTensorContents\x12\x14\n\x0cstring_value\x18\x01 \x03(\t\x12\x12\n\nbool_value\x18\x02 \x03(\x08\x12\x11\n\tint_value\x18\x03 \x03(\x05\x12\x13\n\x0bint64_value\x18\x04 \x03(\x03\x12\x12\n\nuint_value\x18\x05 \x03(\r\x12\x14\n\x0cuint64_value\x18\x06 \x03(\x04\x12\x12\n\nfp32_value\x18\x07 \x03(\x02\x12\x12\n\nfp64_value\x18\x08 \x03(\x01\x12\x12\n\nbyte_value\x18\t \x03(\x0c\x12-\n\x04type\x18\n \x01(\x0e\x32\x11.arks.ContentType:\x0cTYPE_INVALID"q\n\x04Pair\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t\x12+\n\x08\x63ontents\x18\x03 \x01(\x0b\x32\x19.arks.InferTensorContents\x12\x10\n\x08pb_value\x18\x04 \x03(\x0c\x12\x0e\n\x06shapes\x18\x05 \x03(\x05"\x97\x01\n\x06RowKey\x12\x0f\n\x07row_key\x18\x01 \x01(\t\x12\x10\n\x08versions\x18\x02 \x03(\x03\x12\x1a\n\x12\x61nt_fea_track_info\x18\x03 \x01(\t\x12\'\n\npartitions\x18\x04 \x03(\x0b\x32\x13.arks.PartitionInfo\x12%\n\x11realtime_features\x18\x05 \x03(\x0b\x32\n.arks.Pair",\n\rPartitionInfo\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t"\xb9\x01\n\x04Item\x12\x0f\n\x07item_id\x18\x01 \x02(\t\x12\x1c\n\x08\x66\x65\x61tures\x18\x02 \x03(\x0b\x32\n.arks.Pair\x12\x1e\n\nattributes\x18\x03 \x03(\x0b\x32\n.arks.Pair\x12\r\n\x05score\x18\x04 \x01(\x02\x12 \n\tsub_items\x18\x05 \x03(\x0b\x32\r.arks.SubItem\x12\x1d\n\x11is_features_valid\x18\x06 \x03(\x08\x42\x02\x10\x01\x12\x12\n\x06scores\x18\x07 \x03(\x02\x42\x02\x10\x01"\x9a\x01\n\x07SubItem\x12\x0f\n\x07item_id\x18\x01 \x01(\t\x12\x1c\n\x08\x66\x65\x61tures\x18\x02 \x03(\x0b\x32\n.arks.Pair\x12\r\n\x05score\x18\x03 \x01(\x02\x12\x1d\n\x11is_features_valid\x18\x04 \x03(\x08\x42\x02\x10\x01\x12\x12\n\x06scores\x18\x05 \x03(\x02\x42\x02\x10\x01\x12\x1e\n\nattributes\x18\x06 \x03(\x0b\x32\n.arks.Pair"\xc7\x03\n\x08SeekPlan\x12\x14\n\x0cstorage_type\x18\x01 \x01(\t\x12\r\n\x05table\x18\x02 \x01(\t\x12\x15\n\rcolumn_family\x18\x03 \x01(\t\x12\x0f\n\x07\x63olumns\x18\x04 \x03(\t\x12\x12\n\nkvpair_sep\x18\x05 \x01(\t\x12\x0e\n\x06kv_sep\x18\x06 \x01(\t\x12\x0f\n\x07\x63luster\x18\x07 \x01(\t\x12\x1e\n\x08row_keys\x18\x08 \x03(\x0b\x32\x0c.arks.RowKey\x12\x12\n\ntimeout_ms\x18\t \x01(\x05\x12\x1b\n\x13\x63\x61\x63he_expire_second\x18\n \x01(\x05\x12\x10\n\x08url_user\x18\x0b \x01(\t\x12\x10\n\x08url_item\x18\x0c \x01(\t\x12\x17\n\x0f\x61nt_feature_req\x18\r \x01(\x0c\x12\n\n\x02id\x18\x0e \x01(\t\x12\x16\n\x0ekb_feature_req\x18\x0f \x01(\x0c\x12\x11\n\tdebuginfo\x18\x10 \x01(\t\x12\x11\n\tseparator\x18\x11 \x01(\t\x12=\n\x12item_sequence_type\x18\x12 \x01(\x0e\x32\x16.arks.ItemSequenceType:\tTYPE_NONE\x12"\n\x0emissing_values\x18\x13 \x03(\x0b\x32\n.arks.Pair"\x8f\x01\n\x0b\x44umpReqInfo\x12\x0e\n\x06time_s\x18\x01 \x01(\x05\x12\x0e\n\x06oss_id\x18\x02 \x01(\t\x12\x0f\n\x07oss_key\x18\x03 \x01(\t\x12\x13\n\x0btarget_addr\x18\x04 \x01(\t\x12\x10\n\x08query_id\x18\x05 \x01(\x03\x12\r\n\x05token\x18\x06 \x01(\t\x12\x0b\n\x03\x61pp\x18\x07 \x01(\t\x12\x0c\n\x04host\x18\x08 \x01(\t"\xb3\x04\n\x0b\x41rksRequest\x12\x12\n\x07version\x18\x01 \x01(\x05:\x01\x31\x12\r\n\x05\x64\x65\x62ug\x18\x02 \x01(\x05\x12\x0f\n\x07is_ping\x18\x03 \x01(\x08\x12\x12\n\nsession_id\x18\x04 \x01(\t\x12\x13\n\x0b\x62ucket_name\x18\x05 \x01(\t\x12\x0b\n\x03uid\x18\x06 \x01(\t\x12 \n\x0cuser_profile\x18\x07 \x03(\x0b\x32\n.arks.Pair\x12"\n\x0escene_features\x18\x08 \x03(\x0b\x32\n.arks.Pair\x12\x19\n\x05items\x18\t \x03(\x0b\x32\n.arks.Item\x12\x15\n\x07is_sort\x18\n \x01(\x08:\x04true\x12\x11\n\x05\x63ount\x18\x0b \x01(\x05:\x02\x31\x30\x12.\n\nout_format\x18\x0c \x01(\x0e\x32\x16.arks.OutputFormatType:\x02PB\x12\x12\n\nchain_name\x18\r \x01(\t\x12\x0b\n\x03scm\x18\x0e \x01(\t\x12\x12\n\nscene_name\x18\x0f \x01(\t\x12\x14\n\x0citem_schemas\x18\x10 \x03(\t\x12\x18\n\x10sub_item_schemas\x18\x11 \x03(\t\x12"\n\nseek_plans\x18\x12 \x03(\x0b\x32\x0e.arks.SeekPlan\x12(\n\rdump_req_info\x18\x13 \x01(\x0b\x32\x11.arks.DumpReqInfo\x12\x10\n\x08\x61pp_name\x18\x14 \x01(\t\x12\x16\n\x0ereq_timeout_ms\x18\x15 \x01(\x04\x12\x16\n\x0e\x63lient_version\x18\x16 \x01(\t\x12\n\n\x02ip\x18\x17 \x01(\t"\xba\x02\n\x0c\x41rksResponse\x12,\n\nerror_code\x18\x01 \x01(\x0e\x32\x0f.arks.ErrorCode:\x07SUCCESS\x12\x12\n\nsession_id\x18\x02 \x01(\t\x12\x13\n\x0b\x62ucket_name\x18\x03 \x01(\t\x12 \n\x0cuser_profile\x18\x04 \x03(\x0b\x32\n.arks.Pair\x12\x19\n\x05items\x18\x05 \x03(\x0b\x32\n.arks.Item\x12\x11\n\tdebug_msg\x18\x06 \x01(\t\x12\x0b\n\x03scm\x18\x07 \x01(\t\x12"\n\nseek_plans\x18\x08 \x03(\x0b\x32\x0e.arks.SeekPlan\x12\x0f\n\x07\x65rr_msg\x18\t \x01(\t\x12\x10\n\x08\x61lgo_ret\x18\n \x01(\x05\x12\x10\n\x08\x61lgo_msg\x18\x0b \x01(\t\x12\x11\n\ttrace_msg\x18\x0c \x01(\t\x12\n\n\x02rt\x18\r \x01(\x05*T\n\x10OutputFormatType\x12\x06\n\x02PB\x10\x01\x12\x08\n\x04JSON\x10\x02\x12\x08\n\x04TEXT\x10\x03\x12\r\n\tSNAPPY_PB\x10\x04\x12\x06\n\x02\x46\x42\x10\x05\x12\r\n\tSNAPPY_FB\x10\x06*\x86\x01\n\tErrorCode\x12\x0b\n\x07SUCCESS\x10\x00\x12\x0b\n\x07TIMEOUT\x10\x01\x12\r\n\tSCENE_ERR\x10\x02\x12\r\n\tPARAM_ERR\x10\x03\x12\x0e\n\nSYSTEM_ERR\x10\x04\x12\x0f\n\x0bSERVICE_ERR\x10\x05\x12\x10\n\x0c\x46LOW_CONTROL\x10\x06\x12\x0e\n\nOTHERS_ERR\x10\x07*\xae\x01\n\x0b\x43ontentType\x12\x10\n\x0cTYPE_INVALID\x10\x00\x12\r\n\tTYPE_BOOL\x10\x01\x12\x0e\n\nTYPE_INT32\x10\x02\x12\x0e\n\nTYPE_INT64\x10\x03\x12\x0f\n\x0bTYPE_UINT32\x10\x04\x12\x0f\n\x0bTYPE_UINT64\x10\x05\x12\r\n\tTYPE_FP32\x10\x06\x12\r\n\tTYPE_FP64\x10\x07\x12\x0f\n\x0bTYPE_STRING\x10\x08\x12\r\n\tTYPE_BYTE\x10\t*A\n\x10ItemSequenceType\x12\r\n\tTYPE_NONE\x10\x00\x12\x0f\n\x0bTYPE_CONCAT\x10\x01\x12\r\n\tTYPE_FLAT\x10\x02\x42\x17\n\x15\x63om.alipay.arks.proto'
+ ),
)
_OUTPUTFORMATTYPE = _descriptor.EnumDescriptor(
- name='OutputFormatType',
- full_name='arks.OutputFormatType',
- filename=None,
- file=DESCRIPTOR,
- values=[
- _descriptor.EnumValueDescriptor(
- name='PB', index=0, number=1,
- serialized_options=None,
- type=None),
- _descriptor.EnumValueDescriptor(
- name='JSON', index=1, number=2,
- serialized_options=None,
- type=None),
- _descriptor.EnumValueDescriptor(
- name='TEXT', index=2, number=3,
- serialized_options=None,
- type=None),
- _descriptor.EnumValueDescriptor(
- name='SNAPPY_PB', index=3, number=4,
- serialized_options=None,
- type=None),
- _descriptor.EnumValueDescriptor(
- name='FB', index=4, number=5,
- serialized_options=None,
- type=None),
- _descriptor.EnumValueDescriptor(
- name='SNAPPY_FB', index=5, number=6,
- serialized_options=None,
- type=None),
- ],
- containing_type=None,
- serialized_options=None,
- serialized_start=2422,
- serialized_end=2506,
+ name="OutputFormatType",
+ full_name="arks.OutputFormatType",
+ filename=None,
+ file=DESCRIPTOR,
+ values=[
+ _descriptor.EnumValueDescriptor(
+ name="PB", index=0, number=1, serialized_options=None, type=None
+ ),
+ _descriptor.EnumValueDescriptor(
+ name="JSON", index=1, number=2, serialized_options=None, type=None
+ ),
+ _descriptor.EnumValueDescriptor(
+ name="TEXT", index=2, number=3, serialized_options=None, type=None
+ ),
+ _descriptor.EnumValueDescriptor(
+ name="SNAPPY_PB", index=3, number=4, serialized_options=None, type=None
+ ),
+ _descriptor.EnumValueDescriptor(
+ name="FB", index=4, number=5, serialized_options=None, type=None
+ ),
+ _descriptor.EnumValueDescriptor(
+ name="SNAPPY_FB", index=5, number=6, serialized_options=None, type=None
+ ),
+ ],
+ containing_type=None,
+ serialized_options=None,
+ serialized_start=2422,
+ serialized_end=2506,
)
_sym_db.RegisterEnumDescriptor(_OUTPUTFORMATTYPE)
OutputFormatType = enum_type_wrapper.EnumTypeWrapper(_OUTPUTFORMATTYPE)
_ERRORCODE = _descriptor.EnumDescriptor(
- name='ErrorCode',
- full_name='arks.ErrorCode',
- filename=None,
- file=DESCRIPTOR,
- values=[
- _descriptor.EnumValueDescriptor(
- name='SUCCESS', index=0, number=0,
- serialized_options=None,
- type=None),
- _descriptor.EnumValueDescriptor(
- name='TIMEOUT', index=1, number=1,
- serialized_options=None,
- type=None),
- _descriptor.EnumValueDescriptor(
- name='SCENE_ERR', index=2, number=2,
- serialized_options=None,
- type=None),
- _descriptor.EnumValueDescriptor(
- name='PARAM_ERR', index=3, number=3,
- serialized_options=None,
- type=None),
- _descriptor.EnumValueDescriptor(
- name='SYSTEM_ERR', index=4, number=4,
- serialized_options=None,
- type=None),
- _descriptor.EnumValueDescriptor(
- name='SERVICE_ERR', index=5, number=5,
- serialized_options=None,
- type=None),
- _descriptor.EnumValueDescriptor(
- name='FLOW_CONTROL', index=6, number=6,
- serialized_options=None,
- type=None),
- _descriptor.EnumValueDescriptor(
- name='OTHERS_ERR', index=7, number=7,
- serialized_options=None,
- type=None),
- ],
- containing_type=None,
- serialized_options=None,
- serialized_start=2509,
- serialized_end=2643,
+ name="ErrorCode",
+ full_name="arks.ErrorCode",
+ filename=None,
+ file=DESCRIPTOR,
+ values=[
+ _descriptor.EnumValueDescriptor(
+ name="SUCCESS", index=0, number=0, serialized_options=None, type=None
+ ),
+ _descriptor.EnumValueDescriptor(
+ name="TIMEOUT", index=1, number=1, serialized_options=None, type=None
+ ),
+ _descriptor.EnumValueDescriptor(
+ name="SCENE_ERR", index=2, number=2, serialized_options=None, type=None
+ ),
+ _descriptor.EnumValueDescriptor(
+ name="PARAM_ERR", index=3, number=3, serialized_options=None, type=None
+ ),
+ _descriptor.EnumValueDescriptor(
+ name="SYSTEM_ERR", index=4, number=4, serialized_options=None, type=None
+ ),
+ _descriptor.EnumValueDescriptor(
+ name="SERVICE_ERR", index=5, number=5, serialized_options=None, type=None
+ ),
+ _descriptor.EnumValueDescriptor(
+ name="FLOW_CONTROL", index=6, number=6, serialized_options=None, type=None
+ ),
+ _descriptor.EnumValueDescriptor(
+ name="OTHERS_ERR", index=7, number=7, serialized_options=None, type=None
+ ),
+ ],
+ containing_type=None,
+ serialized_options=None,
+ serialized_start=2509,
+ serialized_end=2643,
)
_sym_db.RegisterEnumDescriptor(_ERRORCODE)
ErrorCode = enum_type_wrapper.EnumTypeWrapper(_ERRORCODE)
_CONTENTTYPE = _descriptor.EnumDescriptor(
- name='ContentType',
- full_name='arks.ContentType',
- filename=None,
- file=DESCRIPTOR,
- values=[
- _descriptor.EnumValueDescriptor(
- name='TYPE_INVALID', index=0, number=0,
- serialized_options=None,
- type=None),
- _descriptor.EnumValueDescriptor(
- name='TYPE_BOOL', index=1, number=1,
- serialized_options=None,
- type=None),
- _descriptor.EnumValueDescriptor(
- name='TYPE_INT32', index=2, number=2,
- serialized_options=None,
- type=None),
- _descriptor.EnumValueDescriptor(
- name='TYPE_INT64', index=3, number=3,
- serialized_options=None,
- type=None),
- _descriptor.EnumValueDescriptor(
- name='TYPE_UINT32', index=4, number=4,
- serialized_options=None,
- type=None),
- _descriptor.EnumValueDescriptor(
- name='TYPE_UINT64', index=5, number=5,
- serialized_options=None,
- type=None),
- _descriptor.EnumValueDescriptor(
- name='TYPE_FP32', index=6, number=6,
- serialized_options=None,
- type=None),
- _descriptor.EnumValueDescriptor(
- name='TYPE_FP64', index=7, number=7,
- serialized_options=None,
- type=None),
- _descriptor.EnumValueDescriptor(
- name='TYPE_STRING', index=8, number=8,
- serialized_options=None,
- type=None),
- _descriptor.EnumValueDescriptor(
- name='TYPE_BYTE', index=9, number=9,
- serialized_options=None,
- type=None),
- ],
- containing_type=None,
- serialized_options=None,
- serialized_start=2646,
- serialized_end=2820,
+ name="ContentType",
+ full_name="arks.ContentType",
+ filename=None,
+ file=DESCRIPTOR,
+ values=[
+ _descriptor.EnumValueDescriptor(
+ name="TYPE_INVALID", index=0, number=0, serialized_options=None, type=None
+ ),
+ _descriptor.EnumValueDescriptor(
+ name="TYPE_BOOL", index=1, number=1, serialized_options=None, type=None
+ ),
+ _descriptor.EnumValueDescriptor(
+ name="TYPE_INT32", index=2, number=2, serialized_options=None, type=None
+ ),
+ _descriptor.EnumValueDescriptor(
+ name="TYPE_INT64", index=3, number=3, serialized_options=None, type=None
+ ),
+ _descriptor.EnumValueDescriptor(
+ name="TYPE_UINT32", index=4, number=4, serialized_options=None, type=None
+ ),
+ _descriptor.EnumValueDescriptor(
+ name="TYPE_UINT64", index=5, number=5, serialized_options=None, type=None
+ ),
+ _descriptor.EnumValueDescriptor(
+ name="TYPE_FP32", index=6, number=6, serialized_options=None, type=None
+ ),
+ _descriptor.EnumValueDescriptor(
+ name="TYPE_FP64", index=7, number=7, serialized_options=None, type=None
+ ),
+ _descriptor.EnumValueDescriptor(
+ name="TYPE_STRING", index=8, number=8, serialized_options=None, type=None
+ ),
+ _descriptor.EnumValueDescriptor(
+ name="TYPE_BYTE", index=9, number=9, serialized_options=None, type=None
+ ),
+ ],
+ containing_type=None,
+ serialized_options=None,
+ serialized_start=2646,
+ serialized_end=2820,
)
_sym_db.RegisterEnumDescriptor(_CONTENTTYPE)
ContentType = enum_type_wrapper.EnumTypeWrapper(_CONTENTTYPE)
_ITEMSEQUENCETYPE = _descriptor.EnumDescriptor(
- name='ItemSequenceType',
- full_name='arks.ItemSequenceType',
- filename=None,
- file=DESCRIPTOR,
- values=[
- _descriptor.EnumValueDescriptor(
- name='TYPE_NONE', index=0, number=0,
- serialized_options=None,
- type=None),
- _descriptor.EnumValueDescriptor(
- name='TYPE_CONCAT', index=1, number=1,
- serialized_options=None,
- type=None),
- _descriptor.EnumValueDescriptor(
- name='TYPE_FLAT', index=2, number=2,
- serialized_options=None,
- type=None),
- ],
- containing_type=None,
- serialized_options=None,
- serialized_start=2822,
- serialized_end=2887,
+ name="ItemSequenceType",
+ full_name="arks.ItemSequenceType",
+ filename=None,
+ file=DESCRIPTOR,
+ values=[
+ _descriptor.EnumValueDescriptor(
+ name="TYPE_NONE", index=0, number=0, serialized_options=None, type=None
+ ),
+ _descriptor.EnumValueDescriptor(
+ name="TYPE_CONCAT", index=1, number=1, serialized_options=None, type=None
+ ),
+ _descriptor.EnumValueDescriptor(
+ name="TYPE_FLAT", index=2, number=2, serialized_options=None, type=None
+ ),
+ ],
+ containing_type=None,
+ serialized_options=None,
+ serialized_start=2822,
+ serialized_end=2887,
)
_sym_db.RegisterEnumDescriptor(_ITEMSEQUENCETYPE)
@@ -224,1044 +199,2131 @@
TYPE_FLAT = 2
-
_INFERTENSORCONTENTS = _descriptor.Descriptor(
- name='InferTensorContents',
- full_name='arks.InferTensorContents',
- filename=None,
- file=DESCRIPTOR,
- containing_type=None,
- fields=[
- _descriptor.FieldDescriptor(
- name='string_value', full_name='arks.InferTensorContents.string_value', index=0,
- number=1, type=9, cpp_type=9, label=3,
- has_default_value=False, default_value=[],
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='bool_value', full_name='arks.InferTensorContents.bool_value', index=1,
- number=2, type=8, cpp_type=7, label=3,
- has_default_value=False, default_value=[],
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='int_value', full_name='arks.InferTensorContents.int_value', index=2,
- number=3, type=5, cpp_type=1, label=3,
- has_default_value=False, default_value=[],
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='int64_value', full_name='arks.InferTensorContents.int64_value', index=3,
- number=4, type=3, cpp_type=2, label=3,
- has_default_value=False, default_value=[],
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='uint_value', full_name='arks.InferTensorContents.uint_value', index=4,
- number=5, type=13, cpp_type=3, label=3,
- has_default_value=False, default_value=[],
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='uint64_value', full_name='arks.InferTensorContents.uint64_value', index=5,
- number=6, type=4, cpp_type=4, label=3,
- has_default_value=False, default_value=[],
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='fp32_value', full_name='arks.InferTensorContents.fp32_value', index=6,
- number=7, type=2, cpp_type=6, label=3,
- has_default_value=False, default_value=[],
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='fp64_value', full_name='arks.InferTensorContents.fp64_value', index=7,
- number=8, type=1, cpp_type=5, label=3,
- has_default_value=False, default_value=[],
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='byte_value', full_name='arks.InferTensorContents.byte_value', index=8,
- number=9, type=12, cpp_type=9, label=3,
- has_default_value=False, default_value=[],
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='type', full_name='arks.InferTensorContents.type', index=9,
- number=10, type=14, cpp_type=8, label=1,
- has_default_value=True, default_value=0,
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=None, file=DESCRIPTOR),
- ],
- extensions=[
- ],
- nested_types=[],
- enum_types=[
- ],
- serialized_options=None,
- is_extendable=False,
- syntax='proto2',
- extension_ranges=[],
- oneofs=[
- ],
- serialized_start=21,
- serialized_end=273,
+ name="InferTensorContents",
+ full_name="arks.InferTensorContents",
+ filename=None,
+ file=DESCRIPTOR,
+ containing_type=None,
+ fields=[
+ _descriptor.FieldDescriptor(
+ name="string_value",
+ full_name="arks.InferTensorContents.string_value",
+ index=0,
+ number=1,
+ type=9,
+ cpp_type=9,
+ label=3,
+ has_default_value=False,
+ default_value=[],
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=None,
+ file=DESCRIPTOR,
+ ),
+ _descriptor.FieldDescriptor(
+ name="bool_value",
+ full_name="arks.InferTensorContents.bool_value",
+ index=1,
+ number=2,
+ type=8,
+ cpp_type=7,
+ label=3,
+ has_default_value=False,
+ default_value=[],
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=None,
+ file=DESCRIPTOR,
+ ),
+ _descriptor.FieldDescriptor(
+ name="int_value",
+ full_name="arks.InferTensorContents.int_value",
+ index=2,
+ number=3,
+ type=5,
+ cpp_type=1,
+ label=3,
+ has_default_value=False,
+ default_value=[],
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=None,
+ file=DESCRIPTOR,
+ ),
+ _descriptor.FieldDescriptor(
+ name="int64_value",
+ full_name="arks.InferTensorContents.int64_value",
+ index=3,
+ number=4,
+ type=3,
+ cpp_type=2,
+ label=3,
+ has_default_value=False,
+ default_value=[],
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=None,
+ file=DESCRIPTOR,
+ ),
+ _descriptor.FieldDescriptor(
+ name="uint_value",
+ full_name="arks.InferTensorContents.uint_value",
+ index=4,
+ number=5,
+ type=13,
+ cpp_type=3,
+ label=3,
+ has_default_value=False,
+ default_value=[],
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=None,
+ file=DESCRIPTOR,
+ ),
+ _descriptor.FieldDescriptor(
+ name="uint64_value",
+ full_name="arks.InferTensorContents.uint64_value",
+ index=5,
+ number=6,
+ type=4,
+ cpp_type=4,
+ label=3,
+ has_default_value=False,
+ default_value=[],
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=None,
+ file=DESCRIPTOR,
+ ),
+ _descriptor.FieldDescriptor(
+ name="fp32_value",
+ full_name="arks.InferTensorContents.fp32_value",
+ index=6,
+ number=7,
+ type=2,
+ cpp_type=6,
+ label=3,
+ has_default_value=False,
+ default_value=[],
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=None,
+ file=DESCRIPTOR,
+ ),
+ _descriptor.FieldDescriptor(
+ name="fp64_value",
+ full_name="arks.InferTensorContents.fp64_value",
+ index=7,
+ number=8,
+ type=1,
+ cpp_type=5,
+ label=3,
+ has_default_value=False,
+ default_value=[],
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=None,
+ file=DESCRIPTOR,
+ ),
+ _descriptor.FieldDescriptor(
+ name="byte_value",
+ full_name="arks.InferTensorContents.byte_value",
+ index=8,
+ number=9,
+ type=12,
+ cpp_type=9,
+ label=3,
+ has_default_value=False,
+ default_value=[],
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=None,
+ file=DESCRIPTOR,
+ ),
+ _descriptor.FieldDescriptor(
+ name="type",
+ full_name="arks.InferTensorContents.type",
+ index=9,
+ number=10,
+ type=14,
+ cpp_type=8,
+ label=1,
+ has_default_value=True,
+ default_value=0,
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=None,
+ file=DESCRIPTOR,
+ ),
+ ],
+ extensions=[],
+ nested_types=[],
+ enum_types=[],
+ serialized_options=None,
+ is_extendable=False,
+ syntax="proto2",
+ extension_ranges=[],
+ oneofs=[],
+ serialized_start=21,
+ serialized_end=273,
)
_PAIR = _descriptor.Descriptor(
- name='Pair',
- full_name='arks.Pair',
- filename=None,
- file=DESCRIPTOR,
- containing_type=None,
- fields=[
- _descriptor.FieldDescriptor(
- name='key', full_name='arks.Pair.key', index=0,
- number=1, type=9, cpp_type=9, label=1,
- has_default_value=False, default_value=_b("").decode('utf-8'),
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='value', full_name='arks.Pair.value', index=1,
- number=2, type=9, cpp_type=9, label=1,
- has_default_value=False, default_value=_b("").decode('utf-8'),
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='contents', full_name='arks.Pair.contents', index=2,
- number=3, type=11, cpp_type=10, label=1,
- has_default_value=False, default_value=None,
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='pb_value', full_name='arks.Pair.pb_value', index=3,
- number=4, type=12, cpp_type=9, label=3,
- has_default_value=False, default_value=[],
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='shapes', full_name='arks.Pair.shapes', index=4,
- number=5, type=5, cpp_type=1, label=3,
- has_default_value=False, default_value=[],
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=None, file=DESCRIPTOR),
- ],
- extensions=[
- ],
- nested_types=[],
- enum_types=[
- ],
- serialized_options=None,
- is_extendable=False,
- syntax='proto2',
- extension_ranges=[],
- oneofs=[
- ],
- serialized_start=275,
- serialized_end=388,
+ name="Pair",
+ full_name="arks.Pair",
+ filename=None,
+ file=DESCRIPTOR,
+ containing_type=None,
+ fields=[
+ _descriptor.FieldDescriptor(
+ name="key",
+ full_name="arks.Pair.key",
+ index=0,
+ number=1,
+ type=9,
+ cpp_type=9,
+ label=1,
+ has_default_value=False,
+ default_value=_b("").decode("utf-8"),
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=None,
+ file=DESCRIPTOR,
+ ),
+ _descriptor.FieldDescriptor(
+ name="value",
+ full_name="arks.Pair.value",
+ index=1,
+ number=2,
+ type=9,
+ cpp_type=9,
+ label=1,
+ has_default_value=False,
+ default_value=_b("").decode("utf-8"),
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=None,
+ file=DESCRIPTOR,
+ ),
+ _descriptor.FieldDescriptor(
+ name="contents",
+ full_name="arks.Pair.contents",
+ index=2,
+ number=3,
+ type=11,
+ cpp_type=10,
+ label=1,
+ has_default_value=False,
+ default_value=None,
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=None,
+ file=DESCRIPTOR,
+ ),
+ _descriptor.FieldDescriptor(
+ name="pb_value",
+ full_name="arks.Pair.pb_value",
+ index=3,
+ number=4,
+ type=12,
+ cpp_type=9,
+ label=3,
+ has_default_value=False,
+ default_value=[],
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=None,
+ file=DESCRIPTOR,
+ ),
+ _descriptor.FieldDescriptor(
+ name="shapes",
+ full_name="arks.Pair.shapes",
+ index=4,
+ number=5,
+ type=5,
+ cpp_type=1,
+ label=3,
+ has_default_value=False,
+ default_value=[],
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=None,
+ file=DESCRIPTOR,
+ ),
+ ],
+ extensions=[],
+ nested_types=[],
+ enum_types=[],
+ serialized_options=None,
+ is_extendable=False,
+ syntax="proto2",
+ extension_ranges=[],
+ oneofs=[],
+ serialized_start=275,
+ serialized_end=388,
)
_ROWKEY = _descriptor.Descriptor(
- name='RowKey',
- full_name='arks.RowKey',
- filename=None,
- file=DESCRIPTOR,
- containing_type=None,
- fields=[
- _descriptor.FieldDescriptor(
- name='row_key', full_name='arks.RowKey.row_key', index=0,
- number=1, type=9, cpp_type=9, label=1,
- has_default_value=False, default_value=_b("").decode('utf-8'),
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='versions', full_name='arks.RowKey.versions', index=1,
- number=2, type=3, cpp_type=2, label=3,
- has_default_value=False, default_value=[],
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='ant_fea_track_info', full_name='arks.RowKey.ant_fea_track_info', index=2,
- number=3, type=9, cpp_type=9, label=1,
- has_default_value=False, default_value=_b("").decode('utf-8'),
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='partitions', full_name='arks.RowKey.partitions', index=3,
- number=4, type=11, cpp_type=10, label=3,
- has_default_value=False, default_value=[],
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='realtime_features', full_name='arks.RowKey.realtime_features', index=4,
- number=5, type=11, cpp_type=10, label=3,
- has_default_value=False, default_value=[],
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=None, file=DESCRIPTOR),
- ],
- extensions=[
- ],
- nested_types=[],
- enum_types=[
- ],
- serialized_options=None,
- is_extendable=False,
- syntax='proto2',
- extension_ranges=[],
- oneofs=[
- ],
- serialized_start=391,
- serialized_end=542,
+ name="RowKey",
+ full_name="arks.RowKey",
+ filename=None,
+ file=DESCRIPTOR,
+ containing_type=None,
+ fields=[
+ _descriptor.FieldDescriptor(
+ name="row_key",
+ full_name="arks.RowKey.row_key",
+ index=0,
+ number=1,
+ type=9,
+ cpp_type=9,
+ label=1,
+ has_default_value=False,
+ default_value=_b("").decode("utf-8"),
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=None,
+ file=DESCRIPTOR,
+ ),
+ _descriptor.FieldDescriptor(
+ name="versions",
+ full_name="arks.RowKey.versions",
+ index=1,
+ number=2,
+ type=3,
+ cpp_type=2,
+ label=3,
+ has_default_value=False,
+ default_value=[],
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=None,
+ file=DESCRIPTOR,
+ ),
+ _descriptor.FieldDescriptor(
+ name="ant_fea_track_info",
+ full_name="arks.RowKey.ant_fea_track_info",
+ index=2,
+ number=3,
+ type=9,
+ cpp_type=9,
+ label=1,
+ has_default_value=False,
+ default_value=_b("").decode("utf-8"),
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=None,
+ file=DESCRIPTOR,
+ ),
+ _descriptor.FieldDescriptor(
+ name="partitions",
+ full_name="arks.RowKey.partitions",
+ index=3,
+ number=4,
+ type=11,
+ cpp_type=10,
+ label=3,
+ has_default_value=False,
+ default_value=[],
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=None,
+ file=DESCRIPTOR,
+ ),
+ _descriptor.FieldDescriptor(
+ name="realtime_features",
+ full_name="arks.RowKey.realtime_features",
+ index=4,
+ number=5,
+ type=11,
+ cpp_type=10,
+ label=3,
+ has_default_value=False,
+ default_value=[],
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=None,
+ file=DESCRIPTOR,
+ ),
+ ],
+ extensions=[],
+ nested_types=[],
+ enum_types=[],
+ serialized_options=None,
+ is_extendable=False,
+ syntax="proto2",
+ extension_ranges=[],
+ oneofs=[],
+ serialized_start=391,
+ serialized_end=542,
)
_PARTITIONINFO = _descriptor.Descriptor(
- name='PartitionInfo',
- full_name='arks.PartitionInfo',
- filename=None,
- file=DESCRIPTOR,
- containing_type=None,
- fields=[
- _descriptor.FieldDescriptor(
- name='name', full_name='arks.PartitionInfo.name', index=0,
- number=1, type=9, cpp_type=9, label=1,
- has_default_value=False, default_value=_b("").decode('utf-8'),
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='value', full_name='arks.PartitionInfo.value', index=1,
- number=2, type=9, cpp_type=9, label=1,
- has_default_value=False, default_value=_b("").decode('utf-8'),
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=None, file=DESCRIPTOR),
- ],
- extensions=[
- ],
- nested_types=[],
- enum_types=[
- ],
- serialized_options=None,
- is_extendable=False,
- syntax='proto2',
- extension_ranges=[],
- oneofs=[
- ],
- serialized_start=544,
- serialized_end=588,
+ name="PartitionInfo",
+ full_name="arks.PartitionInfo",
+ filename=None,
+ file=DESCRIPTOR,
+ containing_type=None,
+ fields=[
+ _descriptor.FieldDescriptor(
+ name="name",
+ full_name="arks.PartitionInfo.name",
+ index=0,
+ number=1,
+ type=9,
+ cpp_type=9,
+ label=1,
+ has_default_value=False,
+ default_value=_b("").decode("utf-8"),
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=None,
+ file=DESCRIPTOR,
+ ),
+ _descriptor.FieldDescriptor(
+ name="value",
+ full_name="arks.PartitionInfo.value",
+ index=1,
+ number=2,
+ type=9,
+ cpp_type=9,
+ label=1,
+ has_default_value=False,
+ default_value=_b("").decode("utf-8"),
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=None,
+ file=DESCRIPTOR,
+ ),
+ ],
+ extensions=[],
+ nested_types=[],
+ enum_types=[],
+ serialized_options=None,
+ is_extendable=False,
+ syntax="proto2",
+ extension_ranges=[],
+ oneofs=[],
+ serialized_start=544,
+ serialized_end=588,
)
_ITEM = _descriptor.Descriptor(
- name='Item',
- full_name='arks.Item',
- filename=None,
- file=DESCRIPTOR,
- containing_type=None,
- fields=[
- _descriptor.FieldDescriptor(
- name='item_id', full_name='arks.Item.item_id', index=0,
- number=1, type=9, cpp_type=9, label=2,
- has_default_value=False, default_value=_b("").decode('utf-8'),
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='features', full_name='arks.Item.features', index=1,
- number=2, type=11, cpp_type=10, label=3,
- has_default_value=False, default_value=[],
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='attributes', full_name='arks.Item.attributes', index=2,
- number=3, type=11, cpp_type=10, label=3,
- has_default_value=False, default_value=[],
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='score', full_name='arks.Item.score', index=3,
- number=4, type=2, cpp_type=6, label=1,
- has_default_value=False, default_value=float(0),
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='sub_items', full_name='arks.Item.sub_items', index=4,
- number=5, type=11, cpp_type=10, label=3,
- has_default_value=False, default_value=[],
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='is_features_valid', full_name='arks.Item.is_features_valid', index=5,
- number=6, type=8, cpp_type=7, label=3,
- has_default_value=False, default_value=[],
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=_b('\020\001'), file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='scores', full_name='arks.Item.scores', index=6,
- number=7, type=2, cpp_type=6, label=3,
- has_default_value=False, default_value=[],
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=_b('\020\001'), file=DESCRIPTOR),
- ],
- extensions=[
- ],
- nested_types=[],
- enum_types=[
- ],
- serialized_options=None,
- is_extendable=False,
- syntax='proto2',
- extension_ranges=[],
- oneofs=[
- ],
- serialized_start=591,
- serialized_end=776,
+ name="Item",
+ full_name="arks.Item",
+ filename=None,
+ file=DESCRIPTOR,
+ containing_type=None,
+ fields=[
+ _descriptor.FieldDescriptor(
+ name="item_id",
+ full_name="arks.Item.item_id",
+ index=0,
+ number=1,
+ type=9,
+ cpp_type=9,
+ label=2,
+ has_default_value=False,
+ default_value=_b("").decode("utf-8"),
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=None,
+ file=DESCRIPTOR,
+ ),
+ _descriptor.FieldDescriptor(
+ name="features",
+ full_name="arks.Item.features",
+ index=1,
+ number=2,
+ type=11,
+ cpp_type=10,
+ label=3,
+ has_default_value=False,
+ default_value=[],
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=None,
+ file=DESCRIPTOR,
+ ),
+ _descriptor.FieldDescriptor(
+ name="attributes",
+ full_name="arks.Item.attributes",
+ index=2,
+ number=3,
+ type=11,
+ cpp_type=10,
+ label=3,
+ has_default_value=False,
+ default_value=[],
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=None,
+ file=DESCRIPTOR,
+ ),
+ _descriptor.FieldDescriptor(
+ name="score",
+ full_name="arks.Item.score",
+ index=3,
+ number=4,
+ type=2,
+ cpp_type=6,
+ label=1,
+ has_default_value=False,
+ default_value=float(0),
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=None,
+ file=DESCRIPTOR,
+ ),
+ _descriptor.FieldDescriptor(
+ name="sub_items",
+ full_name="arks.Item.sub_items",
+ index=4,
+ number=5,
+ type=11,
+ cpp_type=10,
+ label=3,
+ has_default_value=False,
+ default_value=[],
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=None,
+ file=DESCRIPTOR,
+ ),
+ _descriptor.FieldDescriptor(
+ name="is_features_valid",
+ full_name="arks.Item.is_features_valid",
+ index=5,
+ number=6,
+ type=8,
+ cpp_type=7,
+ label=3,
+ has_default_value=False,
+ default_value=[],
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=_b("\020\001"),
+ file=DESCRIPTOR,
+ ),
+ _descriptor.FieldDescriptor(
+ name="scores",
+ full_name="arks.Item.scores",
+ index=6,
+ number=7,
+ type=2,
+ cpp_type=6,
+ label=3,
+ has_default_value=False,
+ default_value=[],
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=_b("\020\001"),
+ file=DESCRIPTOR,
+ ),
+ ],
+ extensions=[],
+ nested_types=[],
+ enum_types=[],
+ serialized_options=None,
+ is_extendable=False,
+ syntax="proto2",
+ extension_ranges=[],
+ oneofs=[],
+ serialized_start=591,
+ serialized_end=776,
)
_SUBITEM = _descriptor.Descriptor(
- name='SubItem',
- full_name='arks.SubItem',
- filename=None,
- file=DESCRIPTOR,
- containing_type=None,
- fields=[
- _descriptor.FieldDescriptor(
- name='item_id', full_name='arks.SubItem.item_id', index=0,
- number=1, type=9, cpp_type=9, label=1,
- has_default_value=False, default_value=_b("").decode('utf-8'),
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='features', full_name='arks.SubItem.features', index=1,
- number=2, type=11, cpp_type=10, label=3,
- has_default_value=False, default_value=[],
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='score', full_name='arks.SubItem.score', index=2,
- number=3, type=2, cpp_type=6, label=1,
- has_default_value=False, default_value=float(0),
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='is_features_valid', full_name='arks.SubItem.is_features_valid', index=3,
- number=4, type=8, cpp_type=7, label=3,
- has_default_value=False, default_value=[],
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=_b('\020\001'), file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='scores', full_name='arks.SubItem.scores', index=4,
- number=5, type=2, cpp_type=6, label=3,
- has_default_value=False, default_value=[],
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=_b('\020\001'), file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='attributes', full_name='arks.SubItem.attributes', index=5,
- number=6, type=11, cpp_type=10, label=3,
- has_default_value=False, default_value=[],
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=None, file=DESCRIPTOR),
- ],
- extensions=[
- ],
- nested_types=[],
- enum_types=[
- ],
- serialized_options=None,
- is_extendable=False,
- syntax='proto2',
- extension_ranges=[],
- oneofs=[
- ],
- serialized_start=779,
- serialized_end=933,
+ name="SubItem",
+ full_name="arks.SubItem",
+ filename=None,
+ file=DESCRIPTOR,
+ containing_type=None,
+ fields=[
+ _descriptor.FieldDescriptor(
+ name="item_id",
+ full_name="arks.SubItem.item_id",
+ index=0,
+ number=1,
+ type=9,
+ cpp_type=9,
+ label=1,
+ has_default_value=False,
+ default_value=_b("").decode("utf-8"),
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=None,
+ file=DESCRIPTOR,
+ ),
+ _descriptor.FieldDescriptor(
+ name="features",
+ full_name="arks.SubItem.features",
+ index=1,
+ number=2,
+ type=11,
+ cpp_type=10,
+ label=3,
+ has_default_value=False,
+ default_value=[],
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=None,
+ file=DESCRIPTOR,
+ ),
+ _descriptor.FieldDescriptor(
+ name="score",
+ full_name="arks.SubItem.score",
+ index=2,
+ number=3,
+ type=2,
+ cpp_type=6,
+ label=1,
+ has_default_value=False,
+ default_value=float(0),
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=None,
+ file=DESCRIPTOR,
+ ),
+ _descriptor.FieldDescriptor(
+ name="is_features_valid",
+ full_name="arks.SubItem.is_features_valid",
+ index=3,
+ number=4,
+ type=8,
+ cpp_type=7,
+ label=3,
+ has_default_value=False,
+ default_value=[],
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=_b("\020\001"),
+ file=DESCRIPTOR,
+ ),
+ _descriptor.FieldDescriptor(
+ name="scores",
+ full_name="arks.SubItem.scores",
+ index=4,
+ number=5,
+ type=2,
+ cpp_type=6,
+ label=3,
+ has_default_value=False,
+ default_value=[],
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=_b("\020\001"),
+ file=DESCRIPTOR,
+ ),
+ _descriptor.FieldDescriptor(
+ name="attributes",
+ full_name="arks.SubItem.attributes",
+ index=5,
+ number=6,
+ type=11,
+ cpp_type=10,
+ label=3,
+ has_default_value=False,
+ default_value=[],
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=None,
+ file=DESCRIPTOR,
+ ),
+ ],
+ extensions=[],
+ nested_types=[],
+ enum_types=[],
+ serialized_options=None,
+ is_extendable=False,
+ syntax="proto2",
+ extension_ranges=[],
+ oneofs=[],
+ serialized_start=779,
+ serialized_end=933,
)
_SEEKPLAN = _descriptor.Descriptor(
- name='SeekPlan',
- full_name='arks.SeekPlan',
- filename=None,
- file=DESCRIPTOR,
- containing_type=None,
- fields=[
- _descriptor.FieldDescriptor(
- name='storage_type', full_name='arks.SeekPlan.storage_type', index=0,
- number=1, type=9, cpp_type=9, label=1,
- has_default_value=False, default_value=_b("").decode('utf-8'),
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='table', full_name='arks.SeekPlan.table', index=1,
- number=2, type=9, cpp_type=9, label=1,
- has_default_value=False, default_value=_b("").decode('utf-8'),
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='column_family', full_name='arks.SeekPlan.column_family', index=2,
- number=3, type=9, cpp_type=9, label=1,
- has_default_value=False, default_value=_b("").decode('utf-8'),
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='columns', full_name='arks.SeekPlan.columns', index=3,
- number=4, type=9, cpp_type=9, label=3,
- has_default_value=False, default_value=[],
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='kvpair_sep', full_name='arks.SeekPlan.kvpair_sep', index=4,
- number=5, type=9, cpp_type=9, label=1,
- has_default_value=False, default_value=_b("").decode('utf-8'),
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='kv_sep', full_name='arks.SeekPlan.kv_sep', index=5,
- number=6, type=9, cpp_type=9, label=1,
- has_default_value=False, default_value=_b("").decode('utf-8'),
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='cluster', full_name='arks.SeekPlan.cluster', index=6,
- number=7, type=9, cpp_type=9, label=1,
- has_default_value=False, default_value=_b("").decode('utf-8'),
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='row_keys', full_name='arks.SeekPlan.row_keys', index=7,
- number=8, type=11, cpp_type=10, label=3,
- has_default_value=False, default_value=[],
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='timeout_ms', full_name='arks.SeekPlan.timeout_ms', index=8,
- number=9, type=5, cpp_type=1, label=1,
- has_default_value=False, default_value=0,
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='cache_expire_second', full_name='arks.SeekPlan.cache_expire_second', index=9,
- number=10, type=5, cpp_type=1, label=1,
- has_default_value=False, default_value=0,
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='url_user', full_name='arks.SeekPlan.url_user', index=10,
- number=11, type=9, cpp_type=9, label=1,
- has_default_value=False, default_value=_b("").decode('utf-8'),
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='url_item', full_name='arks.SeekPlan.url_item', index=11,
- number=12, type=9, cpp_type=9, label=1,
- has_default_value=False, default_value=_b("").decode('utf-8'),
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='ant_feature_req', full_name='arks.SeekPlan.ant_feature_req', index=12,
- number=13, type=12, cpp_type=9, label=1,
- has_default_value=False, default_value=_b(""),
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='id', full_name='arks.SeekPlan.id', index=13,
- number=14, type=9, cpp_type=9, label=1,
- has_default_value=False, default_value=_b("").decode('utf-8'),
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='kb_feature_req', full_name='arks.SeekPlan.kb_feature_req', index=14,
- number=15, type=12, cpp_type=9, label=1,
- has_default_value=False, default_value=_b(""),
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='debuginfo', full_name='arks.SeekPlan.debuginfo', index=15,
- number=16, type=9, cpp_type=9, label=1,
- has_default_value=False, default_value=_b("").decode('utf-8'),
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='separator', full_name='arks.SeekPlan.separator', index=16,
- number=17, type=9, cpp_type=9, label=1,
- has_default_value=False, default_value=_b("").decode('utf-8'),
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='item_sequence_type', full_name='arks.SeekPlan.item_sequence_type', index=17,
- number=18, type=14, cpp_type=8, label=1,
- has_default_value=True, default_value=0,
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='missing_values', full_name='arks.SeekPlan.missing_values', index=18,
- number=19, type=11, cpp_type=10, label=3,
- has_default_value=False, default_value=[],
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=None, file=DESCRIPTOR),
- ],
- extensions=[
- ],
- nested_types=[],
- enum_types=[
- ],
- serialized_options=None,
- is_extendable=False,
- syntax='proto2',
- extension_ranges=[],
- oneofs=[
- ],
- serialized_start=936,
- serialized_end=1391,
+ name="SeekPlan",
+ full_name="arks.SeekPlan",
+ filename=None,
+ file=DESCRIPTOR,
+ containing_type=None,
+ fields=[
+ _descriptor.FieldDescriptor(
+ name="storage_type",
+ full_name="arks.SeekPlan.storage_type",
+ index=0,
+ number=1,
+ type=9,
+ cpp_type=9,
+ label=1,
+ has_default_value=False,
+ default_value=_b("").decode("utf-8"),
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=None,
+ file=DESCRIPTOR,
+ ),
+ _descriptor.FieldDescriptor(
+ name="table",
+ full_name="arks.SeekPlan.table",
+ index=1,
+ number=2,
+ type=9,
+ cpp_type=9,
+ label=1,
+ has_default_value=False,
+ default_value=_b("").decode("utf-8"),
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=None,
+ file=DESCRIPTOR,
+ ),
+ _descriptor.FieldDescriptor(
+ name="column_family",
+ full_name="arks.SeekPlan.column_family",
+ index=2,
+ number=3,
+ type=9,
+ cpp_type=9,
+ label=1,
+ has_default_value=False,
+ default_value=_b("").decode("utf-8"),
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=None,
+ file=DESCRIPTOR,
+ ),
+ _descriptor.FieldDescriptor(
+ name="columns",
+ full_name="arks.SeekPlan.columns",
+ index=3,
+ number=4,
+ type=9,
+ cpp_type=9,
+ label=3,
+ has_default_value=False,
+ default_value=[],
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=None,
+ file=DESCRIPTOR,
+ ),
+ _descriptor.FieldDescriptor(
+ name="kvpair_sep",
+ full_name="arks.SeekPlan.kvpair_sep",
+ index=4,
+ number=5,
+ type=9,
+ cpp_type=9,
+ label=1,
+ has_default_value=False,
+ default_value=_b("").decode("utf-8"),
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=None,
+ file=DESCRIPTOR,
+ ),
+ _descriptor.FieldDescriptor(
+ name="kv_sep",
+ full_name="arks.SeekPlan.kv_sep",
+ index=5,
+ number=6,
+ type=9,
+ cpp_type=9,
+ label=1,
+ has_default_value=False,
+ default_value=_b("").decode("utf-8"),
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=None,
+ file=DESCRIPTOR,
+ ),
+ _descriptor.FieldDescriptor(
+ name="cluster",
+ full_name="arks.SeekPlan.cluster",
+ index=6,
+ number=7,
+ type=9,
+ cpp_type=9,
+ label=1,
+ has_default_value=False,
+ default_value=_b("").decode("utf-8"),
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=None,
+ file=DESCRIPTOR,
+ ),
+ _descriptor.FieldDescriptor(
+ name="row_keys",
+ full_name="arks.SeekPlan.row_keys",
+ index=7,
+ number=8,
+ type=11,
+ cpp_type=10,
+ label=3,
+ has_default_value=False,
+ default_value=[],
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=None,
+ file=DESCRIPTOR,
+ ),
+ _descriptor.FieldDescriptor(
+ name="timeout_ms",
+ full_name="arks.SeekPlan.timeout_ms",
+ index=8,
+ number=9,
+ type=5,
+ cpp_type=1,
+ label=1,
+ has_default_value=False,
+ default_value=0,
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=None,
+ file=DESCRIPTOR,
+ ),
+ _descriptor.FieldDescriptor(
+ name="cache_expire_second",
+ full_name="arks.SeekPlan.cache_expire_second",
+ index=9,
+ number=10,
+ type=5,
+ cpp_type=1,
+ label=1,
+ has_default_value=False,
+ default_value=0,
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=None,
+ file=DESCRIPTOR,
+ ),
+ _descriptor.FieldDescriptor(
+ name="url_user",
+ full_name="arks.SeekPlan.url_user",
+ index=10,
+ number=11,
+ type=9,
+ cpp_type=9,
+ label=1,
+ has_default_value=False,
+ default_value=_b("").decode("utf-8"),
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=None,
+ file=DESCRIPTOR,
+ ),
+ _descriptor.FieldDescriptor(
+ name="url_item",
+ full_name="arks.SeekPlan.url_item",
+ index=11,
+ number=12,
+ type=9,
+ cpp_type=9,
+ label=1,
+ has_default_value=False,
+ default_value=_b("").decode("utf-8"),
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=None,
+ file=DESCRIPTOR,
+ ),
+ _descriptor.FieldDescriptor(
+ name="ant_feature_req",
+ full_name="arks.SeekPlan.ant_feature_req",
+ index=12,
+ number=13,
+ type=12,
+ cpp_type=9,
+ label=1,
+ has_default_value=False,
+ default_value=_b(""),
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=None,
+ file=DESCRIPTOR,
+ ),
+ _descriptor.FieldDescriptor(
+ name="id",
+ full_name="arks.SeekPlan.id",
+ index=13,
+ number=14,
+ type=9,
+ cpp_type=9,
+ label=1,
+ has_default_value=False,
+ default_value=_b("").decode("utf-8"),
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=None,
+ file=DESCRIPTOR,
+ ),
+ _descriptor.FieldDescriptor(
+ name="kb_feature_req",
+ full_name="arks.SeekPlan.kb_feature_req",
+ index=14,
+ number=15,
+ type=12,
+ cpp_type=9,
+ label=1,
+ has_default_value=False,
+ default_value=_b(""),
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=None,
+ file=DESCRIPTOR,
+ ),
+ _descriptor.FieldDescriptor(
+ name="debuginfo",
+ full_name="arks.SeekPlan.debuginfo",
+ index=15,
+ number=16,
+ type=9,
+ cpp_type=9,
+ label=1,
+ has_default_value=False,
+ default_value=_b("").decode("utf-8"),
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=None,
+ file=DESCRIPTOR,
+ ),
+ _descriptor.FieldDescriptor(
+ name="separator",
+ full_name="arks.SeekPlan.separator",
+ index=16,
+ number=17,
+ type=9,
+ cpp_type=9,
+ label=1,
+ has_default_value=False,
+ default_value=_b("").decode("utf-8"),
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=None,
+ file=DESCRIPTOR,
+ ),
+ _descriptor.FieldDescriptor(
+ name="item_sequence_type",
+ full_name="arks.SeekPlan.item_sequence_type",
+ index=17,
+ number=18,
+ type=14,
+ cpp_type=8,
+ label=1,
+ has_default_value=True,
+ default_value=0,
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=None,
+ file=DESCRIPTOR,
+ ),
+ _descriptor.FieldDescriptor(
+ name="missing_values",
+ full_name="arks.SeekPlan.missing_values",
+ index=18,
+ number=19,
+ type=11,
+ cpp_type=10,
+ label=3,
+ has_default_value=False,
+ default_value=[],
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=None,
+ file=DESCRIPTOR,
+ ),
+ ],
+ extensions=[],
+ nested_types=[],
+ enum_types=[],
+ serialized_options=None,
+ is_extendable=False,
+ syntax="proto2",
+ extension_ranges=[],
+ oneofs=[],
+ serialized_start=936,
+ serialized_end=1391,
)
_DUMPREQINFO = _descriptor.Descriptor(
- name='DumpReqInfo',
- full_name='arks.DumpReqInfo',
- filename=None,
- file=DESCRIPTOR,
- containing_type=None,
- fields=[
- _descriptor.FieldDescriptor(
- name='time_s', full_name='arks.DumpReqInfo.time_s', index=0,
- number=1, type=5, cpp_type=1, label=1,
- has_default_value=False, default_value=0,
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='oss_id', full_name='arks.DumpReqInfo.oss_id', index=1,
- number=2, type=9, cpp_type=9, label=1,
- has_default_value=False, default_value=_b("").decode('utf-8'),
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='oss_key', full_name='arks.DumpReqInfo.oss_key', index=2,
- number=3, type=9, cpp_type=9, label=1,
- has_default_value=False, default_value=_b("").decode('utf-8'),
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='target_addr', full_name='arks.DumpReqInfo.target_addr', index=3,
- number=4, type=9, cpp_type=9, label=1,
- has_default_value=False, default_value=_b("").decode('utf-8'),
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='query_id', full_name='arks.DumpReqInfo.query_id', index=4,
- number=5, type=3, cpp_type=2, label=1,
- has_default_value=False, default_value=0,
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='token', full_name='arks.DumpReqInfo.token', index=5,
- number=6, type=9, cpp_type=9, label=1,
- has_default_value=False, default_value=_b("").decode('utf-8'),
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='app', full_name='arks.DumpReqInfo.app', index=6,
- number=7, type=9, cpp_type=9, label=1,
- has_default_value=False, default_value=_b("").decode('utf-8'),
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='host', full_name='arks.DumpReqInfo.host', index=7,
- number=8, type=9, cpp_type=9, label=1,
- has_default_value=False, default_value=_b("").decode('utf-8'),
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=None, file=DESCRIPTOR),
- ],
- extensions=[
- ],
- nested_types=[],
- enum_types=[
- ],
- serialized_options=None,
- is_extendable=False,
- syntax='proto2',
- extension_ranges=[],
- oneofs=[
- ],
- serialized_start=1394,
- serialized_end=1537,
+ name="DumpReqInfo",
+ full_name="arks.DumpReqInfo",
+ filename=None,
+ file=DESCRIPTOR,
+ containing_type=None,
+ fields=[
+ _descriptor.FieldDescriptor(
+ name="time_s",
+ full_name="arks.DumpReqInfo.time_s",
+ index=0,
+ number=1,
+ type=5,
+ cpp_type=1,
+ label=1,
+ has_default_value=False,
+ default_value=0,
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=None,
+ file=DESCRIPTOR,
+ ),
+ _descriptor.FieldDescriptor(
+ name="oss_id",
+ full_name="arks.DumpReqInfo.oss_id",
+ index=1,
+ number=2,
+ type=9,
+ cpp_type=9,
+ label=1,
+ has_default_value=False,
+ default_value=_b("").decode("utf-8"),
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=None,
+ file=DESCRIPTOR,
+ ),
+ _descriptor.FieldDescriptor(
+ name="oss_key",
+ full_name="arks.DumpReqInfo.oss_key",
+ index=2,
+ number=3,
+ type=9,
+ cpp_type=9,
+ label=1,
+ has_default_value=False,
+ default_value=_b("").decode("utf-8"),
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=None,
+ file=DESCRIPTOR,
+ ),
+ _descriptor.FieldDescriptor(
+ name="target_addr",
+ full_name="arks.DumpReqInfo.target_addr",
+ index=3,
+ number=4,
+ type=9,
+ cpp_type=9,
+ label=1,
+ has_default_value=False,
+ default_value=_b("").decode("utf-8"),
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=None,
+ file=DESCRIPTOR,
+ ),
+ _descriptor.FieldDescriptor(
+ name="query_id",
+ full_name="arks.DumpReqInfo.query_id",
+ index=4,
+ number=5,
+ type=3,
+ cpp_type=2,
+ label=1,
+ has_default_value=False,
+ default_value=0,
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=None,
+ file=DESCRIPTOR,
+ ),
+ _descriptor.FieldDescriptor(
+ name="token",
+ full_name="arks.DumpReqInfo.token",
+ index=5,
+ number=6,
+ type=9,
+ cpp_type=9,
+ label=1,
+ has_default_value=False,
+ default_value=_b("").decode("utf-8"),
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=None,
+ file=DESCRIPTOR,
+ ),
+ _descriptor.FieldDescriptor(
+ name="app",
+ full_name="arks.DumpReqInfo.app",
+ index=6,
+ number=7,
+ type=9,
+ cpp_type=9,
+ label=1,
+ has_default_value=False,
+ default_value=_b("").decode("utf-8"),
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=None,
+ file=DESCRIPTOR,
+ ),
+ _descriptor.FieldDescriptor(
+ name="host",
+ full_name="arks.DumpReqInfo.host",
+ index=7,
+ number=8,
+ type=9,
+ cpp_type=9,
+ label=1,
+ has_default_value=False,
+ default_value=_b("").decode("utf-8"),
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=None,
+ file=DESCRIPTOR,
+ ),
+ ],
+ extensions=[],
+ nested_types=[],
+ enum_types=[],
+ serialized_options=None,
+ is_extendable=False,
+ syntax="proto2",
+ extension_ranges=[],
+ oneofs=[],
+ serialized_start=1394,
+ serialized_end=1537,
)
_ARKSREQUEST = _descriptor.Descriptor(
- name='ArksRequest',
- full_name='arks.ArksRequest',
- filename=None,
- file=DESCRIPTOR,
- containing_type=None,
- fields=[
- _descriptor.FieldDescriptor(
- name='version', full_name='arks.ArksRequest.version', index=0,
- number=1, type=5, cpp_type=1, label=1,
- has_default_value=True, default_value=1,
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='debug', full_name='arks.ArksRequest.debug', index=1,
- number=2, type=5, cpp_type=1, label=1,
- has_default_value=False, default_value=0,
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='is_ping', full_name='arks.ArksRequest.is_ping', index=2,
- number=3, type=8, cpp_type=7, label=1,
- has_default_value=False, default_value=False,
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='session_id', full_name='arks.ArksRequest.session_id', index=3,
- number=4, type=9, cpp_type=9, label=1,
- has_default_value=False, default_value=_b("").decode('utf-8'),
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='bucket_name', full_name='arks.ArksRequest.bucket_name', index=4,
- number=5, type=9, cpp_type=9, label=1,
- has_default_value=False, default_value=_b("").decode('utf-8'),
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='uid', full_name='arks.ArksRequest.uid', index=5,
- number=6, type=9, cpp_type=9, label=1,
- has_default_value=False, default_value=_b("").decode('utf-8'),
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='user_profile', full_name='arks.ArksRequest.user_profile', index=6,
- number=7, type=11, cpp_type=10, label=3,
- has_default_value=False, default_value=[],
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='scene_features', full_name='arks.ArksRequest.scene_features', index=7,
- number=8, type=11, cpp_type=10, label=3,
- has_default_value=False, default_value=[],
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='items', full_name='arks.ArksRequest.items', index=8,
- number=9, type=11, cpp_type=10, label=3,
- has_default_value=False, default_value=[],
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='is_sort', full_name='arks.ArksRequest.is_sort', index=9,
- number=10, type=8, cpp_type=7, label=1,
- has_default_value=True, default_value=True,
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='count', full_name='arks.ArksRequest.count', index=10,
- number=11, type=5, cpp_type=1, label=1,
- has_default_value=True, default_value=10,
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='out_format', full_name='arks.ArksRequest.out_format', index=11,
- number=12, type=14, cpp_type=8, label=1,
- has_default_value=True, default_value=1,
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='chain_name', full_name='arks.ArksRequest.chain_name', index=12,
- number=13, type=9, cpp_type=9, label=1,
- has_default_value=False, default_value=_b("").decode('utf-8'),
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='scm', full_name='arks.ArksRequest.scm', index=13,
- number=14, type=9, cpp_type=9, label=1,
- has_default_value=False, default_value=_b("").decode('utf-8'),
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='scene_name', full_name='arks.ArksRequest.scene_name', index=14,
- number=15, type=9, cpp_type=9, label=1,
- has_default_value=False, default_value=_b("").decode('utf-8'),
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='item_schemas', full_name='arks.ArksRequest.item_schemas', index=15,
- number=16, type=9, cpp_type=9, label=3,
- has_default_value=False, default_value=[],
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='sub_item_schemas', full_name='arks.ArksRequest.sub_item_schemas', index=16,
- number=17, type=9, cpp_type=9, label=3,
- has_default_value=False, default_value=[],
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='seek_plans', full_name='arks.ArksRequest.seek_plans', index=17,
- number=18, type=11, cpp_type=10, label=3,
- has_default_value=False, default_value=[],
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='dump_req_info', full_name='arks.ArksRequest.dump_req_info', index=18,
- number=19, type=11, cpp_type=10, label=1,
- has_default_value=False, default_value=None,
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='app_name', full_name='arks.ArksRequest.app_name', index=19,
- number=20, type=9, cpp_type=9, label=1,
- has_default_value=False, default_value=_b("").decode('utf-8'),
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='req_timeout_ms', full_name='arks.ArksRequest.req_timeout_ms', index=20,
- number=21, type=4, cpp_type=4, label=1,
- has_default_value=False, default_value=0,
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='client_version', full_name='arks.ArksRequest.client_version', index=21,
- number=22, type=9, cpp_type=9, label=1,
- has_default_value=False, default_value=_b("").decode('utf-8'),
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='ip', full_name='arks.ArksRequest.ip', index=22,
- number=23, type=9, cpp_type=9, label=1,
- has_default_value=False, default_value=_b("").decode('utf-8'),
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=None, file=DESCRIPTOR),
- ],
- extensions=[
- ],
- nested_types=[],
- enum_types=[
- ],
- serialized_options=None,
- is_extendable=False,
- syntax='proto2',
- extension_ranges=[],
- oneofs=[
- ],
- serialized_start=1540,
- serialized_end=2103,
+ name="ArksRequest",
+ full_name="arks.ArksRequest",
+ filename=None,
+ file=DESCRIPTOR,
+ containing_type=None,
+ fields=[
+ _descriptor.FieldDescriptor(
+ name="version",
+ full_name="arks.ArksRequest.version",
+ index=0,
+ number=1,
+ type=5,
+ cpp_type=1,
+ label=1,
+ has_default_value=True,
+ default_value=1,
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=None,
+ file=DESCRIPTOR,
+ ),
+ _descriptor.FieldDescriptor(
+ name="debug",
+ full_name="arks.ArksRequest.debug",
+ index=1,
+ number=2,
+ type=5,
+ cpp_type=1,
+ label=1,
+ has_default_value=False,
+ default_value=0,
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=None,
+ file=DESCRIPTOR,
+ ),
+ _descriptor.FieldDescriptor(
+ name="is_ping",
+ full_name="arks.ArksRequest.is_ping",
+ index=2,
+ number=3,
+ type=8,
+ cpp_type=7,
+ label=1,
+ has_default_value=False,
+ default_value=False,
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=None,
+ file=DESCRIPTOR,
+ ),
+ _descriptor.FieldDescriptor(
+ name="session_id",
+ full_name="arks.ArksRequest.session_id",
+ index=3,
+ number=4,
+ type=9,
+ cpp_type=9,
+ label=1,
+ has_default_value=False,
+ default_value=_b("").decode("utf-8"),
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=None,
+ file=DESCRIPTOR,
+ ),
+ _descriptor.FieldDescriptor(
+ name="bucket_name",
+ full_name="arks.ArksRequest.bucket_name",
+ index=4,
+ number=5,
+ type=9,
+ cpp_type=9,
+ label=1,
+ has_default_value=False,
+ default_value=_b("").decode("utf-8"),
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=None,
+ file=DESCRIPTOR,
+ ),
+ _descriptor.FieldDescriptor(
+ name="uid",
+ full_name="arks.ArksRequest.uid",
+ index=5,
+ number=6,
+ type=9,
+ cpp_type=9,
+ label=1,
+ has_default_value=False,
+ default_value=_b("").decode("utf-8"),
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=None,
+ file=DESCRIPTOR,
+ ),
+ _descriptor.FieldDescriptor(
+ name="user_profile",
+ full_name="arks.ArksRequest.user_profile",
+ index=6,
+ number=7,
+ type=11,
+ cpp_type=10,
+ label=3,
+ has_default_value=False,
+ default_value=[],
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=None,
+ file=DESCRIPTOR,
+ ),
+ _descriptor.FieldDescriptor(
+ name="scene_features",
+ full_name="arks.ArksRequest.scene_features",
+ index=7,
+ number=8,
+ type=11,
+ cpp_type=10,
+ label=3,
+ has_default_value=False,
+ default_value=[],
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=None,
+ file=DESCRIPTOR,
+ ),
+ _descriptor.FieldDescriptor(
+ name="items",
+ full_name="arks.ArksRequest.items",
+ index=8,
+ number=9,
+ type=11,
+ cpp_type=10,
+ label=3,
+ has_default_value=False,
+ default_value=[],
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=None,
+ file=DESCRIPTOR,
+ ),
+ _descriptor.FieldDescriptor(
+ name="is_sort",
+ full_name="arks.ArksRequest.is_sort",
+ index=9,
+ number=10,
+ type=8,
+ cpp_type=7,
+ label=1,
+ has_default_value=True,
+ default_value=True,
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=None,
+ file=DESCRIPTOR,
+ ),
+ _descriptor.FieldDescriptor(
+ name="count",
+ full_name="arks.ArksRequest.count",
+ index=10,
+ number=11,
+ type=5,
+ cpp_type=1,
+ label=1,
+ has_default_value=True,
+ default_value=10,
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=None,
+ file=DESCRIPTOR,
+ ),
+ _descriptor.FieldDescriptor(
+ name="out_format",
+ full_name="arks.ArksRequest.out_format",
+ index=11,
+ number=12,
+ type=14,
+ cpp_type=8,
+ label=1,
+ has_default_value=True,
+ default_value=1,
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=None,
+ file=DESCRIPTOR,
+ ),
+ _descriptor.FieldDescriptor(
+ name="chain_name",
+ full_name="arks.ArksRequest.chain_name",
+ index=12,
+ number=13,
+ type=9,
+ cpp_type=9,
+ label=1,
+ has_default_value=False,
+ default_value=_b("").decode("utf-8"),
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=None,
+ file=DESCRIPTOR,
+ ),
+ _descriptor.FieldDescriptor(
+ name="scm",
+ full_name="arks.ArksRequest.scm",
+ index=13,
+ number=14,
+ type=9,
+ cpp_type=9,
+ label=1,
+ has_default_value=False,
+ default_value=_b("").decode("utf-8"),
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=None,
+ file=DESCRIPTOR,
+ ),
+ _descriptor.FieldDescriptor(
+ name="scene_name",
+ full_name="arks.ArksRequest.scene_name",
+ index=14,
+ number=15,
+ type=9,
+ cpp_type=9,
+ label=1,
+ has_default_value=False,
+ default_value=_b("").decode("utf-8"),
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=None,
+ file=DESCRIPTOR,
+ ),
+ _descriptor.FieldDescriptor(
+ name="item_schemas",
+ full_name="arks.ArksRequest.item_schemas",
+ index=15,
+ number=16,
+ type=9,
+ cpp_type=9,
+ label=3,
+ has_default_value=False,
+ default_value=[],
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=None,
+ file=DESCRIPTOR,
+ ),
+ _descriptor.FieldDescriptor(
+ name="sub_item_schemas",
+ full_name="arks.ArksRequest.sub_item_schemas",
+ index=16,
+ number=17,
+ type=9,
+ cpp_type=9,
+ label=3,
+ has_default_value=False,
+ default_value=[],
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=None,
+ file=DESCRIPTOR,
+ ),
+ _descriptor.FieldDescriptor(
+ name="seek_plans",
+ full_name="arks.ArksRequest.seek_plans",
+ index=17,
+ number=18,
+ type=11,
+ cpp_type=10,
+ label=3,
+ has_default_value=False,
+ default_value=[],
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=None,
+ file=DESCRIPTOR,
+ ),
+ _descriptor.FieldDescriptor(
+ name="dump_req_info",
+ full_name="arks.ArksRequest.dump_req_info",
+ index=18,
+ number=19,
+ type=11,
+ cpp_type=10,
+ label=1,
+ has_default_value=False,
+ default_value=None,
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=None,
+ file=DESCRIPTOR,
+ ),
+ _descriptor.FieldDescriptor(
+ name="app_name",
+ full_name="arks.ArksRequest.app_name",
+ index=19,
+ number=20,
+ type=9,
+ cpp_type=9,
+ label=1,
+ has_default_value=False,
+ default_value=_b("").decode("utf-8"),
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=None,
+ file=DESCRIPTOR,
+ ),
+ _descriptor.FieldDescriptor(
+ name="req_timeout_ms",
+ full_name="arks.ArksRequest.req_timeout_ms",
+ index=20,
+ number=21,
+ type=4,
+ cpp_type=4,
+ label=1,
+ has_default_value=False,
+ default_value=0,
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=None,
+ file=DESCRIPTOR,
+ ),
+ _descriptor.FieldDescriptor(
+ name="client_version",
+ full_name="arks.ArksRequest.client_version",
+ index=21,
+ number=22,
+ type=9,
+ cpp_type=9,
+ label=1,
+ has_default_value=False,
+ default_value=_b("").decode("utf-8"),
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=None,
+ file=DESCRIPTOR,
+ ),
+ _descriptor.FieldDescriptor(
+ name="ip",
+ full_name="arks.ArksRequest.ip",
+ index=22,
+ number=23,
+ type=9,
+ cpp_type=9,
+ label=1,
+ has_default_value=False,
+ default_value=_b("").decode("utf-8"),
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=None,
+ file=DESCRIPTOR,
+ ),
+ ],
+ extensions=[],
+ nested_types=[],
+ enum_types=[],
+ serialized_options=None,
+ is_extendable=False,
+ syntax="proto2",
+ extension_ranges=[],
+ oneofs=[],
+ serialized_start=1540,
+ serialized_end=2103,
)
_ARKSRESPONSE = _descriptor.Descriptor(
- name='ArksResponse',
- full_name='arks.ArksResponse',
- filename=None,
- file=DESCRIPTOR,
- containing_type=None,
- fields=[
- _descriptor.FieldDescriptor(
- name='error_code', full_name='arks.ArksResponse.error_code', index=0,
- number=1, type=14, cpp_type=8, label=1,
- has_default_value=True, default_value=0,
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='session_id', full_name='arks.ArksResponse.session_id', index=1,
- number=2, type=9, cpp_type=9, label=1,
- has_default_value=False, default_value=_b("").decode('utf-8'),
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='bucket_name', full_name='arks.ArksResponse.bucket_name', index=2,
- number=3, type=9, cpp_type=9, label=1,
- has_default_value=False, default_value=_b("").decode('utf-8'),
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='user_profile', full_name='arks.ArksResponse.user_profile', index=3,
- number=4, type=11, cpp_type=10, label=3,
- has_default_value=False, default_value=[],
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='items', full_name='arks.ArksResponse.items', index=4,
- number=5, type=11, cpp_type=10, label=3,
- has_default_value=False, default_value=[],
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='debug_msg', full_name='arks.ArksResponse.debug_msg', index=5,
- number=6, type=9, cpp_type=9, label=1,
- has_default_value=False, default_value=_b("").decode('utf-8'),
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='scm', full_name='arks.ArksResponse.scm', index=6,
- number=7, type=9, cpp_type=9, label=1,
- has_default_value=False, default_value=_b("").decode('utf-8'),
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='seek_plans', full_name='arks.ArksResponse.seek_plans', index=7,
- number=8, type=11, cpp_type=10, label=3,
- has_default_value=False, default_value=[],
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='err_msg', full_name='arks.ArksResponse.err_msg', index=8,
- number=9, type=9, cpp_type=9, label=1,
- has_default_value=False, default_value=_b("").decode('utf-8'),
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='algo_ret', full_name='arks.ArksResponse.algo_ret', index=9,
- number=10, type=5, cpp_type=1, label=1,
- has_default_value=False, default_value=0,
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='algo_msg', full_name='arks.ArksResponse.algo_msg', index=10,
- number=11, type=9, cpp_type=9, label=1,
- has_default_value=False, default_value=_b("").decode('utf-8'),
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='trace_msg', full_name='arks.ArksResponse.trace_msg', index=11,
- number=12, type=9, cpp_type=9, label=1,
- has_default_value=False, default_value=_b("").decode('utf-8'),
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=None, file=DESCRIPTOR),
- _descriptor.FieldDescriptor(
- name='rt', full_name='arks.ArksResponse.rt', index=12,
- number=13, type=5, cpp_type=1, label=1,
- has_default_value=False, default_value=0,
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- serialized_options=None, file=DESCRIPTOR),
- ],
- extensions=[
- ],
- nested_types=[],
- enum_types=[
- ],
- serialized_options=None,
- is_extendable=False,
- syntax='proto2',
- extension_ranges=[],
- oneofs=[
- ],
- serialized_start=2106,
- serialized_end=2420,
+ name="ArksResponse",
+ full_name="arks.ArksResponse",
+ filename=None,
+ file=DESCRIPTOR,
+ containing_type=None,
+ fields=[
+ _descriptor.FieldDescriptor(
+ name="error_code",
+ full_name="arks.ArksResponse.error_code",
+ index=0,
+ number=1,
+ type=14,
+ cpp_type=8,
+ label=1,
+ has_default_value=True,
+ default_value=0,
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=None,
+ file=DESCRIPTOR,
+ ),
+ _descriptor.FieldDescriptor(
+ name="session_id",
+ full_name="arks.ArksResponse.session_id",
+ index=1,
+ number=2,
+ type=9,
+ cpp_type=9,
+ label=1,
+ has_default_value=False,
+ default_value=_b("").decode("utf-8"),
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=None,
+ file=DESCRIPTOR,
+ ),
+ _descriptor.FieldDescriptor(
+ name="bucket_name",
+ full_name="arks.ArksResponse.bucket_name",
+ index=2,
+ number=3,
+ type=9,
+ cpp_type=9,
+ label=1,
+ has_default_value=False,
+ default_value=_b("").decode("utf-8"),
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=None,
+ file=DESCRIPTOR,
+ ),
+ _descriptor.FieldDescriptor(
+ name="user_profile",
+ full_name="arks.ArksResponse.user_profile",
+ index=3,
+ number=4,
+ type=11,
+ cpp_type=10,
+ label=3,
+ has_default_value=False,
+ default_value=[],
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=None,
+ file=DESCRIPTOR,
+ ),
+ _descriptor.FieldDescriptor(
+ name="items",
+ full_name="arks.ArksResponse.items",
+ index=4,
+ number=5,
+ type=11,
+ cpp_type=10,
+ label=3,
+ has_default_value=False,
+ default_value=[],
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=None,
+ file=DESCRIPTOR,
+ ),
+ _descriptor.FieldDescriptor(
+ name="debug_msg",
+ full_name="arks.ArksResponse.debug_msg",
+ index=5,
+ number=6,
+ type=9,
+ cpp_type=9,
+ label=1,
+ has_default_value=False,
+ default_value=_b("").decode("utf-8"),
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=None,
+ file=DESCRIPTOR,
+ ),
+ _descriptor.FieldDescriptor(
+ name="scm",
+ full_name="arks.ArksResponse.scm",
+ index=6,
+ number=7,
+ type=9,
+ cpp_type=9,
+ label=1,
+ has_default_value=False,
+ default_value=_b("").decode("utf-8"),
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=None,
+ file=DESCRIPTOR,
+ ),
+ _descriptor.FieldDescriptor(
+ name="seek_plans",
+ full_name="arks.ArksResponse.seek_plans",
+ index=7,
+ number=8,
+ type=11,
+ cpp_type=10,
+ label=3,
+ has_default_value=False,
+ default_value=[],
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=None,
+ file=DESCRIPTOR,
+ ),
+ _descriptor.FieldDescriptor(
+ name="err_msg",
+ full_name="arks.ArksResponse.err_msg",
+ index=8,
+ number=9,
+ type=9,
+ cpp_type=9,
+ label=1,
+ has_default_value=False,
+ default_value=_b("").decode("utf-8"),
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=None,
+ file=DESCRIPTOR,
+ ),
+ _descriptor.FieldDescriptor(
+ name="algo_ret",
+ full_name="arks.ArksResponse.algo_ret",
+ index=9,
+ number=10,
+ type=5,
+ cpp_type=1,
+ label=1,
+ has_default_value=False,
+ default_value=0,
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=None,
+ file=DESCRIPTOR,
+ ),
+ _descriptor.FieldDescriptor(
+ name="algo_msg",
+ full_name="arks.ArksResponse.algo_msg",
+ index=10,
+ number=11,
+ type=9,
+ cpp_type=9,
+ label=1,
+ has_default_value=False,
+ default_value=_b("").decode("utf-8"),
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=None,
+ file=DESCRIPTOR,
+ ),
+ _descriptor.FieldDescriptor(
+ name="trace_msg",
+ full_name="arks.ArksResponse.trace_msg",
+ index=11,
+ number=12,
+ type=9,
+ cpp_type=9,
+ label=1,
+ has_default_value=False,
+ default_value=_b("").decode("utf-8"),
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=None,
+ file=DESCRIPTOR,
+ ),
+ _descriptor.FieldDescriptor(
+ name="rt",
+ full_name="arks.ArksResponse.rt",
+ index=12,
+ number=13,
+ type=5,
+ cpp_type=1,
+ label=1,
+ has_default_value=False,
+ default_value=0,
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ serialized_options=None,
+ file=DESCRIPTOR,
+ ),
+ ],
+ extensions=[],
+ nested_types=[],
+ enum_types=[],
+ serialized_options=None,
+ is_extendable=False,
+ syntax="proto2",
+ extension_ranges=[],
+ oneofs=[],
+ serialized_start=2106,
+ serialized_end=2420,
)
-_INFERTENSORCONTENTS.fields_by_name['type'].enum_type = _CONTENTTYPE
-_PAIR.fields_by_name['contents'].message_type = _INFERTENSORCONTENTS
-_ROWKEY.fields_by_name['partitions'].message_type = _PARTITIONINFO
-_ROWKEY.fields_by_name['realtime_features'].message_type = _PAIR
-_ITEM.fields_by_name['features'].message_type = _PAIR
-_ITEM.fields_by_name['attributes'].message_type = _PAIR
-_ITEM.fields_by_name['sub_items'].message_type = _SUBITEM
-_SUBITEM.fields_by_name['features'].message_type = _PAIR
-_SUBITEM.fields_by_name['attributes'].message_type = _PAIR
-_SEEKPLAN.fields_by_name['row_keys'].message_type = _ROWKEY
-_SEEKPLAN.fields_by_name['item_sequence_type'].enum_type = _ITEMSEQUENCETYPE
-_SEEKPLAN.fields_by_name['missing_values'].message_type = _PAIR
-_ARKSREQUEST.fields_by_name['user_profile'].message_type = _PAIR
-_ARKSREQUEST.fields_by_name['scene_features'].message_type = _PAIR
-_ARKSREQUEST.fields_by_name['items'].message_type = _ITEM
-_ARKSREQUEST.fields_by_name['out_format'].enum_type = _OUTPUTFORMATTYPE
-_ARKSREQUEST.fields_by_name['seek_plans'].message_type = _SEEKPLAN
-_ARKSREQUEST.fields_by_name['dump_req_info'].message_type = _DUMPREQINFO
-_ARKSRESPONSE.fields_by_name['error_code'].enum_type = _ERRORCODE
-_ARKSRESPONSE.fields_by_name['user_profile'].message_type = _PAIR
-_ARKSRESPONSE.fields_by_name['items'].message_type = _ITEM
-_ARKSRESPONSE.fields_by_name['seek_plans'].message_type = _SEEKPLAN
-DESCRIPTOR.message_types_by_name['InferTensorContents'] = _INFERTENSORCONTENTS
-DESCRIPTOR.message_types_by_name['Pair'] = _PAIR
-DESCRIPTOR.message_types_by_name['RowKey'] = _ROWKEY
-DESCRIPTOR.message_types_by_name['PartitionInfo'] = _PARTITIONINFO
-DESCRIPTOR.message_types_by_name['Item'] = _ITEM
-DESCRIPTOR.message_types_by_name['SubItem'] = _SUBITEM
-DESCRIPTOR.message_types_by_name['SeekPlan'] = _SEEKPLAN
-DESCRIPTOR.message_types_by_name['DumpReqInfo'] = _DUMPREQINFO
-DESCRIPTOR.message_types_by_name['ArksRequest'] = _ARKSREQUEST
-DESCRIPTOR.message_types_by_name['ArksResponse'] = _ARKSRESPONSE
-DESCRIPTOR.enum_types_by_name['OutputFormatType'] = _OUTPUTFORMATTYPE
-DESCRIPTOR.enum_types_by_name['ErrorCode'] = _ERRORCODE
-DESCRIPTOR.enum_types_by_name['ContentType'] = _CONTENTTYPE
-DESCRIPTOR.enum_types_by_name['ItemSequenceType'] = _ITEMSEQUENCETYPE
+_INFERTENSORCONTENTS.fields_by_name["type"].enum_type = _CONTENTTYPE
+_PAIR.fields_by_name["contents"].message_type = _INFERTENSORCONTENTS
+_ROWKEY.fields_by_name["partitions"].message_type = _PARTITIONINFO
+_ROWKEY.fields_by_name["realtime_features"].message_type = _PAIR
+_ITEM.fields_by_name["features"].message_type = _PAIR
+_ITEM.fields_by_name["attributes"].message_type = _PAIR
+_ITEM.fields_by_name["sub_items"].message_type = _SUBITEM
+_SUBITEM.fields_by_name["features"].message_type = _PAIR
+_SUBITEM.fields_by_name["attributes"].message_type = _PAIR
+_SEEKPLAN.fields_by_name["row_keys"].message_type = _ROWKEY
+_SEEKPLAN.fields_by_name["item_sequence_type"].enum_type = _ITEMSEQUENCETYPE
+_SEEKPLAN.fields_by_name["missing_values"].message_type = _PAIR
+_ARKSREQUEST.fields_by_name["user_profile"].message_type = _PAIR
+_ARKSREQUEST.fields_by_name["scene_features"].message_type = _PAIR
+_ARKSREQUEST.fields_by_name["items"].message_type = _ITEM
+_ARKSREQUEST.fields_by_name["out_format"].enum_type = _OUTPUTFORMATTYPE
+_ARKSREQUEST.fields_by_name["seek_plans"].message_type = _SEEKPLAN
+_ARKSREQUEST.fields_by_name["dump_req_info"].message_type = _DUMPREQINFO
+_ARKSRESPONSE.fields_by_name["error_code"].enum_type = _ERRORCODE
+_ARKSRESPONSE.fields_by_name["user_profile"].message_type = _PAIR
+_ARKSRESPONSE.fields_by_name["items"].message_type = _ITEM
+_ARKSRESPONSE.fields_by_name["seek_plans"].message_type = _SEEKPLAN
+DESCRIPTOR.message_types_by_name["InferTensorContents"] = _INFERTENSORCONTENTS
+DESCRIPTOR.message_types_by_name["Pair"] = _PAIR
+DESCRIPTOR.message_types_by_name["RowKey"] = _ROWKEY
+DESCRIPTOR.message_types_by_name["PartitionInfo"] = _PARTITIONINFO
+DESCRIPTOR.message_types_by_name["Item"] = _ITEM
+DESCRIPTOR.message_types_by_name["SubItem"] = _SUBITEM
+DESCRIPTOR.message_types_by_name["SeekPlan"] = _SEEKPLAN
+DESCRIPTOR.message_types_by_name["DumpReqInfo"] = _DUMPREQINFO
+DESCRIPTOR.message_types_by_name["ArksRequest"] = _ARKSREQUEST
+DESCRIPTOR.message_types_by_name["ArksResponse"] = _ARKSRESPONSE
+DESCRIPTOR.enum_types_by_name["OutputFormatType"] = _OUTPUTFORMATTYPE
+DESCRIPTOR.enum_types_by_name["ErrorCode"] = _ERRORCODE
+DESCRIPTOR.enum_types_by_name["ContentType"] = _CONTENTTYPE
+DESCRIPTOR.enum_types_by_name["ItemSequenceType"] = _ITEMSEQUENCETYPE
_sym_db.RegisterFileDescriptor(DESCRIPTOR)
-InferTensorContents = _reflection.GeneratedProtocolMessageType('InferTensorContents', (_message.Message,), dict(
- DESCRIPTOR = _INFERTENSORCONTENTS,
- __module__ = 'arks_pb2'
- # @@protoc_insertion_point(class_scope:arks.InferTensorContents)
- ))
+InferTensorContents = _reflection.GeneratedProtocolMessageType(
+ "InferTensorContents",
+ (_message.Message,),
+ dict(
+ DESCRIPTOR=_INFERTENSORCONTENTS,
+ __module__="arks_pb2"
+ # @@protoc_insertion_point(class_scope:arks.InferTensorContents)
+ ),
+)
_sym_db.RegisterMessage(InferTensorContents)
-Pair = _reflection.GeneratedProtocolMessageType('Pair', (_message.Message,), dict(
- DESCRIPTOR = _PAIR,
- __module__ = 'arks_pb2'
- # @@protoc_insertion_point(class_scope:arks.Pair)
- ))
+Pair = _reflection.GeneratedProtocolMessageType(
+ "Pair",
+ (_message.Message,),
+ dict(
+ DESCRIPTOR=_PAIR,
+ __module__="arks_pb2"
+ # @@protoc_insertion_point(class_scope:arks.Pair)
+ ),
+)
_sym_db.RegisterMessage(Pair)
-RowKey = _reflection.GeneratedProtocolMessageType('RowKey', (_message.Message,), dict(
- DESCRIPTOR = _ROWKEY,
- __module__ = 'arks_pb2'
- # @@protoc_insertion_point(class_scope:arks.RowKey)
- ))
+RowKey = _reflection.GeneratedProtocolMessageType(
+ "RowKey",
+ (_message.Message,),
+ dict(
+ DESCRIPTOR=_ROWKEY,
+ __module__="arks_pb2"
+ # @@protoc_insertion_point(class_scope:arks.RowKey)
+ ),
+)
_sym_db.RegisterMessage(RowKey)
-PartitionInfo = _reflection.GeneratedProtocolMessageType('PartitionInfo', (_message.Message,), dict(
- DESCRIPTOR = _PARTITIONINFO,
- __module__ = 'arks_pb2'
- # @@protoc_insertion_point(class_scope:arks.PartitionInfo)
- ))
+PartitionInfo = _reflection.GeneratedProtocolMessageType(
+ "PartitionInfo",
+ (_message.Message,),
+ dict(
+ DESCRIPTOR=_PARTITIONINFO,
+ __module__="arks_pb2"
+ # @@protoc_insertion_point(class_scope:arks.PartitionInfo)
+ ),
+)
_sym_db.RegisterMessage(PartitionInfo)
-Item = _reflection.GeneratedProtocolMessageType('Item', (_message.Message,), dict(
- DESCRIPTOR = _ITEM,
- __module__ = 'arks_pb2'
- # @@protoc_insertion_point(class_scope:arks.Item)
- ))
+Item = _reflection.GeneratedProtocolMessageType(
+ "Item",
+ (_message.Message,),
+ dict(
+ DESCRIPTOR=_ITEM,
+ __module__="arks_pb2"
+ # @@protoc_insertion_point(class_scope:arks.Item)
+ ),
+)
_sym_db.RegisterMessage(Item)
-SubItem = _reflection.GeneratedProtocolMessageType('SubItem', (_message.Message,), dict(
- DESCRIPTOR = _SUBITEM,
- __module__ = 'arks_pb2'
- # @@protoc_insertion_point(class_scope:arks.SubItem)
- ))
+SubItem = _reflection.GeneratedProtocolMessageType(
+ "SubItem",
+ (_message.Message,),
+ dict(
+ DESCRIPTOR=_SUBITEM,
+ __module__="arks_pb2"
+ # @@protoc_insertion_point(class_scope:arks.SubItem)
+ ),
+)
_sym_db.RegisterMessage(SubItem)
-SeekPlan = _reflection.GeneratedProtocolMessageType('SeekPlan', (_message.Message,), dict(
- DESCRIPTOR = _SEEKPLAN,
- __module__ = 'arks_pb2'
- # @@protoc_insertion_point(class_scope:arks.SeekPlan)
- ))
+SeekPlan = _reflection.GeneratedProtocolMessageType(
+ "SeekPlan",
+ (_message.Message,),
+ dict(
+ DESCRIPTOR=_SEEKPLAN,
+ __module__="arks_pb2"
+ # @@protoc_insertion_point(class_scope:arks.SeekPlan)
+ ),
+)
_sym_db.RegisterMessage(SeekPlan)
-DumpReqInfo = _reflection.GeneratedProtocolMessageType('DumpReqInfo', (_message.Message,), dict(
- DESCRIPTOR = _DUMPREQINFO,
- __module__ = 'arks_pb2'
- # @@protoc_insertion_point(class_scope:arks.DumpReqInfo)
- ))
+DumpReqInfo = _reflection.GeneratedProtocolMessageType(
+ "DumpReqInfo",
+ (_message.Message,),
+ dict(
+ DESCRIPTOR=_DUMPREQINFO,
+ __module__="arks_pb2"
+ # @@protoc_insertion_point(class_scope:arks.DumpReqInfo)
+ ),
+)
_sym_db.RegisterMessage(DumpReqInfo)
-ArksRequest = _reflection.GeneratedProtocolMessageType('ArksRequest', (_message.Message,), dict(
- DESCRIPTOR = _ARKSREQUEST,
- __module__ = 'arks_pb2'
- # @@protoc_insertion_point(class_scope:arks.ArksRequest)
- ))
+ArksRequest = _reflection.GeneratedProtocolMessageType(
+ "ArksRequest",
+ (_message.Message,),
+ dict(
+ DESCRIPTOR=_ARKSREQUEST,
+ __module__="arks_pb2"
+ # @@protoc_insertion_point(class_scope:arks.ArksRequest)
+ ),
+)
_sym_db.RegisterMessage(ArksRequest)
-ArksResponse = _reflection.GeneratedProtocolMessageType('ArksResponse', (_message.Message,), dict(
- DESCRIPTOR = _ARKSRESPONSE,
- __module__ = 'arks_pb2'
- # @@protoc_insertion_point(class_scope:arks.ArksResponse)
- ))
+ArksResponse = _reflection.GeneratedProtocolMessageType(
+ "ArksResponse",
+ (_message.Message,),
+ dict(
+ DESCRIPTOR=_ARKSRESPONSE,
+ __module__="arks_pb2"
+ # @@protoc_insertion_point(class_scope:arks.ArksResponse)
+ ),
+)
_sym_db.RegisterMessage(ArksResponse)
DESCRIPTOR._options = None
-_ITEM.fields_by_name['is_features_valid']._options = None
-_ITEM.fields_by_name['scores']._options = None
-_SUBITEM.fields_by_name['is_features_valid']._options = None
-_SUBITEM.fields_by_name['scores']._options = None
+_ITEM.fields_by_name["is_features_valid"]._options = None
+_ITEM.fields_by_name["scores"]._options = None
+_SUBITEM.fields_by_name["is_features_valid"]._options = None
+_SUBITEM.fields_by_name["scores"]._options = None
# @@protoc_insertion_point(module_scope)
diff --git a/kag/common/base/prompt_op.py b/kag/common/base/prompt_op.py
index 057e35bf..946b5abb 100644
--- a/kag/common/base/prompt_op.py
+++ b/kag/common/base/prompt_op.py
@@ -73,9 +73,9 @@ def template_variables(self) -> List[str]:
)
def process_template_string_to_avoid_dollar_problem(self, template_string):
- new_template_str = template_string.replace('$', '$$')
+ new_template_str = template_string.replace("$", "$$")
for var in self.template_variables:
- new_template_str = new_template_str.replace(f'$${var}', f'${var}')
+ new_template_str = new_template_str.replace(f"$${var}", f"${var}")
return new_template_str
def build_prompt(self, variables) -> str:
@@ -93,7 +93,9 @@ def build_prompt(self, variables) -> str:
"""
self.template_variables_value = variables
- template_string = self.process_template_string_to_avoid_dollar_problem(self.template)
+ template_string = self.process_template_string_to_avoid_dollar_problem(
+ self.template
+ )
template = Template(template_string)
return template.substitute(**variables)
@@ -134,10 +136,10 @@ def load(cls, biz_scene: str, type: str):
os.path.join(os.getenv("KAG_PROJECT_ROOT_PATH", ""), "solver", "prompt"),
]
module_paths = [
- '.'.join([BUILDER_PROMPT_PATH, biz_scene, type]),
- '.'.join([SOLVER_PROMPT_PATH, biz_scene, type]),
- '.'.join([BUILDER_PROMPT_PATH, 'default', type]),
- '.'.join([SOLVER_PROMPT_PATH, 'default', type]),
+ ".".join([BUILDER_PROMPT_PATH, biz_scene, type]),
+ ".".join([SOLVER_PROMPT_PATH, biz_scene, type]),
+ ".".join([BUILDER_PROMPT_PATH, "default", type]),
+ ".".join([SOLVER_PROMPT_PATH, "default", type]),
]
def find_class_from_dir(dir, type):
@@ -160,7 +162,11 @@ def find_class_from_module(module):
classes = inspect.getmembers(module, inspect.isclass)
for class_name, class_obj in classes:
import kag
- if issubclass(class_obj, kag.common.base.prompt_op.PromptOp) and inspect.getmodule(class_obj) == module:
+
+ if (
+ issubclass(class_obj, kag.common.base.prompt_op.PromptOp)
+ and inspect.getmodule(class_obj) == module
+ ):
return class_obj
return None
@@ -181,4 +187,6 @@ def find_class_from_module(module):
except ModuleNotFoundError:
continue
- raise ValueError(f'Not support prompt with biz_scene[{biz_scene}] and type[{type}]')
+ raise ValueError(
+ f"Not support prompt with biz_scene[{biz_scene}] and type[{type}]"
+ )
diff --git a/kag/common/benchmarks/evaUtils.py b/kag/common/benchmarks/evaUtils.py
index f443e8a0..69547517 100644
--- a/kag/common/benchmarks/evaUtils.py
+++ b/kag/common/benchmarks/evaUtils.py
@@ -17,15 +17,16 @@ def normalize_answer(s):
Returns:
str: The standardized answer string.
"""
+
def remove_articles(text):
- return re.sub(r'\b(a|an|the)\b', ' ', text)
+ return re.sub(r"\b(a|an|the)\b", " ", text)
def white_space_fix(text):
- return ' '.join(text.split())
+ return " ".join(text.split())
def remove_punc(text):
exclude = set(string.punctuation)
- return ''.join(ch for ch in text if ch not in exclude)
+ return "".join(ch for ch in text if ch not in exclude)
def lower(text):
return str(text).lower()
@@ -52,10 +53,16 @@ def f1_score(prediction, ground_truth):
ZERO_METRIC = (0, 0, 0)
- if normalized_prediction in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth:
+ if (
+ normalized_prediction in ["yes", "no", "noanswer"]
+ and normalized_prediction != normalized_ground_truth
+ ):
return ZERO_METRIC
- if normalized_ground_truth in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth:
+ if (
+ normalized_ground_truth in ["yes", "no", "noanswer"]
+ and normalized_prediction != normalized_ground_truth
+ ):
return ZERO_METRIC
prediction_tokens = normalized_prediction.split()
@@ -78,35 +85,36 @@ def f1_score(prediction, ground_truth):
def exact_match_score(prediction, ground_truth):
"""
Calculates the exact match score between a predicted answer and the ground truth answer.
-
+
This function normalizes both the predicted answer and the ground truth answer before comparing them.
Normalization is performed to ensure that non-essential differences such as spaces and case are ignored.
-
+
Parameters:
prediction (str): The predicted answer string.
ground_truth (str): The ground truth answer string.
-
+
Returns:
int: 1 if the predicted answer exactly matches the ground truth answer, otherwise 0.
"""
return 1 if normalize_answer(prediction) == normalize_answer(ground_truth) else 0
+
def get_em_f1(prediction, gold):
"""
Calculates the Exact Match (EM) score and F1 score between the prediction and the gold standard.
-
+
This function evaluates the performance of a model in text similarity tasks by calculating the EM score and F1 score to measure the accuracy of the predictions.
-
+
Parameters:
prediction (str): The output predicted by the model.
gold (str): The gold standard output (i.e., the correct output).
-
+
Returns:
tuple: A tuple containing two floats, the EM score and the F1 score. The EM score represents the exact match accuracy, while the F1 score is a combination of precision and recall.
"""
em = exact_match_score(prediction, gold)
f1, precision, recall = f1_score(prediction, gold)
-
- return float(em), f1
\ No newline at end of file
+
+ return float(em), f1
diff --git a/kag/common/benchmarks/evaluate.py b/kag/common/benchmarks/evaluate.py
index 4b920f93..8c525d27 100644
--- a/kag/common/benchmarks/evaluate.py
+++ b/kag/common/benchmarks/evaluate.py
@@ -1,22 +1,22 @@
-
from typing import List
from .evaUtils import get_em_f1
-class Evaluate():
+class Evaluate:
"""
provide evaluation for benchmarks, such as em、f1、answer_similarity, answer_correctness
"""
- def __init__(self, embedding_factory = "text-embedding-ada-002"):
+
+ def __init__(self, embedding_factory="text-embedding-ada-002"):
self.embedding_factory = embedding_factory
def evaForSimilarity(self, predictionlist: List[str], goldlist: List[str]):
"""
evaluate the similarity between prediction and gold #TODO
"""
- # data_samples = {
+ # data_samples = {
# 'question': [],
# 'answer': predictionlist,
# 'ground_truth': goldlist
@@ -29,7 +29,6 @@ def evaForSimilarity(self, predictionlist: List[str], goldlist: List[str]):
# return np.average(score.to_pandas()[['answer_similarity']])
return 0.0
-
def getBenchMark(self, predictionlist: List[str], goldlist: List[str]):
"""
Calculates and returns evaluation metrics between predictions and ground truths.
@@ -45,21 +44,24 @@ def getBenchMark(self, predictionlist: List[str], goldlist: List[str]):
dict: Dictionary containing EM, F1 score, and answer similarity.
"""
# Initialize total metrics
- total_metrics = {'em': 0.0, 'f1': 0.0, 'answer_similarity': 0.0}
-
+ total_metrics = {"em": 0.0, "f1": 0.0, "answer_similarity": 0.0}
+
# Iterate over prediction and gold lists to calculate EM and F1 scores
for prediction, gold in zip(predictionlist, goldlist):
- em, f1 = get_em_f1(prediction, gold) # Call external function to calculate EM and F1
- total_metrics['em'] += em # Accumulate EM score
- total_metrics['f1'] += f1 # Accumulate F1 score
-
+ em, f1 = get_em_f1(
+ prediction, gold
+ ) # Call external function to calculate EM and F1
+ total_metrics["em"] += em # Accumulate EM score
+ total_metrics["f1"] += f1 # Accumulate F1 score
+
# Calculate average EM and F1 scores
- total_metrics['em'] /= len(predictionlist)
- total_metrics['f1'] /= len(predictionlist)
-
+ total_metrics["em"] /= len(predictionlist)
+ total_metrics["f1"] /= len(predictionlist)
+
# Call method to calculate answer similarity
- total_metrics['answer_similarity'] = self.evaForSimilarity(predictionlist, goldlist)
+ total_metrics["answer_similarity"] = self.evaForSimilarity(
+ predictionlist, goldlist
+ )
# Return evaluation metrics dictionary
return total_metrics
-
diff --git a/kag/common/env.py b/kag/common/env.py
index 33904c93..62839159 100644
--- a/kag/common/env.py
+++ b/kag/common/env.py
@@ -108,7 +108,8 @@ def init_kag_config(config_path: Union[str, Path] = None):
kag_cfg_server_side = ProjectClient(host_addr=host_addr).get_config(
int(project_id)
)
- except:
+ except Exception as e:
+ print(f"Failed to get configuration from server, info: {e}")
kag_cfg_server_side = {}
for section in kag_cfg.sections():
@@ -134,8 +135,3 @@ def init_kag_config(config_path: Union[str, Path] = None):
logging.getLogger("neo4j.notifications").setLevel(logging.ERROR)
logging.getLogger("neo4j.io").setLevel(logging.INFO)
logging.getLogger("neo4j.pool").setLevel(logging.INFO)
-
-
-def merge_server_kag_config():
-
- config = ProjectClient().get_config(self.project_id)
diff --git a/kag/common/graphstore/graph_store.py b/kag/common/graphstore/graph_store.py
index 8877ad2b..1cc65f83 100644
--- a/kag/common/graphstore/graph_store.py
+++ b/kag/common/graphstore/graph_store.py
@@ -49,7 +49,9 @@ def upsert_node(self, label, properties, id_key="id", extra_labels=("Entity",)):
pass
@abstractmethod
- def upsert_nodes(self, label, properties_list, id_key="id", extra_labels=("Entity",)):
+ def upsert_nodes(
+ self, label, properties_list, id_key="id", extra_labels=("Entity",)
+ ):
"""
Insert or update multiple nodes.
@@ -112,10 +114,18 @@ def delete_nodes(self, label, id_values, id_key="id"):
pass
@abstractmethod
- def upsert_relationship(self, start_node_label, start_node_id_value,
- end_node_label, end_node_id_value,
- rel_type, properties, upsert_nodes=True,
- start_node_id_key="id", end_node_id_key="id"):
+ def upsert_relationship(
+ self,
+ start_node_label,
+ start_node_id_value,
+ end_node_label,
+ end_node_id_value,
+ rel_type,
+ properties,
+ upsert_nodes=True,
+ start_node_id_key="id",
+ end_node_id_key="id",
+ ):
"""
Insert or update a relationship.
@@ -133,9 +143,16 @@ def upsert_relationship(self, start_node_label, start_node_id_value,
pass
@abstractmethod
- def upsert_relationships(self, start_node_label, end_node_label, rel_type,
- relationships, upsert_nodes=True, start_node_id_key="id",
- end_node_id_key="id"):
+ def upsert_relationships(
+ self,
+ start_node_label,
+ end_node_label,
+ rel_type,
+ relationships,
+ upsert_nodes=True,
+ start_node_id_key="id",
+ end_node_id_key="id",
+ ):
"""
Insert or update multiple relationships.
@@ -151,9 +168,16 @@ def upsert_relationships(self, start_node_label, end_node_label, rel_type,
pass
@abstractmethod
- def delete_relationship(self, start_node_label, start_node_id_value,
- end_node_label, end_node_id_value,
- rel_type, start_node_id_key="id", end_node_id_key="id"):
+ def delete_relationship(
+ self,
+ start_node_label,
+ start_node_id_value,
+ end_node_label,
+ end_node_id_value,
+ rel_type,
+ start_node_id_key="id",
+ end_node_id_key="id",
+ ):
"""
Delete a specified relationship.
@@ -169,9 +193,16 @@ def delete_relationship(self, start_node_label, start_node_id_value,
pass
@abstractmethod
- def delete_relationships(self, start_node_label, start_node_id_values,
- end_node_label, end_node_id_values, rel_type,
- start_node_id_key="id", end_node_id_key="id"):
+ def delete_relationships(
+ self,
+ start_node_label,
+ start_node_id_values,
+ end_node_label,
+ end_node_id_values,
+ rel_type,
+ start_node_id_key="id",
+ end_node_id_key="id",
+ ):
"""
Delete multiple relationships.
@@ -211,9 +242,16 @@ def create_text_index(self, labels, property_keys, index_name=None):
pass
@abstractmethod
- def create_vector_index(self, label, property_key, index_name=None,
- vector_dimensions=768, metric_type="cosine",
- hnsw_m=None, hnsw_ef_construction=None):
+ def create_vector_index(
+ self,
+ label,
+ property_key,
+ index_name=None,
+ vector_dimensions=768,
+ metric_type="cosine",
+ hnsw_m=None,
+ hnsw_ef_construction=None,
+ ):
"""
Create a vector index.
@@ -239,7 +277,9 @@ def delete_index(self, index_name):
pass
@abstractmethod
- def text_search(self, query_string, label_constraints=None, topk=10, index_name=None):
+ def text_search(
+ self, query_string, label_constraints=None, topk=10, index_name=None
+ ):
"""
Perform a text search.
@@ -255,7 +295,15 @@ def text_search(self, query_string, label_constraints=None, topk=10, index_name=
pass
@abstractmethod
- def vector_search(self, label, property_key, query_text_or_vector, topk=10, index_name=None, ef_search=None):
+ def vector_search(
+ self,
+ label,
+ property_key,
+ query_text_or_vector,
+ topk=10,
+ index_name=None,
+ ef_search=None,
+ ):
"""
Perform a vector search.
diff --git a/kag/common/graphstore/neo4j_graph_store.py b/kag/common/graphstore/neo4j_graph_store.py
index 33b46d9d..97bd5c47 100644
--- a/kag/common/graphstore/neo4j_graph_store.py
+++ b/kag/common/graphstore/neo4j_graph_store.py
@@ -10,7 +10,6 @@
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import logging
-import os
import re
import threading
import time
@@ -25,18 +24,20 @@
logger = logging.getLogger(__name__)
+
class SingletonMeta(ABCMeta):
"""
Thread-safe Singleton metaclass
"""
+
_instances = {}
_lock = threading.Lock()
def __call__(cls, *args, **kwargs):
- uri = kwargs.get('uri')
- user = kwargs.get('user')
- password = kwargs.get('password')
- database = kwargs.get('database', 'neo4j')
+ uri = kwargs.get("uri")
+ user = kwargs.get("user")
+ password = kwargs.get("password")
+ database = kwargs.get("database", "neo4j")
key = (cls, uri, user, password, database)
with cls._lock:
@@ -46,12 +47,19 @@ def __call__(cls, *args, **kwargs):
class Neo4jClient(GraphStore, metaclass=SingletonMeta):
-
- def __init__(self, uri, user, password, database="neo4j", init_type="write", interval_minutes=10):
+ def __init__(
+ self,
+ uri,
+ user,
+ password,
+ database="neo4j",
+ init_type="write",
+ interval_minutes=10,
+ ):
self._driver = GraphDatabase.driver(uri, auth=(user, password))
logger.info(f"init Neo4jClient uri: {uri} database: {database}")
self._database = database
- self._lucene_special_chars = "\\+-!():^[]\"{}~*?|&/"
+ self._lucene_special_chars = '\\+-!():^[]"{}~*?|&/'
self._lucene_pattern = self._get_lucene_pattern()
self._simple_ident = "[A-Za-z_][A-Za-z0-9_]*"
self._simple_ident_pattern = re.compile(self._simple_ident)
@@ -71,14 +79,16 @@ def close(self):
self._driver.close()
def schedule_constraint(self, interval_minutes):
-
def job():
try:
self._labels = self._create_unique_constraint()
self._update_pagerank_graph()
except Exception as e:
import traceback
- logger.error(f"Error run scheduled job: {traceback.format_exc()}")
+
+ logger.error(
+ f"Error run scheduled job, info: {e},\ntraceback:\n {traceback.format_exc()}"
+ )
def run_scheduled_tasks():
while True:
@@ -116,7 +126,9 @@ def _create_unique_index_constraint(self, label, session):
try:
result = session.run(create_constraint_query)
result.consume()
- logger.debug(f"Unique constraint created for constraint_name: {constraint_name}")
+ logger.debug(
+ f"Unique constraint created for constraint_name: {constraint_name}"
+ )
except Exception as e:
logger.debug(f"warn creating constraint for {constraint_name}: {e}")
self._create_index_constraint(self, label, session)
@@ -186,7 +198,12 @@ def _collect_text_index_info(self, schema_types):
label_property_keys = {}
for property_key in properties:
index_type = properties[property_key].index_type
- if property_key == "name" or index_type and index_type in (IndexTypeEnum.Text, IndexTypeEnum.TextAndVector):
+ if (
+ property_key == "name"
+ or index_type
+ and index_type
+ in (IndexTypeEnum.Text, IndexTypeEnum.TextAndVector)
+ ):
label_property_keys[property_key] = True
if label_property_keys:
labels[label] = True
@@ -199,9 +216,13 @@ def upsert_node(self, label, properties, id_key="id", extra_labels=("Entity",)):
if label not in self._labels:
self._create_unique_index_constraint(self, label, session)
try:
- return session.execute_write(self._upsert_node, self, label, id_key, properties, extra_labels)
+ return session.execute_write(
+ self._upsert_node, self, label, id_key, properties, extra_labels
+ )
except Exception as e:
- logger.error(f"upsert_node label:{label} properties:{properties} Exception: {e}")
+ logger.error(
+ f"upsert_node label:{label} properties:{properties} Exception: {e}"
+ )
return None
@staticmethod
@@ -209,23 +230,36 @@ def _upsert_node(tx, self, label, id_key, properties, extra_labels):
if not label:
logger.warning("label cannot be None or empty strings")
return None
- query = (f"MERGE (n:{self._escape_neo4j(label)} {{{self._escape_neo4j(id_key)}: $properties.{self._escape_neo4j(id_key)}}}) "
- "SET n += $properties ")
+ query = (
+ f"MERGE (n:{self._escape_neo4j(label)} {{{self._escape_neo4j(id_key)}: $properties.{self._escape_neo4j(id_key)}}}) "
+ "SET n += $properties "
+ )
if extra_labels:
query += f", n:{':'.join(self._escape_neo4j(extra_label) for extra_label in extra_labels)} "
query += "RETURN n"
result = tx.run(query, properties=properties)
return result.single()[0]
- def upsert_nodes(self, label, properties_list, id_key="id", extra_labels=("Entity",)):
+ def upsert_nodes(
+ self, label, properties_list, id_key="id", extra_labels=("Entity",)
+ ):
self._preprocess_node_properties_list(label, properties_list, extra_labels)
with self._driver.session(database=self._database) as session:
if label not in self._labels:
self._create_unique_index_constraint(self, label, session)
try:
- return session.execute_write(self._upsert_nodes, self, label, properties_list, id_key, extra_labels)
+ return session.execute_write(
+ self._upsert_nodes,
+ self,
+ label,
+ properties_list,
+ id_key,
+ extra_labels,
+ )
except Exception as e:
- logger.error(f"upsert_nodes label:{label} properties:{properties_list} Exception: {e}")
+ logger.error(
+ f"upsert_nodes label:{label} properties:{properties_list} Exception: {e}"
+ )
return None
@staticmethod
@@ -233,14 +267,16 @@ def _upsert_nodes(tx, self, label, properties_list, id_key, extra_labels):
if not label:
logger.warning("label cannot be None or empty strings")
return None
- query = ("UNWIND $properties_list AS properties "
- f"MERGE (n:{self._escape_neo4j(label)} {{{self._escape_neo4j(id_key)}: properties.{self._escape_neo4j(id_key)}}}) "
- "SET n += properties ")
+ query = (
+ "UNWIND $properties_list AS properties "
+ f"MERGE (n:{self._escape_neo4j(label)} {{{self._escape_neo4j(id_key)}: properties.{self._escape_neo4j(id_key)}}}) "
+ "SET n += properties "
+ )
if extra_labels:
query += f", n:{':'.join(self._escape_neo4j(extra_label) for extra_label in extra_labels)} "
query += "RETURN n"
result = tx.run(query, properties_list=properties_list)
- return [record['n'] for record in result]
+ return [record["n"] for record in result]
def _get_embedding_vector(self, properties, vector_field):
for property_key, property_value in properties.items():
@@ -256,7 +292,9 @@ def _get_embedding_vector(self, properties, vector_field):
vector = self.vectorizer.vectorize(property_value)
return vector
except Exception as e:
- logger.info(f"An error occurred while vectorizing property {property_key!r}: {e}")
+ logger.info(
+ f"An error occurred while vectorizing property {property_key!r}: {e}"
+ )
return None
return None
@@ -287,7 +325,9 @@ def batch_preprocess_node_properties(self, node_batch, extra_labels=("Entity",))
return
class EmbeddingVectorPlaceholder(object):
- def __init__(self, number, properties, vector_field, property_key, property_value):
+ def __init__(
+ self, number, properties, vector_field, property_key, property_value
+ ):
self._number = number
self._properties = properties
self._vector_field = vector_field
@@ -317,7 +357,9 @@ def get_placeholder(self, graph_store, properties, vector_field):
message = f"property {property_key!r} must be string to generate embedding vector"
raise RuntimeError(message)
num = len(self._placeholders)
- placeholder = EmbeddingVectorPlaceholder(num, properties, vector_field, property_key, property_value)
+ placeholder = EmbeddingVectorPlaceholder(
+ num, properties, vector_field, property_key, property_value
+ )
self._placeholders.append(placeholder)
return placeholder
return None
@@ -364,7 +406,9 @@ def patch(self):
for vector_field in vec_meta[label]:
if vector_field in properties:
continue
- placeholder = manager.get_placeholder(self, properties, vector_field)
+ placeholder = manager.get_placeholder(
+ self, properties, vector_field
+ )
if placeholder is not None:
properties[vector_field] = placeholder
manager.batch_vectorize(self._vectorizer)
@@ -406,25 +450,58 @@ def _delete_nodes(tx, self, label, id_key, id_values):
query = f"UNWIND $id_values AS id_value MATCH (n:{self._escape_neo4j(label)} {{{self._escape_neo4j(id_key)}: id_value}}) DETACH DELETE n"
tx.run(query, id_values=id_values)
- def upsert_relationship(self, start_node_label, start_node_id_value,
- end_node_label, end_node_id_value, rel_type,
- properties, upsert_nodes=True, start_node_id_key="id", end_node_id_key="id"):
+ def upsert_relationship(
+ self,
+ start_node_label,
+ start_node_id_value,
+ end_node_label,
+ end_node_id_value,
+ rel_type,
+ properties,
+ upsert_nodes=True,
+ start_node_id_key="id",
+ end_node_id_key="id",
+ ):
rel_type = self._escape_neo4j(rel_type)
with self._driver.session(database=self._database) as session:
try:
- return session.execute_write(self._upsert_relationship, self, start_node_label, start_node_id_key,
- start_node_id_value, end_node_label, end_node_id_key,
- end_node_id_value, rel_type, properties, upsert_nodes)
+ return session.execute_write(
+ self._upsert_relationship,
+ self,
+ start_node_label,
+ start_node_id_key,
+ start_node_id_value,
+ end_node_label,
+ end_node_id_key,
+ end_node_id_value,
+ rel_type,
+ properties,
+ upsert_nodes,
+ )
except Exception as e:
- logger.error(f"upsert_relationship rel_type:{rel_type} properties:{properties} Exception: {e}")
+ logger.error(
+ f"upsert_relationship rel_type:{rel_type} properties:{properties} Exception: {e}"
+ )
return None
@staticmethod
- def _upsert_relationship(tx, self, start_node_label, start_node_id_key, start_node_id_value,
- end_node_label, end_node_id_key, end_node_id_value,
- rel_type, properties, upsert_nodes):
+ def _upsert_relationship(
+ tx,
+ self,
+ start_node_label,
+ start_node_id_key,
+ start_node_id_value,
+ end_node_label,
+ end_node_id_key,
+ end_node_id_value,
+ rel_type,
+ properties,
+ upsert_nodes,
+ ):
if not start_node_label or not end_node_label or not rel_type:
- logger.warning("start_node_label, end_node_label, and rel_type cannot be None or empty strings")
+ logger.warning(
+ "start_node_label, end_node_label, and rel_type cannot be None or empty strings"
+ )
return None
if upsert_nodes:
query = (
@@ -438,25 +515,59 @@ def _upsert_relationship(tx, self, start_node_label, start_node_id_key, start_no
f"(b:{self._escape_neo4j(end_node_label)} {{{self._escape_neo4j(end_node_id_key)}: $end_node_id_value}}) "
f"MERGE (a)-[r:{self._escape_neo4j(rel_type)}]->(b) SET r += $properties RETURN r"
)
- result = tx.run(query, start_node_id_value=start_node_id_value,
- end_node_id_value=end_node_id_value, properties=properties)
+ result = tx.run(
+ query,
+ start_node_id_value=start_node_id_value,
+ end_node_id_value=end_node_id_value,
+ properties=properties,
+ )
return result.single()
- def upsert_relationships(self, start_node_label, end_node_label, rel_type, relations,
- upsert_nodes=True, start_node_id_key="id", end_node_id_key="id"):
+ def upsert_relationships(
+ self,
+ start_node_label,
+ end_node_label,
+ rel_type,
+ relations,
+ upsert_nodes=True,
+ start_node_id_key="id",
+ end_node_id_key="id",
+ ):
with self._driver.session(database=self._database) as session:
try:
- return session.execute_write(self._upsert_relationships, self, relations, start_node_label,
- start_node_id_key, end_node_label, end_node_id_key, rel_type, upsert_nodes)
+ return session.execute_write(
+ self._upsert_relationships,
+ self,
+ relations,
+ start_node_label,
+ start_node_id_key,
+ end_node_label,
+ end_node_id_key,
+ rel_type,
+ upsert_nodes,
+ )
except Exception as e:
- logger.error(f"upsert_relationships rel_type:{rel_type} relations:{relations} Exception: {e}")
+ logger.error(
+ f"upsert_relationships rel_type:{rel_type} relations:{relations} Exception: {e}"
+ )
return None
@staticmethod
- def _upsert_relationships(tx, self, relations, start_node_label, start_node_id_key,
- end_node_label, end_node_id_key, rel_type, upsert_nodes):
+ def _upsert_relationships(
+ tx,
+ self,
+ relations,
+ start_node_label,
+ start_node_id_key,
+ end_node_label,
+ end_node_id_key,
+ rel_type,
+ upsert_nodes,
+ ):
if not start_node_label or not end_node_label or not rel_type:
- logger.warning("start_node_label, end_node_label, and rel_type cannot be None or empty strings")
+ logger.warning(
+ "start_node_label, end_node_label, and rel_type cannot be None or empty strings"
+ )
return None
if upsert_nodes:
query = (
@@ -473,51 +584,111 @@ def _upsert_relationships(tx, self, relations, start_node_label, start_node_id_k
f"MERGE (a)-[r:{self._escape_neo4j(rel_type)}]->(b) SET r += relationship.properties RETURN r"
)
- result = tx.run(query, relations=relations,
- start_node_label=start_node_label, start_node_id_key=start_node_id_key,
- end_node_label=end_node_label, end_node_id_key=end_node_id_key,
- rel_type=rel_type)
- return [record['r'] for record in result]
-
- def delete_relationship(self, start_node_label, start_node_id_value,
- end_node_label, end_node_id_value, rel_type,
- start_node_id_key="id", end_node_id_key="id"):
+ result = tx.run(
+ query,
+ relations=relations,
+ start_node_label=start_node_label,
+ start_node_id_key=start_node_id_key,
+ end_node_label=end_node_label,
+ end_node_id_key=end_node_id_key,
+ rel_type=rel_type,
+ )
+ return [record["r"] for record in result]
+
+ def delete_relationship(
+ self,
+ start_node_label,
+ start_node_id_value,
+ end_node_label,
+ end_node_id_value,
+ rel_type,
+ start_node_id_key="id",
+ end_node_id_key="id",
+ ):
with self._driver.session(database=self._database) as session:
try:
- session.execute_write(self._delete_relationship, self, start_node_label, start_node_id_key,
- start_node_id_value, end_node_label, end_node_id_key,
- end_node_id_value, rel_type)
+ session.execute_write(
+ self._delete_relationship,
+ self,
+ start_node_label,
+ start_node_id_key,
+ start_node_id_value,
+ end_node_label,
+ end_node_id_key,
+ end_node_id_value,
+ rel_type,
+ )
except Exception as e:
logger.error(f"delete_relationship rel_type:{rel_type} Exception: {e}")
-
@staticmethod
- def _delete_relationship(tx, self, start_node_label, start_node_id_key, start_node_id_value,
- end_node_label, end_node_id_key, end_node_id_value, rel_type):
+ def _delete_relationship(
+ tx,
+ self,
+ start_node_label,
+ start_node_id_key,
+ start_node_id_value,
+ end_node_label,
+ end_node_id_key,
+ end_node_id_value,
+ rel_type,
+ ):
query = (
f"MATCH (a:{self._escape_neo4j(start_node_label)} {{{self._escape_neo4j(start_node_id_key)}: $start_node_id_value}})-[r:{self._escape_neo4j(rel_type)}]->"
f"(b:{self._escape_neo4j(end_node_label)} {{{self._escape_neo4j(end_node_id_key)}: $end_node_id_value}}) DELETE r"
)
- tx.run(query, start_node_id_value=start_node_id_value, end_node_id_value=end_node_id_value)
+ tx.run(
+ query,
+ start_node_id_value=start_node_id_value,
+ end_node_id_value=end_node_id_value,
+ )
- def delete_relationships(self, start_node_label, start_node_id_values,
- end_node_label, end_node_id_values, rel_type,
- start_node_id_key="id", end_node_id_key="id"):
+ def delete_relationships(
+ self,
+ start_node_label,
+ start_node_id_values,
+ end_node_label,
+ end_node_id_values,
+ rel_type,
+ start_node_id_key="id",
+ end_node_id_key="id",
+ ):
with self._driver.session(database=self._database) as session:
- session.execute_write(self._delete_relationships, self,
- start_node_label, start_node_id_key, start_node_id_values,
- end_node_label, end_node_id_key, end_node_id_values, rel_type)
+ session.execute_write(
+ self._delete_relationships,
+ self,
+ start_node_label,
+ start_node_id_key,
+ start_node_id_values,
+ end_node_label,
+ end_node_id_key,
+ end_node_id_values,
+ rel_type,
+ )
@staticmethod
- def _delete_relationships(tx, self, start_node_label, start_node_id_key, start_node_id_values,
- end_node_label, end_node_id_key, end_node_id_values, rel_type):
+ def _delete_relationships(
+ tx,
+ self,
+ start_node_label,
+ start_node_id_key,
+ start_node_id_values,
+ end_node_label,
+ end_node_id_key,
+ end_node_id_values,
+ rel_type,
+ ):
query = (
"UNWIND $start_node_id_values AS start_node_id_value "
"UNWIND $end_node_id_values AS end_node_id_value "
f"MATCH (a:{self._escape_neo4j(start_node_label)} {{{self._escape_neo4j(start_node_id_key)}: start_node_id_value}})-[r:{self._escape_neo4j(rel_type)}]->"
f"(b:{self._escape_neo4j(end_node_label)} {{{self._escape_neo4j(end_node_id_key)}: end_node_id_value}}) DELETE r"
)
- tx.run(query, start_node_id_values=start_node_id_values, end_node_id_values=end_node_id_values)
+ tx.run(
+ query,
+ start_node_id_values=start_node_id_values,
+ end_node_id_values=end_node_id_values,
+ )
def _get_lucene_pattern(self):
string = re.escape(self._lucene_special_chars)
@@ -539,7 +710,7 @@ def _get_utf16_codepoints(self, string):
for ch in string:
data = ch.encode("utf-16-le")
for i in range(0, len(data), 2):
- value = int.from_bytes(data[i:i+2], "little")
+ value = int.from_bytes(data[i : i + 2], "little")
result.append(value)
return tuple(result)
@@ -562,6 +733,7 @@ def _escape_neo4j(self, name):
def _to_snake_case(self, name):
import re
+
words = re.findall("[A-Za-z][a-z0-9]*", name)
result = "_".join(words).lower()
return result
@@ -578,7 +750,9 @@ def _create_vector_field_name(self, property_key):
def create_index(self, label, property_key, index_name=None):
with self._driver.session(database=self._database) as session:
- session.execute_write(self._create_index, self, label, property_key, index_name)
+ session.execute_write(
+ self._create_index, self, label, property_key, index_name
+ )
@staticmethod
def _create_index(tx, self, label, property_key, index_name):
@@ -596,50 +770,87 @@ def create_text_index(self, labels, property_keys, index_name=None):
if index_name is None:
index_name = "_default_text_index"
label_spec = "|".join(self._escape_neo4j(label) for label in labels)
- property_spec = ", ".join(f"n.{self._escape_neo4j(key)}" for key in property_keys)
+ property_spec = ", ".join(
+ f"n.{self._escape_neo4j(key)}" for key in property_keys
+ )
query = (
f"CREATE FULLTEXT INDEX {self._escape_neo4j(index_name)} IF NOT EXISTS "
f"FOR (n:{label_spec}) ON EACH [{property_spec}]"
)
+
def do_create_text_index(tx):
tx.run(query)
+
with self._driver.session(database=self._database) as session:
session.execute_write(do_create_text_index)
return index_name
- def create_vector_index(self, label, property_key, index_name=None,
- vector_dimensions=768, metric_type="cosine",
- hnsw_m=None, hnsw_ef_construction=None):
+ def create_vector_index(
+ self,
+ label,
+ property_key,
+ index_name=None,
+ vector_dimensions=768,
+ metric_type="cosine",
+ hnsw_m=None,
+ hnsw_ef_construction=None,
+ ):
if index_name is None:
index_name = self._create_vector_index_name(label, property_key)
if not property_key.lower().endswith("vector"):
property_key = self._create_vector_field_name(property_key)
with self._driver.session(database=self._database) as session:
- session.execute_write(self._create_vector_index, self, label, property_key, index_name,
- vector_dimensions, metric_type, hnsw_m, hnsw_ef_construction)
+ session.execute_write(
+ self._create_vector_index,
+ self,
+ label,
+ property_key,
+ index_name,
+ vector_dimensions,
+ metric_type,
+ hnsw_m,
+ hnsw_ef_construction,
+ )
self.refresh_vector_index_meta(force=True)
return index_name
@staticmethod
- def _create_vector_index(tx, self, label, property_key, index_name, vector_dimensions, metric_type, hnsw_m, hnsw_ef_construction):
+ def _create_vector_index(
+ tx,
+ self,
+ label,
+ property_key,
+ index_name,
+ vector_dimensions,
+ metric_type,
+ hnsw_m,
+ hnsw_ef_construction,
+ ):
query = (
f"CREATE VECTOR INDEX {self._escape_neo4j(index_name)} IF NOT EXISTS FOR (n:{self._escape_neo4j(label)}) ON (n.{self._escape_neo4j(property_key)}) "
- "OPTIONS { indexConfig: {"
- " `vector.dimensions`: $vector_dimensions,"
- " `vector.similarity_function`: $metric_type"
+ "OPTIONS { indexConfig: {"
+ " `vector.dimensions`: $vector_dimensions,"
+ " `vector.similarity_function`: $metric_type"
)
if hnsw_m is not None:
query += ", `vector.hnsw.m`: $hnsw_m"
if hnsw_ef_construction is not None:
query += ", `vector.hnsw.ef_construction`: $hnsw_ef_construction"
query += "}}"
- tx.run(query, vector_dimensions=vector_dimensions, metric_type=metric_type,
- hnsw_m=hnsw_m, hnsw_ef_construction=hnsw_ef_construction)
+ tx.run(
+ query,
+ vector_dimensions=vector_dimensions,
+ metric_type=metric_type,
+ hnsw_m=hnsw_m,
+ hnsw_ef_construction=hnsw_ef_construction,
+ )
def refresh_vector_index_meta(self, force=False):
import time
+
if not force and time.time() - self._vec_meta_ts < self._vec_meta_timeout:
return
+
def do_refresh_vector_index_meta(tx):
query = "SHOW VECTOR INDEX"
res = tx.run(query)
@@ -647,14 +858,17 @@ def do_refresh_vector_index_meta(tx):
meta = dict()
for record in data:
if record["entityType"] == "NODE":
- label, = record["labelsOrTypes"]
- vector_field, = record["properties"]
- if vector_field.startswith("_") and vector_field.endswith("_vector"):
+ (label,) = record["labelsOrTypes"]
+ (vector_field,) = record["properties"]
+ if vector_field.startswith("_") and vector_field.endswith(
+ "_vector"
+ ):
if label not in meta:
meta[label] = []
meta[label].append(vector_field)
self._vec_meta = meta
self._vec_meta_ts = time.time()
+
with self._driver.session(database=self._database) as session:
session.execute_read(do_refresh_vector_index_meta)
@@ -678,7 +892,9 @@ def vectorizer(self):
def vectorizer(self, value):
self._vectorizer = value
- def text_search(self, query_string, label_constraints=None, topk=10, index_name=None):
+ def text_search(
+ self, query_string, label_constraints=None, topk=10, index_name=None
+ ):
if index_name is None:
index_name = "_default_text_index"
if label_constraints is None:
@@ -686,31 +902,48 @@ def text_search(self, query_string, label_constraints=None, topk=10, index_name=
elif isinstance(label_constraints, str):
label_constraints = self._escape_neo4j(label_constraints)
elif isinstance(label_constraints, (list, tuple)):
- label_constraints = "|".join(self._escape_neo4j(label_constraint) for label_constraint in label_constraints)
+ label_constraints = "|".join(
+ self._escape_neo4j(label_constraint)
+ for label_constraint in label_constraints
+ )
else:
message = f"invalid label_constraints: {label_constraints!r}"
raise RuntimeError(message)
if label_constraints is None:
- query = ("CALL db.index.fulltext.queryNodes($index_name, $query_string) "
- "YIELD node AS node, score "
- "RETURN node, score")
+ query = (
+ "CALL db.index.fulltext.queryNodes($index_name, $query_string) "
+ "YIELD node AS node, score "
+ "RETURN node, score"
+ )
else:
- query = ("CALL db.index.fulltext.queryNodes($index_name, $query_string) "
- "YIELD node AS node, score "
- f"WHERE (node:{label_constraints}) "
- "RETURN node, score")
+ query = (
+ "CALL db.index.fulltext.queryNodes($index_name, $query_string) "
+ "YIELD node AS node, score "
+ f"WHERE (node:{label_constraints}) "
+ "RETURN node, score"
+ )
query += " LIMIT $topk"
query_string = self._make_lucene_query(query_string)
def do_text_search(tx):
- res = tx.run(query, query_string=query_string, topk=topk, index_name=index_name)
+ res = tx.run(
+ query, query_string=query_string, topk=topk, index_name=index_name
+ )
data = res.data()
return data
with self._driver.session(database=self._database) as session:
return session.execute_read(do_text_search)
- def vector_search(self, label, property_key, query_text_or_vector, topk=10, index_name=None, ef_search=None):
+ def vector_search(
+ self,
+ label,
+ property_key,
+ query_text_or_vector,
+ topk=10,
+ index_name=None,
+ ef_search=None,
+ ):
if ef_search is not None:
if ef_search < topk:
message = f"ef_search must be greater than or equal to topk; {ef_search!r} is invalid"
@@ -719,13 +952,17 @@ def vector_search(self, label, property_key, query_text_or_vector, topk=10, inde
if index_name is None:
vec_meta = self._vec_meta
if label not in vec_meta:
- logger.warning(f"vector index not defined for label, return empty. label: {label}, "
- f"property_key: {property_key}, query_text_or_vector: {query_text_or_vector}.")
+ logger.warning(
+ f"vector index not defined for label, return empty. label: {label}, "
+ f"property_key: {property_key}, query_text_or_vector: {query_text_or_vector}."
+ )
return []
vector_field = self._create_vector_field_name(property_key)
if vector_field not in vec_meta[label]:
- logger.warning(f"vector index not defined for field, return empty. label: {label}, "
- f"property_key: {property_key}, query_text_or_vector: {query_text_or_vector}.")
+ logger.warning(
+ f"vector index not defined for field, return empty. label: {label}, "
+ f"property_key: {property_key}, query_text_or_vector: {query_text_or_vector}."
+ )
return []
if index_name is None:
index_name = self._create_vector_index_name(label, property_key)
@@ -736,16 +973,27 @@ def vector_search(self, label, property_key, query_text_or_vector, topk=10, inde
def do_vector_search(tx):
if ef_search is not None:
- query = ("CALL db.index.vector.queryNodes($index_name, $ef_search, $query_vector) "
- "YIELD node, score "
- "RETURN node, score, labels(node) as __labels__"
- f"LIMIT {topk}")
- res = tx.run(query, query_vector=query_vector, ef_search=ef_search, index_name=index_name)
+ query = (
+ "CALL db.index.vector.queryNodes($index_name, $ef_search, $query_vector) "
+ "YIELD node, score "
+ "RETURN node, score, labels(node) as __labels__"
+ f"LIMIT {topk}"
+ )
+ res = tx.run(
+ query,
+ query_vector=query_vector,
+ ef_search=ef_search,
+ index_name=index_name,
+ )
else:
- query = ("CALL db.index.vector.queryNodes($index_name, $topk, $query_vector) "
- "YIELD node, score "
- "RETURN node, score, labels(node) as __labels__")
- res = tx.run(query, query_vector=query_vector, topk=topk, index_name=index_name)
+ query = (
+ "CALL db.index.vector.queryNodes($index_name, $topk, $query_vector) "
+ "YIELD node, score "
+ "RETURN node, score, labels(node) as __labels__"
+ )
+ res = tx.run(
+ query, query_vector=query_vector, topk=topk, index_name=index_name
+ )
data = res.data()
for record in data:
record["node"]["__labels__"] = record["__labels__"]
@@ -757,41 +1005,59 @@ def do_vector_search(tx):
def _create_all_graph(self, graph_name):
with self._driver.session(database=self._database) as session:
- logger.debug(f"create pagerank graph graph_name:{graph_name} database:{self._database}")
- result = session.run(f"""
+ logger.debug(
+ f"create pagerank graph graph_name:{graph_name} database:{self._database}"
+ )
+ result = session.run(
+ f"""
CALL gds.graph.exists('{graph_name}') YIELD exists
WHERE exists
CALL gds.graph.drop('{graph_name}') YIELD graphName
RETURN graphName
- """)
+ """
+ )
summary = result.consume()
- logger.debug(f"create pagerank graph exists graph_name:{graph_name} database:{self._database} succeed "
- f"executed:{summary.result_available_after} consumed:{summary.result_consumed_after}")
+ logger.debug(
+ f"create pagerank graph exists graph_name:{graph_name} database:{self._database} succeed "
+ f"executed:{summary.result_available_after} consumed:{summary.result_consumed_after}"
+ )
- result = session.run(f"""
+ result = session.run(
+ f"""
CALL gds.graph.project('{graph_name}','*','*')
YIELD graphName, nodeCount AS nodes, relationshipCount AS rels
RETURN graphName, nodes, rels
- """)
+ """
+ )
summary = result.consume()
- logger.debug(f"create pagerank graph graph_name:{graph_name} database:{self._database} succeed "
- f"executed:{summary.result_available_after} consumed:{summary.result_consumed_after}")
+ logger.debug(
+ f"create pagerank graph graph_name:{graph_name} database:{self._database} succeed "
+ f"executed:{summary.result_available_after} consumed:{summary.result_consumed_after}"
+ )
def _drop_all_graph(self, graph_name):
with self._driver.session(database=self._database) as session:
- logger.debug(f"drop pagerank graph graph_name:{graph_name} database:{self._database}")
- result = session.run(f"""
+ logger.debug(
+ f"drop pagerank graph graph_name:{graph_name} database:{self._database}"
+ )
+ result = session.run(
+ f"""
CALL gds.graph.exists('{graph_name}') YIELD exists
WHERE exists
CALL gds.graph.drop('{graph_name}') YIELD graphName
RETURN graphName
- """)
+ """
+ )
result.consume()
- logger.debug(f"drop pagerank graph graph_name:{graph_name} database:{self._database} succeed")
+ logger.debug(
+ f"drop pagerank graph graph_name:{graph_name} database:{self._database} succeed"
+ )
def execute_pagerank(self, iterations=20, damping_factor=0.85):
with self._driver.session(database=self._database) as session:
- return session.execute_write(self._execute_pagerank, iterations, damping_factor)
+ return session.execute_write(
+ self._execute_pagerank, iterations, damping_factor
+ )
@staticmethod
def _execute_pagerank(tx, iterations, damping_factor):
@@ -809,7 +1075,9 @@ def get_pagerank_scores(self, start_nodes, target_type):
with self._driver.session(database=self._database) as session:
all_graph = self._allGraph
self._exists_all_graph(session, all_graph)
- data = session.execute_write(self._get_pagerank_scores, self, all_graph, start_nodes, target_type)
+ data = session.execute_write(
+ self._get_pagerank_scores, self, all_graph, start_nodes, target_type
+ )
return data
@staticmethod
@@ -817,13 +1085,15 @@ def _get_pagerank_scores(tx, self, graph_name, start_nodes, return_type):
match_clauses = []
match_identify = []
for index, node in enumerate(start_nodes):
- node_type, node_name = node['type'], node['name']
+ node_type, node_name = node["type"], node["name"]
node_identify = f"node_{index}"
- match_clauses.append(f"MATCH ({node_identify}:{self._escape_neo4j(node_type)} {{name: '{escape_single_quotes(node_name)}'}})")
+ match_clauses.append(
+ f"MATCH ({node_identify}:{self._escape_neo4j(node_type)} {{name: '{escape_single_quotes(node_name)}'}})"
+ )
match_identify.append(node_identify)
- match_query = ' '.join(match_clauses)
- match_identify_str = ', '.join(match_identify)
+ match_query = " ".join(match_clauses)
+ match_identify_str = ", ".join(match_identify)
pagerank_query = f"""
{match_query}
@@ -845,16 +1115,20 @@ def _get_pagerank_scores(tx, self, graph_name, start_nodes, return_type):
def _exists_all_graph(session, graph_name):
try:
logger.debug(f"exists pagerank graph graph_name:{graph_name}")
- result = session.run(f"""
+ result = session.run(
+ f"""
CALL gds.graph.exists('{graph_name}') YIELD exists
WHERE NOT exists
CALL gds.graph.project('{graph_name}','*','*')
YIELD graphName, nodeCount AS nodes, relationshipCount AS rels
RETURN graphName, nodes, rels
- """)
+ """
+ )
summary = result.consume()
- logger.debug(f"exists pagerank graph graph_name:{graph_name} succeed "
- f"executed:{summary.result_available_after} consumed:{summary.result_consumed_after}")
+ logger.debug(
+ f"exists pagerank graph graph_name:{graph_name} succeed "
+ f"executed:{summary.result_available_after} consumed:{summary.result_consumed_after}"
+ )
except Exception as e:
logger.debug(f"Error exists pagerank graph {graph_name}: {e}")
@@ -873,18 +1147,26 @@ def _count(tx, self, label):
def create_database(self, database):
with self._driver.session(database=self._database) as session:
database = database.lower()
- result = session.run(f"CREATE DATABASE {self._escape_neo4j(database)} IF NOT EXISTS")
+ result = session.run(
+ f"CREATE DATABASE {self._escape_neo4j(database)} IF NOT EXISTS"
+ )
summary = result.consume()
- logger.info(f"create_database {database} succeed "
- f"executed:{summary.result_available_after} consumed:{summary.result_consumed_after}")
+ logger.info(
+ f"create_database {database} succeed "
+ f"executed:{summary.result_available_after} consumed:{summary.result_consumed_after}"
+ )
def delete_all_data(self, database):
if self._database != database:
- raise ValueError(f"Error: Current database ({self._database}) is not the same as the target database ({database}).")
+ raise ValueError(
+ f"Error: Current database ({self._database}) is not the same as the target database ({database})."
+ )
with self._driver.session(database=database) as session:
while True:
- result = session.run("MATCH (n) WITH n LIMIT 100000 DETACH DELETE n RETURN count(*)")
+ result = session.run(
+ "MATCH (n) WITH n LIMIT 100000 DETACH DELETE n RETURN count(*)"
+ )
count = result.single()[0]
logger.info(f"Deleted {count} nodes in this batch.")
if count == 0:
@@ -893,7 +1175,9 @@ def delete_all_data(self, database):
def run_cypher_query(self, database, query, parameters=None):
if database and self._database != database:
- raise ValueError(f"Current database ({self._database}) is not the same as the target database ({database}).")
+ raise ValueError(
+ f"Current database ({self._database}) is not the same as the target database ({database})."
+ )
with self._driver.session(database=database) as session:
result = session.run(query, parameters)
diff --git a/kag/common/graphstore/rest/__init__.py b/kag/common/graphstore/rest/__init__.py
index 923147a3..2cce4606 100644
--- a/kag/common/graphstore/rest/__init__.py
+++ b/kag/common/graphstore/rest/__init__.py
@@ -35,4 +35,6 @@
from kag.common.graphstore.rest.models.edge_record_instance import EdgeRecordInstance
from kag.common.graphstore.rest.models.upsert_edge_request import UpsertEdgeRequest
from kag.common.graphstore.rest.models.upsert_vertex_request import UpsertVertexRequest
-from kag.common.graphstore.rest.models.vertex_record_instance import VertexRecordInstance
+from kag.common.graphstore.rest.models.vertex_record_instance import (
+ VertexRecordInstance,
+)
diff --git a/kag/common/graphstore/rest/graph_api.py b/kag/common/graphstore/rest/graph_api.py
index e2875966..13dcd5ea 100644
--- a/kag/common/graphstore/rest/graph_api.py
+++ b/kag/common/graphstore/rest/graph_api.py
@@ -18,10 +18,7 @@
import six
from kag.common.rest.api_client import ApiClient
-from kag.common.rest.exceptions import ( # noqa: F401
- ApiTypeError,
- ApiValueError
-)
+from kag.common.rest.exceptions import ApiTypeError, ApiValueError # noqa: F401
class GraphApi(object):
@@ -57,7 +54,7 @@ def graph_delete_edge_post(self, **kwargs): # noqa: E501
If the method is called asynchronously,
returns the request thread.
"""
- kwargs['_return_http_data_only'] = True
+ kwargs["_return_http_data_only"] = True
return self.graph_delete_edge_post_with_http_info(**kwargs) # noqa: E501
def graph_delete_edge_post_with_http_info(self, **kwargs): # noqa: E501
@@ -86,26 +83,24 @@ def graph_delete_edge_post_with_http_info(self, **kwargs): # noqa: E501
local_var_params = locals()
- all_params = [
- 'delete_edge_request'
- ]
+ all_params = ["delete_edge_request"]
all_params.extend(
[
- 'async_req',
- '_return_http_data_only',
- '_preload_content',
- '_request_timeout'
+ "async_req",
+ "_return_http_data_only",
+ "_preload_content",
+ "_request_timeout",
]
)
- for key, val in six.iteritems(local_var_params['kwargs']):
+ for key, val in six.iteritems(local_var_params["kwargs"]):
if key not in all_params:
raise ApiTypeError(
"Got an unexpected keyword argument '%s'"
" to method graph_delete_edge_post" % key
)
local_var_params[key] = val
- del local_var_params['kwargs']
+ del local_var_params["kwargs"]
collection_formats = {}
@@ -119,34 +114,42 @@ def graph_delete_edge_post_with_http_info(self, **kwargs): # noqa: E501
local_var_files = {}
body_params = None
- if 'delete_edge_request' in local_var_params:
- body_params = local_var_params['delete_edge_request']
+ if "delete_edge_request" in local_var_params:
+ body_params = local_var_params["delete_edge_request"]
# HTTP header `Accept`
- header_params['Accept'] = self.api_client.select_header_accept(
- ['application/json']) # noqa: E501
+ header_params["Accept"] = self.api_client.select_header_accept(
+ ["application/json"]
+ ) # noqa: E501
# HTTP header `Content-Type`
- header_params['Content-Type'] = self.api_client.select_header_content_type( # noqa: E501
- ['application/json']) # noqa: E501
+ header_params[
+ "Content-Type"
+ ] = self.api_client.select_header_content_type( # noqa: E501
+ ["application/json"]
+ ) # noqa: E501
# Authentication setting
auth_settings = [] # noqa: E501
return self.api_client.call_api(
- '/graph/deleteEdge', 'POST',
+ "/graph/deleteEdge",
+ "POST",
path_params,
query_params,
header_params,
body=body_params,
post_params=form_params,
files=local_var_files,
- response_type='object', # noqa: E501
+ response_type="object", # noqa: E501
auth_settings=auth_settings,
- async_req=local_var_params.get('async_req'),
- _return_http_data_only=local_var_params.get('_return_http_data_only'), # noqa: E501
- _preload_content=local_var_params.get('_preload_content', True),
- _request_timeout=local_var_params.get('_request_timeout'),
- collection_formats=collection_formats)
+ async_req=local_var_params.get("async_req"),
+ _return_http_data_only=local_var_params.get(
+ "_return_http_data_only"
+ ), # noqa: E501
+ _preload_content=local_var_params.get("_preload_content", True),
+ _request_timeout=local_var_params.get("_request_timeout"),
+ collection_formats=collection_formats,
+ )
def graph_delete_vertex_post(self, **kwargs): # noqa: E501
"""delete_vertex # noqa: E501
@@ -169,7 +172,7 @@ def graph_delete_vertex_post(self, **kwargs): # noqa: E501
If the method is called asynchronously,
returns the request thread.
"""
- kwargs['_return_http_data_only'] = True
+ kwargs["_return_http_data_only"] = True
return self.graph_delete_vertex_post_with_http_info(**kwargs) # noqa: E501
def graph_delete_vertex_post_with_http_info(self, **kwargs): # noqa: E501
@@ -198,26 +201,24 @@ def graph_delete_vertex_post_with_http_info(self, **kwargs): # noqa: E501
local_var_params = locals()
- all_params = [
- 'delete_vertex_request'
- ]
+ all_params = ["delete_vertex_request"]
all_params.extend(
[
- 'async_req',
- '_return_http_data_only',
- '_preload_content',
- '_request_timeout'
+ "async_req",
+ "_return_http_data_only",
+ "_preload_content",
+ "_request_timeout",
]
)
- for key, val in six.iteritems(local_var_params['kwargs']):
+ for key, val in six.iteritems(local_var_params["kwargs"]):
if key not in all_params:
raise ApiTypeError(
"Got an unexpected keyword argument '%s'"
" to method graph_delete_vertex_post" % key
)
local_var_params[key] = val
- del local_var_params['kwargs']
+ del local_var_params["kwargs"]
collection_formats = {}
@@ -231,34 +232,42 @@ def graph_delete_vertex_post_with_http_info(self, **kwargs): # noqa: E501
local_var_files = {}
body_params = None
- if 'delete_vertex_request' in local_var_params:
- body_params = local_var_params['delete_vertex_request']
+ if "delete_vertex_request" in local_var_params:
+ body_params = local_var_params["delete_vertex_request"]
# HTTP header `Accept`
- header_params['Accept'] = self.api_client.select_header_accept(
- ['application/json']) # noqa: E501
+ header_params["Accept"] = self.api_client.select_header_accept(
+ ["application/json"]
+ ) # noqa: E501
# HTTP header `Content-Type`
- header_params['Content-Type'] = self.api_client.select_header_content_type( # noqa: E501
- ['application/json']) # noqa: E501
+ header_params[
+ "Content-Type"
+ ] = self.api_client.select_header_content_type( # noqa: E501
+ ["application/json"]
+ ) # noqa: E501
# Authentication setting
auth_settings = [] # noqa: E501
return self.api_client.call_api(
- '/graph/deleteVertex', 'POST',
+ "/graph/deleteVertex",
+ "POST",
path_params,
query_params,
header_params,
body=body_params,
post_params=form_params,
files=local_var_files,
- response_type='object', # noqa: E501
+ response_type="object", # noqa: E501
auth_settings=auth_settings,
- async_req=local_var_params.get('async_req'),
- _return_http_data_only=local_var_params.get('_return_http_data_only'), # noqa: E501
- _preload_content=local_var_params.get('_preload_content', True),
- _request_timeout=local_var_params.get('_request_timeout'),
- collection_formats=collection_formats)
+ async_req=local_var_params.get("async_req"),
+ _return_http_data_only=local_var_params.get(
+ "_return_http_data_only"
+ ), # noqa: E501
+ _preload_content=local_var_params.get("_preload_content", True),
+ _request_timeout=local_var_params.get("_request_timeout"),
+ collection_formats=collection_formats,
+ )
def graph_upsert_edge_post(self, **kwargs): # noqa: E501
"""upsert_edge # noqa: E501
@@ -281,7 +290,7 @@ def graph_upsert_edge_post(self, **kwargs): # noqa: E501
If the method is called asynchronously,
returns the request thread.
"""
- kwargs['_return_http_data_only'] = True
+ kwargs["_return_http_data_only"] = True
return self.graph_upsert_edge_post_with_http_info(**kwargs) # noqa: E501
def graph_upsert_edge_post_with_http_info(self, **kwargs): # noqa: E501
@@ -310,26 +319,24 @@ def graph_upsert_edge_post_with_http_info(self, **kwargs): # noqa: E501
local_var_params = locals()
- all_params = [
- 'upsert_edge_request'
- ]
+ all_params = ["upsert_edge_request"]
all_params.extend(
[
- 'async_req',
- '_return_http_data_only',
- '_preload_content',
- '_request_timeout'
+ "async_req",
+ "_return_http_data_only",
+ "_preload_content",
+ "_request_timeout",
]
)
- for key, val in six.iteritems(local_var_params['kwargs']):
+ for key, val in six.iteritems(local_var_params["kwargs"]):
if key not in all_params:
raise ApiTypeError(
"Got an unexpected keyword argument '%s'"
" to method graph_upsert_edge_post" % key
)
local_var_params[key] = val
- del local_var_params['kwargs']
+ del local_var_params["kwargs"]
collection_formats = {}
@@ -343,34 +350,42 @@ def graph_upsert_edge_post_with_http_info(self, **kwargs): # noqa: E501
local_var_files = {}
body_params = None
- if 'upsert_edge_request' in local_var_params:
- body_params = local_var_params['upsert_edge_request']
+ if "upsert_edge_request" in local_var_params:
+ body_params = local_var_params["upsert_edge_request"]
# HTTP header `Accept`
- header_params['Accept'] = self.api_client.select_header_accept(
- ['application/json']) # noqa: E501
+ header_params["Accept"] = self.api_client.select_header_accept(
+ ["application/json"]
+ ) # noqa: E501
# HTTP header `Content-Type`
- header_params['Content-Type'] = self.api_client.select_header_content_type( # noqa: E501
- ['application/json']) # noqa: E501
+ header_params[
+ "Content-Type"
+ ] = self.api_client.select_header_content_type( # noqa: E501
+ ["application/json"]
+ ) # noqa: E501
# Authentication setting
auth_settings = [] # noqa: E501
return self.api_client.call_api(
- '/graph/upsertEdge', 'POST',
+ "/graph/upsertEdge",
+ "POST",
path_params,
query_params,
header_params,
body=body_params,
post_params=form_params,
files=local_var_files,
- response_type='object', # noqa: E501
+ response_type="object", # noqa: E501
auth_settings=auth_settings,
- async_req=local_var_params.get('async_req'),
- _return_http_data_only=local_var_params.get('_return_http_data_only'), # noqa: E501
- _preload_content=local_var_params.get('_preload_content', True),
- _request_timeout=local_var_params.get('_request_timeout'),
- collection_formats=collection_formats)
+ async_req=local_var_params.get("async_req"),
+ _return_http_data_only=local_var_params.get(
+ "_return_http_data_only"
+ ), # noqa: E501
+ _preload_content=local_var_params.get("_preload_content", True),
+ _request_timeout=local_var_params.get("_request_timeout"),
+ collection_formats=collection_formats,
+ )
def graph_upsert_vertex_post(self, **kwargs): # noqa: E501
"""upsert_vertex # noqa: E501
@@ -393,7 +408,7 @@ def graph_upsert_vertex_post(self, **kwargs): # noqa: E501
If the method is called asynchronously,
returns the request thread.
"""
- kwargs['_return_http_data_only'] = True
+ kwargs["_return_http_data_only"] = True
return self.graph_upsert_vertex_post_with_http_info(**kwargs) # noqa: E501
def graph_upsert_vertex_post_with_http_info(self, **kwargs): # noqa: E501
@@ -422,26 +437,24 @@ def graph_upsert_vertex_post_with_http_info(self, **kwargs): # noqa: E501
local_var_params = locals()
- all_params = [
- 'upsert_vertex_request'
- ]
+ all_params = ["upsert_vertex_request"]
all_params.extend(
[
- 'async_req',
- '_return_http_data_only',
- '_preload_content',
- '_request_timeout'
+ "async_req",
+ "_return_http_data_only",
+ "_preload_content",
+ "_request_timeout",
]
)
- for key, val in six.iteritems(local_var_params['kwargs']):
+ for key, val in six.iteritems(local_var_params["kwargs"]):
if key not in all_params:
raise ApiTypeError(
"Got an unexpected keyword argument '%s'"
" to method graph_upsert_vertex_post" % key
)
local_var_params[key] = val
- del local_var_params['kwargs']
+ del local_var_params["kwargs"]
collection_formats = {}
@@ -455,31 +468,39 @@ def graph_upsert_vertex_post_with_http_info(self, **kwargs): # noqa: E501
local_var_files = {}
body_params = None
- if 'upsert_vertex_request' in local_var_params:
- body_params = local_var_params['upsert_vertex_request']
+ if "upsert_vertex_request" in local_var_params:
+ body_params = local_var_params["upsert_vertex_request"]
# HTTP header `Accept`
- header_params['Accept'] = self.api_client.select_header_accept(
- ['application/json']) # noqa: E501
+ header_params["Accept"] = self.api_client.select_header_accept(
+ ["application/json"]
+ ) # noqa: E501
# HTTP header `Content-Type`
- header_params['Content-Type'] = self.api_client.select_header_content_type( # noqa: E501
- ['application/json']) # noqa: E501
+ header_params[
+ "Content-Type"
+ ] = self.api_client.select_header_content_type( # noqa: E501
+ ["application/json"]
+ ) # noqa: E501
# Authentication setting
auth_settings = [] # noqa: E501
return self.api_client.call_api(
- '/graph/upsertVertex', 'POST',
+ "/graph/upsertVertex",
+ "POST",
path_params,
query_params,
header_params,
body=body_params,
post_params=form_params,
files=local_var_files,
- response_type='object', # noqa: E501
+ response_type="object", # noqa: E501
auth_settings=auth_settings,
- async_req=local_var_params.get('async_req'),
- _return_http_data_only=local_var_params.get('_return_http_data_only'), # noqa: E501
- _preload_content=local_var_params.get('_preload_content', True),
- _request_timeout=local_var_params.get('_request_timeout'),
- collection_formats=collection_formats)
+ async_req=local_var_params.get("async_req"),
+ _return_http_data_only=local_var_params.get(
+ "_return_http_data_only"
+ ), # noqa: E501
+ _preload_content=local_var_params.get("_preload_content", True),
+ _request_timeout=local_var_params.get("_request_timeout"),
+ collection_formats=collection_formats,
+ )
diff --git a/kag/common/graphstore/rest/models/__init__.py b/kag/common/graphstore/rest/models/__init__.py
index 9660757a..ef11492f 100644
--- a/kag/common/graphstore/rest/models/__init__.py
+++ b/kag/common/graphstore/rest/models/__init__.py
@@ -16,4 +16,6 @@
from kag.common.graphstore.rest.models.edge_record_instance import EdgeRecordInstance
from kag.common.graphstore.rest.models.upsert_edge_request import UpsertEdgeRequest
from kag.common.graphstore.rest.models.upsert_vertex_request import UpsertVertexRequest
-from kag.common.graphstore.rest.models.vertex_record_instance import VertexRecordInstance
+from kag.common.graphstore.rest.models.vertex_record_instance import (
+ VertexRecordInstance,
+)
diff --git a/kag/common/graphstore/rest/models/delete_edge_request.py b/kag/common/graphstore/rest/models/delete_edge_request.py
index 4dc2984f..6d0a03ed 100644
--- a/kag/common/graphstore/rest/models/delete_edge_request.py
+++ b/kag/common/graphstore/rest/models/delete_edge_request.py
@@ -32,17 +32,13 @@ class DeleteEdgeRequest(object):
attribute_map (dict): The key is attribute name
and the value is json key in definition.
"""
- openapi_types = {
- 'project_id': 'int',
- 'edges': 'list[EdgeRecordInstance]'
- }
+ openapi_types = {"project_id": "int", "edges": "list[EdgeRecordInstance]"}
- attribute_map = {
- 'project_id': 'projectId',
- 'edges': 'edges'
- }
+ attribute_map = {"project_id": "projectId", "edges": "edges"}
- def __init__(self, project_id=None, edges=None, local_vars_configuration=None): # noqa: E501
+ def __init__(
+ self, project_id=None, edges=None, local_vars_configuration=None
+ ): # noqa: E501
"""DeleteEdgeRequest - a model defined in OpenAPI""" # noqa: E501
if local_vars_configuration is None:
local_vars_configuration = Configuration()
@@ -73,8 +69,12 @@ def project_id(self, project_id):
:param project_id: The project_id of this DeleteEdgeRequest. # noqa: E501
:type: int
"""
- if self.local_vars_configuration.client_side_validation and project_id is None: # noqa: E501
- raise ValueError("Invalid value for `project_id`, must not be `None`") # noqa: E501
+ if (
+ self.local_vars_configuration.client_side_validation and project_id is None
+ ): # noqa: E501
+ raise ValueError(
+ "Invalid value for `project_id`, must not be `None`"
+ ) # noqa: E501
self._project_id = project_id
@@ -96,8 +96,12 @@ def edges(self, edges):
:param edges: The edges of this DeleteEdgeRequest. # noqa: E501
:type: list[EdgeRecordInstance]
"""
- if self.local_vars_configuration.client_side_validation and edges is None: # noqa: E501
- raise ValueError("Invalid value for `edges`, must not be `None`") # noqa: E501
+ if (
+ self.local_vars_configuration.client_side_validation and edges is None
+ ): # noqa: E501
+ raise ValueError(
+ "Invalid value for `edges`, must not be `None`"
+ ) # noqa: E501
self._edges = edges
@@ -108,18 +112,20 @@ def to_dict(self):
for attr, _ in six.iteritems(self.openapi_types):
value = getattr(self, attr)
if isinstance(value, list):
- result[attr] = list(map(
- lambda x: x.to_dict() if hasattr(x, "to_dict") else x,
- value
- ))
+ result[attr] = list(
+ map(lambda x: x.to_dict() if hasattr(x, "to_dict") else x, value)
+ )
elif hasattr(value, "to_dict"):
result[attr] = value.to_dict()
elif isinstance(value, dict):
- result[attr] = dict(map(
- lambda item: (item[0], item[1].to_dict())
- if hasattr(item[1], "to_dict") else item,
- value.items()
- ))
+ result[attr] = dict(
+ map(
+ lambda item: (item[0], item[1].to_dict())
+ if hasattr(item[1], "to_dict")
+ else item,
+ value.items(),
+ )
+ )
else:
result[attr] = value
diff --git a/kag/common/graphstore/rest/models/delete_vertex_request.py b/kag/common/graphstore/rest/models/delete_vertex_request.py
index 1e9b980a..f6384a20 100644
--- a/kag/common/graphstore/rest/models/delete_vertex_request.py
+++ b/kag/common/graphstore/rest/models/delete_vertex_request.py
@@ -32,17 +32,13 @@ class DeleteVertexRequest(object):
attribute_map (dict): The key is attribute name
and the value is json key in definition.
"""
- openapi_types = {
- 'project_id': 'int',
- 'vertices': 'list[VertexRecordInstance]'
- }
+ openapi_types = {"project_id": "int", "vertices": "list[VertexRecordInstance]"}
- attribute_map = {
- 'project_id': 'projectId',
- 'vertices': 'vertices'
- }
+ attribute_map = {"project_id": "projectId", "vertices": "vertices"}
- def __init__(self, project_id=None, vertices=None, local_vars_configuration=None): # noqa: E501
+ def __init__(
+ self, project_id=None, vertices=None, local_vars_configuration=None
+ ): # noqa: E501
"""DeleteVertexRequest - a model defined in OpenAPI""" # noqa: E501
if local_vars_configuration is None:
local_vars_configuration = Configuration()
@@ -73,8 +69,12 @@ def project_id(self, project_id):
:param project_id: The project_id of this DeleteVertexRequest. # noqa: E501
:type: int
"""
- if self.local_vars_configuration.client_side_validation and project_id is None: # noqa: E501
- raise ValueError("Invalid value for `project_id`, must not be `None`") # noqa: E501
+ if (
+ self.local_vars_configuration.client_side_validation and project_id is None
+ ): # noqa: E501
+ raise ValueError(
+ "Invalid value for `project_id`, must not be `None`"
+ ) # noqa: E501
self._project_id = project_id
@@ -96,8 +96,12 @@ def vertices(self, vertices):
:param vertices: The vertices of this DeleteVertexRequest. # noqa: E501
:type: list[VertexRecordInstance]
"""
- if self.local_vars_configuration.client_side_validation and vertices is None: # noqa: E501
- raise ValueError("Invalid value for `vertices`, must not be `None`") # noqa: E501
+ if (
+ self.local_vars_configuration.client_side_validation and vertices is None
+ ): # noqa: E501
+ raise ValueError(
+ "Invalid value for `vertices`, must not be `None`"
+ ) # noqa: E501
self._vertices = vertices
@@ -108,18 +112,20 @@ def to_dict(self):
for attr, _ in six.iteritems(self.openapi_types):
value = getattr(self, attr)
if isinstance(value, list):
- result[attr] = list(map(
- lambda x: x.to_dict() if hasattr(x, "to_dict") else x,
- value
- ))
+ result[attr] = list(
+ map(lambda x: x.to_dict() if hasattr(x, "to_dict") else x, value)
+ )
elif hasattr(value, "to_dict"):
result[attr] = value.to_dict()
elif isinstance(value, dict):
- result[attr] = dict(map(
- lambda item: (item[0], item[1].to_dict())
- if hasattr(item[1], "to_dict") else item,
- value.items()
- ))
+ result[attr] = dict(
+ map(
+ lambda item: (item[0], item[1].to_dict())
+ if hasattr(item[1], "to_dict")
+ else item,
+ value.items(),
+ )
+ )
else:
result[attr] = value
diff --git a/kag/common/graphstore/rest/models/edge_record_instance.py b/kag/common/graphstore/rest/models/edge_record_instance.py
index 77873ddd..e901fdde 100644
--- a/kag/common/graphstore/rest/models/edge_record_instance.py
+++ b/kag/common/graphstore/rest/models/edge_record_instance.py
@@ -33,24 +33,33 @@ class EdgeRecordInstance(object):
and the value is json key in definition.
"""
openapi_types = {
- 'src_type': 'str',
- 'src_id': 'str',
- 'dst_type': 'str',
- 'dst_id': 'str',
- 'label': 'str',
- 'properties': 'object'
+ "src_type": "str",
+ "src_id": "str",
+ "dst_type": "str",
+ "dst_id": "str",
+ "label": "str",
+ "properties": "object",
}
attribute_map = {
- 'src_type': 'srcType',
- 'src_id': 'srcId',
- 'dst_type': 'dstType',
- 'dst_id': 'dstId',
- 'label': 'label',
- 'properties': 'properties'
+ "src_type": "srcType",
+ "src_id": "srcId",
+ "dst_type": "dstType",
+ "dst_id": "dstId",
+ "label": "label",
+ "properties": "properties",
}
- def __init__(self, src_type=None, src_id=None, dst_type=None, dst_id=None, label=None, properties=None, local_vars_configuration=None): # noqa: E501
+ def __init__(
+ self,
+ src_type=None,
+ src_id=None,
+ dst_type=None,
+ dst_id=None,
+ label=None,
+ properties=None,
+ local_vars_configuration=None,
+ ): # noqa: E501
"""EdgeRecordInstance - a model defined in OpenAPI""" # noqa: E501
if local_vars_configuration is None:
local_vars_configuration = Configuration()
@@ -89,8 +98,12 @@ def src_type(self, src_type):
:param src_type: The src_type of this EdgeRecordInstance. # noqa: E501
:type: str
"""
- if self.local_vars_configuration.client_side_validation and src_type is None: # noqa: E501
- raise ValueError("Invalid value for `src_type`, must not be `None`") # noqa: E501
+ if (
+ self.local_vars_configuration.client_side_validation and src_type is None
+ ): # noqa: E501
+ raise ValueError(
+ "Invalid value for `src_type`, must not be `None`"
+ ) # noqa: E501
self._src_type = src_type
@@ -112,8 +125,12 @@ def src_id(self, src_id):
:param src_id: The src_id of this EdgeRecordInstance. # noqa: E501
:type: str
"""
- if self.local_vars_configuration.client_side_validation and src_id is None: # noqa: E501
- raise ValueError("Invalid value for `src_id`, must not be `None`") # noqa: E501
+ if (
+ self.local_vars_configuration.client_side_validation and src_id is None
+ ): # noqa: E501
+ raise ValueError(
+ "Invalid value for `src_id`, must not be `None`"
+ ) # noqa: E501
self._src_id = src_id
@@ -135,8 +152,12 @@ def dst_type(self, dst_type):
:param dst_type: The dst_type of this EdgeRecordInstance. # noqa: E501
:type: str
"""
- if self.local_vars_configuration.client_side_validation and dst_type is None: # noqa: E501
- raise ValueError("Invalid value for `dst_type`, must not be `None`") # noqa: E501
+ if (
+ self.local_vars_configuration.client_side_validation and dst_type is None
+ ): # noqa: E501
+ raise ValueError(
+ "Invalid value for `dst_type`, must not be `None`"
+ ) # noqa: E501
self._dst_type = dst_type
@@ -158,8 +179,12 @@ def dst_id(self, dst_id):
:param dst_id: The dst_id of this EdgeRecordInstance. # noqa: E501
:type: str
"""
- if self.local_vars_configuration.client_side_validation and dst_id is None: # noqa: E501
- raise ValueError("Invalid value for `dst_id`, must not be `None`") # noqa: E501
+ if (
+ self.local_vars_configuration.client_side_validation and dst_id is None
+ ): # noqa: E501
+ raise ValueError(
+ "Invalid value for `dst_id`, must not be `None`"
+ ) # noqa: E501
self._dst_id = dst_id
@@ -181,8 +206,12 @@ def label(self, label):
:param label: The label of this EdgeRecordInstance. # noqa: E501
:type: str
"""
- if self.local_vars_configuration.client_side_validation and label is None: # noqa: E501
- raise ValueError("Invalid value for `label`, must not be `None`") # noqa: E501
+ if (
+ self.local_vars_configuration.client_side_validation and label is None
+ ): # noqa: E501
+ raise ValueError(
+ "Invalid value for `label`, must not be `None`"
+ ) # noqa: E501
self._label = label
@@ -204,8 +233,12 @@ def properties(self, properties):
:param properties: The properties of this EdgeRecordInstance. # noqa: E501
:type: object
"""
- if self.local_vars_configuration.client_side_validation and properties is None: # noqa: E501
- raise ValueError("Invalid value for `properties`, must not be `None`") # noqa: E501
+ if (
+ self.local_vars_configuration.client_side_validation and properties is None
+ ): # noqa: E501
+ raise ValueError(
+ "Invalid value for `properties`, must not be `None`"
+ ) # noqa: E501
self._properties = properties
@@ -216,18 +249,20 @@ def to_dict(self):
for attr, _ in six.iteritems(self.openapi_types):
value = getattr(self, attr)
if isinstance(value, list):
- result[attr] = list(map(
- lambda x: x.to_dict() if hasattr(x, "to_dict") else x,
- value
- ))
+ result[attr] = list(
+ map(lambda x: x.to_dict() if hasattr(x, "to_dict") else x, value)
+ )
elif hasattr(value, "to_dict"):
result[attr] = value.to_dict()
elif isinstance(value, dict):
- result[attr] = dict(map(
- lambda item: (item[0], item[1].to_dict())
- if hasattr(item[1], "to_dict") else item,
- value.items()
- ))
+ result[attr] = dict(
+ map(
+ lambda item: (item[0], item[1].to_dict())
+ if hasattr(item[1], "to_dict")
+ else item,
+ value.items(),
+ )
+ )
else:
result[attr] = value
diff --git a/kag/common/graphstore/rest/models/upsert_edge_request.py b/kag/common/graphstore/rest/models/upsert_edge_request.py
index 7dd1c89a..5cd69ed1 100644
--- a/kag/common/graphstore/rest/models/upsert_edge_request.py
+++ b/kag/common/graphstore/rest/models/upsert_edge_request.py
@@ -33,18 +33,24 @@ class UpsertEdgeRequest(object):
and the value is json key in definition.
"""
openapi_types = {
- 'project_id': 'int',
- 'upsert_adjacent_vertices': 'bool',
- 'edges': 'list[EdgeRecordInstance]'
+ "project_id": "int",
+ "upsert_adjacent_vertices": "bool",
+ "edges": "list[EdgeRecordInstance]",
}
attribute_map = {
- 'project_id': 'projectId',
- 'upsert_adjacent_vertices': 'upsertAdjacentVertices',
- 'edges': 'edges'
+ "project_id": "projectId",
+ "upsert_adjacent_vertices": "upsertAdjacentVertices",
+ "edges": "edges",
}
- def __init__(self, project_id=None, upsert_adjacent_vertices=None, edges=None, local_vars_configuration=None): # noqa: E501
+ def __init__(
+ self,
+ project_id=None,
+ upsert_adjacent_vertices=None,
+ edges=None,
+ local_vars_configuration=None,
+ ): # noqa: E501
"""UpsertEdgeRequest - a model defined in OpenAPI""" # noqa: E501
if local_vars_configuration is None:
local_vars_configuration = Configuration()
@@ -77,8 +83,12 @@ def project_id(self, project_id):
:param project_id: The project_id of this UpsertEdgeRequest. # noqa: E501
:type: int
"""
- if self.local_vars_configuration.client_side_validation and project_id is None: # noqa: E501
- raise ValueError("Invalid value for `project_id`, must not be `None`") # noqa: E501
+ if (
+ self.local_vars_configuration.client_side_validation and project_id is None
+ ): # noqa: E501
+ raise ValueError(
+ "Invalid value for `project_id`, must not be `None`"
+ ) # noqa: E501
self._project_id = project_id
@@ -100,8 +110,13 @@ def upsert_adjacent_vertices(self, upsert_adjacent_vertices):
:param upsert_adjacent_vertices: The upsert_adjacent_vertices of this UpsertEdgeRequest. # noqa: E501
:type: bool
"""
- if self.local_vars_configuration.client_side_validation and upsert_adjacent_vertices is None: # noqa: E501
- raise ValueError("Invalid value for `upsert_adjacent_vertices`, must not be `None`") # noqa: E501
+ if (
+ self.local_vars_configuration.client_side_validation
+ and upsert_adjacent_vertices is None
+ ): # noqa: E501
+ raise ValueError(
+ "Invalid value for `upsert_adjacent_vertices`, must not be `None`"
+ ) # noqa: E501
self._upsert_adjacent_vertices = upsert_adjacent_vertices
@@ -123,8 +138,12 @@ def edges(self, edges):
:param edges: The edges of this UpsertEdgeRequest. # noqa: E501
:type: list[EdgeRecordInstance]
"""
- if self.local_vars_configuration.client_side_validation and edges is None: # noqa: E501
- raise ValueError("Invalid value for `edges`, must not be `None`") # noqa: E501
+ if (
+ self.local_vars_configuration.client_side_validation and edges is None
+ ): # noqa: E501
+ raise ValueError(
+ "Invalid value for `edges`, must not be `None`"
+ ) # noqa: E501
self._edges = edges
@@ -135,18 +154,20 @@ def to_dict(self):
for attr, _ in six.iteritems(self.openapi_types):
value = getattr(self, attr)
if isinstance(value, list):
- result[attr] = list(map(
- lambda x: x.to_dict() if hasattr(x, "to_dict") else x,
- value
- ))
+ result[attr] = list(
+ map(lambda x: x.to_dict() if hasattr(x, "to_dict") else x, value)
+ )
elif hasattr(value, "to_dict"):
result[attr] = value.to_dict()
elif isinstance(value, dict):
- result[attr] = dict(map(
- lambda item: (item[0], item[1].to_dict())
- if hasattr(item[1], "to_dict") else item,
- value.items()
- ))
+ result[attr] = dict(
+ map(
+ lambda item: (item[0], item[1].to_dict())
+ if hasattr(item[1], "to_dict")
+ else item,
+ value.items(),
+ )
+ )
else:
result[attr] = value
diff --git a/kag/common/graphstore/rest/models/upsert_vertex_request.py b/kag/common/graphstore/rest/models/upsert_vertex_request.py
index 6ed6cec1..682968b8 100644
--- a/kag/common/graphstore/rest/models/upsert_vertex_request.py
+++ b/kag/common/graphstore/rest/models/upsert_vertex_request.py
@@ -32,17 +32,13 @@ class UpsertVertexRequest(object):
attribute_map (dict): The key is attribute name
and the value is json key in definition.
"""
- openapi_types = {
- 'project_id': 'int',
- 'vertices': 'list[VertexRecordInstance]'
- }
+ openapi_types = {"project_id": "int", "vertices": "list[VertexRecordInstance]"}
- attribute_map = {
- 'project_id': 'projectId',
- 'vertices': 'vertices'
- }
+ attribute_map = {"project_id": "projectId", "vertices": "vertices"}
- def __init__(self, project_id=None, vertices=None, local_vars_configuration=None): # noqa: E501
+ def __init__(
+ self, project_id=None, vertices=None, local_vars_configuration=None
+ ): # noqa: E501
"""UpsertVertexRequest - a model defined in OpenAPI""" # noqa: E501
if local_vars_configuration is None:
local_vars_configuration = Configuration()
@@ -73,8 +69,12 @@ def project_id(self, project_id):
:param project_id: The project_id of this UpsertVertexRequest. # noqa: E501
:type: int
"""
- if self.local_vars_configuration.client_side_validation and project_id is None: # noqa: E501
- raise ValueError("Invalid value for `project_id`, must not be `None`") # noqa: E501
+ if (
+ self.local_vars_configuration.client_side_validation and project_id is None
+ ): # noqa: E501
+ raise ValueError(
+ "Invalid value for `project_id`, must not be `None`"
+ ) # noqa: E501
self._project_id = project_id
@@ -96,8 +96,12 @@ def vertices(self, vertices):
:param vertices: The vertices of this UpsertVertexRequest. # noqa: E501
:type: list[VertexRecordInstance]
"""
- if self.local_vars_configuration.client_side_validation and vertices is None: # noqa: E501
- raise ValueError("Invalid value for `vertices`, must not be `None`") # noqa: E501
+ if (
+ self.local_vars_configuration.client_side_validation and vertices is None
+ ): # noqa: E501
+ raise ValueError(
+ "Invalid value for `vertices`, must not be `None`"
+ ) # noqa: E501
self._vertices = vertices
@@ -108,18 +112,20 @@ def to_dict(self):
for attr, _ in six.iteritems(self.openapi_types):
value = getattr(self, attr)
if isinstance(value, list):
- result[attr] = list(map(
- lambda x: x.to_dict() if hasattr(x, "to_dict") else x,
- value
- ))
+ result[attr] = list(
+ map(lambda x: x.to_dict() if hasattr(x, "to_dict") else x, value)
+ )
elif hasattr(value, "to_dict"):
result[attr] = value.to_dict()
elif isinstance(value, dict):
- result[attr] = dict(map(
- lambda item: (item[0], item[1].to_dict())
- if hasattr(item[1], "to_dict") else item,
- value.items()
- ))
+ result[attr] = dict(
+ map(
+ lambda item: (item[0], item[1].to_dict())
+ if hasattr(item[1], "to_dict")
+ else item,
+ value.items(),
+ )
+ )
else:
result[attr] = value
diff --git a/kag/common/graphstore/rest/models/vertex_record_instance.py b/kag/common/graphstore/rest/models/vertex_record_instance.py
index 8fe12ca2..710891c1 100644
--- a/kag/common/graphstore/rest/models/vertex_record_instance.py
+++ b/kag/common/graphstore/rest/models/vertex_record_instance.py
@@ -33,20 +33,27 @@ class VertexRecordInstance(object):
and the value is json key in definition.
"""
openapi_types = {
- 'type': 'str',
- 'id': 'str',
- 'properties': 'object',
- 'vectors': 'object'
+ "type": "str",
+ "id": "str",
+ "properties": "object",
+ "vectors": "object",
}
attribute_map = {
- 'type': 'type',
- 'id': 'id',
- 'properties': 'properties',
- 'vectors': 'vectors'
+ "type": "type",
+ "id": "id",
+ "properties": "properties",
+ "vectors": "vectors",
}
- def __init__(self, type=None, id=None, properties=None, vectors=None, local_vars_configuration=None): # noqa: E501
+ def __init__(
+ self,
+ type=None,
+ id=None,
+ properties=None,
+ vectors=None,
+ local_vars_configuration=None,
+ ): # noqa: E501
"""VertexRecordInstance - a model defined in OpenAPI""" # noqa: E501
if local_vars_configuration is None:
local_vars_configuration = Configuration()
@@ -81,8 +88,12 @@ def type(self, type):
:param type: The type of this VertexRecordInstance. # noqa: E501
:type: str
"""
- if self.local_vars_configuration.client_side_validation and type is None: # noqa: E501
- raise ValueError("Invalid value for `type`, must not be `None`") # noqa: E501
+ if (
+ self.local_vars_configuration.client_side_validation and type is None
+ ): # noqa: E501
+ raise ValueError(
+ "Invalid value for `type`, must not be `None`"
+ ) # noqa: E501
self._type = type
@@ -104,7 +115,9 @@ def id(self, id):
:param id: The id of this VertexRecordInstance. # noqa: E501
:type: str
"""
- if self.local_vars_configuration.client_side_validation and id is None: # noqa: E501
+ if (
+ self.local_vars_configuration.client_side_validation and id is None
+ ): # noqa: E501
raise ValueError("Invalid value for `id`, must not be `None`") # noqa: E501
self._id = id
@@ -127,8 +140,12 @@ def properties(self, properties):
:param properties: The properties of this VertexRecordInstance. # noqa: E501
:type: object
"""
- if self.local_vars_configuration.client_side_validation and properties is None: # noqa: E501
- raise ValueError("Invalid value for `properties`, must not be `None`") # noqa: E501
+ if (
+ self.local_vars_configuration.client_side_validation and properties is None
+ ): # noqa: E501
+ raise ValueError(
+ "Invalid value for `properties`, must not be `None`"
+ ) # noqa: E501
self._properties = properties
@@ -150,8 +167,12 @@ def vectors(self, vectors):
:param vectors: The vectors of this VertexRecordInstance. # noqa: E501
:type: object
"""
- if self.local_vars_configuration.client_side_validation and vectors is None: # noqa: E501
- raise ValueError("Invalid value for `vectors`, must not be `None`") # noqa: E501
+ if (
+ self.local_vars_configuration.client_side_validation and vectors is None
+ ): # noqa: E501
+ raise ValueError(
+ "Invalid value for `vectors`, must not be `None`"
+ ) # noqa: E501
self._vectors = vectors
@@ -162,18 +183,20 @@ def to_dict(self):
for attr, _ in six.iteritems(self.openapi_types):
value = getattr(self, attr)
if isinstance(value, list):
- result[attr] = list(map(
- lambda x: x.to_dict() if hasattr(x, "to_dict") else x,
- value
- ))
+ result[attr] = list(
+ map(lambda x: x.to_dict() if hasattr(x, "to_dict") else x, value)
+ )
elif hasattr(value, "to_dict"):
result[attr] = value.to_dict()
elif isinstance(value, dict):
- result[attr] = dict(map(
- lambda item: (item[0], item[1].to_dict())
- if hasattr(item[1], "to_dict") else item,
- value.items()
- ))
+ result[attr] = dict(
+ map(
+ lambda item: (item[0], item[1].to_dict())
+ if hasattr(item[1], "to_dict")
+ else item,
+ value.items(),
+ )
+ )
else:
result[attr] = value
diff --git a/kag/common/llm/llm_client.py b/kag/common/llm/llm_client.py
index 8c509e50..88e4a05e 100644
--- a/kag/common/llm/llm_client.py
+++ b/kag/common/llm/llm_client.py
@@ -10,13 +10,10 @@
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
-import os
import json
-from pathlib import Path
from typing import Union, Dict, List, Any
import logging
import traceback
-import yaml
from kag.common.base.prompt_op import PromptOp
from kag.common.registry import Registrable
diff --git a/kag/common/llm/ollama_client.py b/kag/common/llm/ollama_client.py
index 44613009..61b11df2 100644
--- a/kag/common/llm/ollama_client.py
+++ b/kag/common/llm/ollama_client.py
@@ -10,31 +10,11 @@
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
-import os
-import ast
-import re
import json
-import time
-import uuid
-import html
-from binascii import b2a_hex
-from datetime import datetime
-from pathlib import Path
-from typing import Union, Dict, List, Any
-from urllib import request
-from collections import defaultdict
-from openai import OpenAI
import logging
from ollama import Client
-import requests
-import traceback
-from Crypto.Cipher import AES
-from requests import RequestException
-
-from kag.common import arks_pb2
-from kag.common.base.prompt_op import PromptOp
from kag.common.llm.llm_client import LLMClient
diff --git a/kag/common/llm/openai_client.py b/kag/common/llm/openai_client.py
index d6ad0c94..e4980c9c 100644
--- a/kag/common/llm/openai_client.py
+++ b/kag/common/llm/openai_client.py
@@ -12,7 +12,6 @@
import json
-from typing import Union
from openai import OpenAI
import logging
diff --git a/kag/common/llm/vllm_client.py b/kag/common/llm/vllm_client.py
index 11b12b3b..0700c8e2 100644
--- a/kag/common/llm/vllm_client.py
+++ b/kag/common/llm/vllm_client.py
@@ -10,31 +10,10 @@
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
-import os
-import ast
-import re
-import json
-import time
-import uuid
-import html
-from binascii import b2a_hex
-from datetime import datetime
-from pathlib import Path
-from typing import Union, Dict, List, Any
-from urllib import request
-from collections import defaultdict
-from openai import OpenAI
+import json
import logging
-
import requests
-import traceback
-from Crypto.Cipher import AES
-from requests import RequestException
-
-from kag.common import arks_pb2
-from kag.common.base.prompt_op import PromptOp
-
from kag.common.llm.llm_client import LLMClient
diff --git a/kag/common/registry/__init__.py b/kag/common/registry/__init__.py
index 87617fc5..3ab66aed 100644
--- a/kag/common/registry/__init__.py
+++ b/kag/common/registry/__init__.py
@@ -1,4 +1,15 @@
# -*- coding: utf-8 -*-
+# Copyright 2023 OpenSPG Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
+# in compliance with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software distributed under the License
+# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
+# or implied.
+
from kag.common.registry.registrable import Registrable, ConfigurationError
from kag.common.registry.lazy import Lazy
from kag.common.registry.functor import Functor
diff --git a/kag/common/registry/functor.py b/kag/common/registry/functor.py
index b3d8259e..1051cf3e 100644
--- a/kag/common/registry/functor.py
+++ b/kag/common/registry/functor.py
@@ -1,4 +1,15 @@
# -*- coding: utf-8 -*-
+# Copyright 2023 OpenSPG Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
+# in compliance with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software distributed under the License
+# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
+# or implied.
+
import logging
import collections
from kag.common.registry.registrable import (
@@ -8,7 +19,8 @@
create_kwargs,
)
from types import FunctionType
-from typing import TypeVar, Type, Union, Callable, Dict, cast
+
+from typing import Type, Union, Callable, Dict, cast
from functools import partial
from pyhocon import ConfigTree, ConfigFactory
diff --git a/kag/common/registry/lazy.py b/kag/common/registry/lazy.py
index eb8fcc1b..1b3f281e 100644
--- a/kag/common/registry/lazy.py
+++ b/kag/common/registry/lazy.py
@@ -1,4 +1,14 @@
# -*- coding: utf-8 -*-
+# Copyright 2023 OpenSPG Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
+# in compliance with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software distributed under the License
+# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
+# or implied.
import inspect
from pyhocon import ConfigTree
diff --git a/kag/common/registry/registrable.py b/kag/common/registry/registrable.py
index 7596de9e..f42099b5 100644
--- a/kag/common/registry/registrable.py
+++ b/kag/common/registry/registrable.py
@@ -1,11 +1,24 @@
# -*- coding: utf-8 -*-
+# Copyright 2023 OpenSPG Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
+# in compliance with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software distributed under the License
+# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
+# or implied.
+
import inspect
import importlib
import logging
import functools
import collections
import traceback
-from pyhocon import ConfigList, ConfigTree, ConfigFactory
+
+from pathlib import Path
+from pyhocon import ConfigTree, ConfigFactory
from pyhocon.exceptions import ConfigMissingException
from copy import deepcopy
from collections import defaultdict
@@ -479,6 +492,10 @@ def constructor(**kwargs):
class Registrable:
+ """
+ This class is motivated by the original work:
+ https://github.com/allenai/allennlp/blob/main/allennlp/common/from_params.py
+ """
_registry: Dict[Type, Dict[str, Tuple[Type, Optional[str]]]] = defaultdict(dict)
default_implementation: Optional[str] = None
@@ -721,7 +738,7 @@ def from_config(
)
setattr(instant, "__original_parameters__", original_params)
# if constructor takes kwargs, they can't be infered from constructor. Therefore we should record
- # which attrs are created by kwargs to correctly restore the configs by `to_params`.
+ # which attrs are created by kwargs to correctly restore the configs by `to_config`.
if accepts_kwargs:
remaining_kwargs = set(params)
params.clear()
@@ -733,15 +750,15 @@ def from_config(
return instant
- def _to_params(self, v):
+ def _to_config(self, v):
"""iteratively convert v to params"""
v_type = type(v)
- if hasattr(v, "to_params"):
- params = v.to_params()
+ if hasattr(v, "to_config"):
+ params = v.to_config()
elif v_type in {collections.abc.Mapping, Mapping, Dict, dict}:
params = {}
for subk, subv in v.items():
- params[subk] = self._to_params(subv)
+ params[subk] = self._to_config(subv)
elif v_type in {
collections.abc.Iterable,
Iterable,
@@ -752,12 +769,12 @@ def _to_params(self, v):
Set,
set,
}:
- params = [self._to_params(x) for x in v]
+ params = [self._to_config(x) for x in v]
else:
params = v
return params
- def to_params(self) -> ConfigTree:
+ def to_config(self) -> ConfigTree:
"""
convert object back to params.
Note: If the object is not instantiated by from_config, we can't transfer it back.
@@ -780,20 +797,20 @@ def to_params(self) -> ConfigTree:
# attrs of instance itself.
if hasattr(self, k):
v = getattr(self, k)
- if hasattr(v, "to_params"):
- conf = v.to_params()
+ if hasattr(v, "to_config"):
+ conf = v.to_config()
else:
- conf = self._to_params(v)
+ conf = self._to_config(v)
config[k] = conf
return ConfigFactory.from_dict(config)
- def to_params_with_constructor(self, constructor: str = None) -> ConfigTree:
+ def to_config_with_constructor(self, constructor: str = None) -> ConfigTree:
"""convert object back to params.
- Different from `to_params`, this function can convert objects that are not instantiated by `from_config`,
+ Different from `to_config`, this function can convert objects that are not instantiated by `from_config`,
but sometimes it may not give correct result.
For example, suppose the class has more than one constructor, and we instantiated by constructorA but convert
it to params of constructorB. So use it with caution.
- One should always use `from_config` to instantiate the object and `to_params` to convert it back to params.
+ One should always use `from_config` to instantiate the object and `to_config` to convert it back to params.
"""
config = {}
@@ -816,10 +833,10 @@ def to_params_with_constructor(self, constructor: str = None) -> ConfigTree:
# get param instance from class attr
v_instance = getattr(self, v.name, None)
- if hasattr(v_instance, "to_params"):
- conf = v_instance.to_params()
+ if hasattr(v_instance, "to_config"):
+ conf = v_instance.to_config()
else:
- conf = self._to_params(v_instance)
+ conf = self._to_config(v_instance)
config[k] = conf
if accepts_kwargs:
for k in self.__from_config_kwargs__:
diff --git a/kag/common/registry/utils.py b/kag/common/registry/utils.py
index c2f26138..7255c27f 100644
--- a/kag/common/registry/utils.py
+++ b/kag/common/registry/utils.py
@@ -1,4 +1,15 @@
# -*- coding: utf-8 -*-
+# Copyright 2023 OpenSPG Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
+# in compliance with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software distributed under the License
+# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
+# or implied.
+
import os
import sys
import importlib
diff --git a/kag/common/reranker/__init__.py b/kag/common/reranker/__init__.py
index a945c8dd..4d9914d1 100644
--- a/kag/common/reranker/__init__.py
+++ b/kag/common/reranker/__init__.py
@@ -13,7 +13,4 @@
from kag.common.reranker.bge_reranker import BGEReranker
from kag.common.reranker.reranker import Reranker
-__all__ = [
- "BGEReranker",
- "Reranker"
-]
+__all__ = ["BGEReranker", "Reranker"]
diff --git a/kag/common/reranker/bge_reranker.py b/kag/common/reranker/bge_reranker.py
index 45a63615..e74cb022 100644
--- a/kag/common/reranker/bge_reranker.py
+++ b/kag/common/reranker/bge_reranker.py
@@ -20,60 +20,61 @@
def rrf_score(length, r: int = 1):
"""
Calculates the RRF (Recursive Robust Function) scores.
-
+
This function generates a score sequence of the given length, where each score is calculated based on the index according to the formula 1/(r+i).
RRF is a method used in information retrieval and data analysis, and this function provides a way to generate weights based on document indices.
-
+
Parameters:
length: int, the length of the score sequence, i.e., the number of scores to generate.
r: int, optional, default is 1. Controls the starting index of the scores. Increasing the value of r shifts the emphasis towards later scores.
-
+
Returns:
numpy.ndarray, an array containing the scores calculated according to the given formula.
"""
return np.array([1 / (r + i) for i in range(length)])
-
class BGEReranker(Reranker):
"""
BGEReranker class is a subclass of Reranker that reranks given queries and passages.
-
+
This class uses the FlagReranker model from FlagEmbedding to score and reorder passages.
-
+
Args:
model_path (str): Path to the FlagReranker model.
use_fp16 (bool): Whether to use half-precision floating-point numbers for computation. Default is True.
"""
+
def __init__(self, model_path: str, use_fp16: bool = True):
from FlagEmbedding import FlagReranker
+
self.model_path = model_path
self.model = FlagReranker(self.model_path, use_fp16=use_fp16)
def rerank(self, queries: List[str], passages: List[str]):
"""
Reranks given queries and passages.
-
+
Args:
queries (List[str]): List of queries.
passages (List[str]): List of passages, where each passage is a string.
-
+
Returns:
new_passages (List[str]): List of passages after reranking.
"""
# Calculate initial ranking scores for passages
rank_scores = rrf_score(len(passages))
passage_scores = np.zeros(len(passages)) + rank_scores
-
+
# For each query, compute passage scores using the model and accumulate them
for query in queries:
scores = self.model.compute_score([[query, x] for x in passages])
sorted_idx = np.argsort(-np.array(scores))
for rank, passage_id in enumerate(sorted_idx):
passage_scores[passage_id] += rank_scores[rank]
-
+
# Perform final sorting of passages based on accumulated scores
merged_sorted_idx = np.argsort(-passage_scores)
-
+
new_passages = [passages[x] for x in merged_sorted_idx]
- return new_passages
\ No newline at end of file
+ return new_passages
diff --git a/kag/common/reranker/reranker.py b/kag/common/reranker/reranker.py
index 69b97a25..92e6d968 100644
--- a/kag/common/reranker/reranker.py
+++ b/kag/common/reranker/reranker.py
@@ -43,4 +43,4 @@ def rerank(self, queries: List[str], passages: List[str]):
The function is currently not implemented and raises an exception to indicate this.
"""
- raise NotImplementedError("rerank not implemented yet.")
\ No newline at end of file
+ raise NotImplementedError("rerank not implemented yet.")
diff --git a/kag/common/retriever/__init__.py b/kag/common/retriever/__init__.py
index 05156aa5..450fe9eb 100644
--- a/kag/common/retriever/__init__.py
+++ b/kag/common/retriever/__init__.py
@@ -12,7 +12,5 @@
from kag.common.retriever.kag_retriever import DefaultRetriever
from kag.common.retriever.retriever import Retriever
-__all__ = [
- "DefaultRetriever",
- "Retriever"
-]
+
+__all__ = ["DefaultRetriever", "Retriever"]
diff --git a/kag/common/retriever/kag_retriever.py b/kag/common/retriever/kag_retriever.py
index 9b40a3b8..30a1992c 100644
--- a/kag/common/retriever/kag_retriever.py
+++ b/kag/common/retriever/kag_retriever.py
@@ -50,7 +50,9 @@ def __init__(self, **kwargs):
self._init_search()
- self.ner_prompt = PromptOp.load(self.biz_scene, "question_ner")(language=self.language, project_id=self.project_id)
+ self.ner_prompt = PromptOp.load(self.biz_scene, "question_ner")(
+ language=self.language, project_id=self.project_id
+ )
self.std_prompt = PromptOp.load(self.biz_scene, "std")(language=self.language)
self.pagerank_threshold = 0.9
@@ -60,6 +62,7 @@ def __init__(self, **kwargs):
self.reranker_model_path = os.getenv("KAG_RETRIEVER_RERANKER_MODEL_PATH")
if self.reranker_model_path:
from kag.common.reranker.reranker import BGEReranker
+
self.reranker = BGEReranker(self.reranker_model_path, use_fp16=True)
else:
self.reranker = None
@@ -70,18 +73,15 @@ def _init_search(self):
self.sc: SearchClient = SearchClient(self.host_addr, self.project_id)
vectorizer_config = eval(os.getenv("KAG_VECTORIZER", "{}"))
if self.host_addr and self.project_id:
- config = ProjectClient(host_addr=self.host_addr, project_id=self.project_id).get_config(self.project_id)
+ config = ProjectClient(
+ host_addr=self.host_addr, project_id=self.project_id
+ ).get_config(self.project_id)
vectorizer_config.update(config.get("vectorizer", {}))
- self.vectorizer = Vectorizer.from_config(
- vectorizer_config
- )
+ self.vectorizer = Vectorizer.from_config(vectorizer_config)
self.reason: ReasonerClient = ReasonerClient(self.host_addr, self.project_id)
self.graph_algo = GraphAlgoClient(self.host_addr, self.project_id)
-
-
-
@retry(stop=stop_after_attempt(3))
def named_entity_recognition(self, query: str):
"""
@@ -119,7 +119,9 @@ def named_entity_standardization(self, query: str, entities: List[Dict]):
)
@staticmethod
- def append_official_name(source_entities: List[Dict], entities_with_official_name: List[Dict]):
+ def append_official_name(
+ source_entities: List[Dict], entities_with_official_name: List[Dict]
+ ):
"""
Appends official names to entities.
@@ -162,13 +164,11 @@ def calculate_sim_scores(self, query: str, doc_nums: int):
label=self.schema_util.get_label_within_prefix(CHUNK_TYPE),
property_key="content",
query_vector=query_vector,
- topk=doc_nums
+ topk=doc_nums,
)
scores = {item["node"]["id"]: item["score"] for item in top_k}
except Exception as e:
- logger.error(
- f"run calculate_sim_scores failed, info: {e}", exc_info=True
- )
+ logger.error(f"run calculate_sim_scores failed, info: {e}", exc_info=True)
return scores
def calculate_pagerank_scores(self, start_nodes: List[Dict]):
@@ -190,12 +190,12 @@ def calculate_pagerank_scores(self, start_nodes: List[Dict]):
if len(start_nodes) != 0:
try:
scores = self.graph_algo.calculate_pagerank_scores(
- self.schema_util.get_label_within_prefix(CHUNK_TYPE),
- start_nodes
+ self.schema_util.get_label_within_prefix(CHUNK_TYPE), start_nodes
)
except Exception as e:
logger.error(
- f"run calculate_pagerank_scores failed, info: {e}, start_nodes: {start_nodes}", exc_info=True
+ f"run calculate_pagerank_scores failed, info: {e}, start_nodes: {start_nodes}",
+ exc_info=True,
)
return scores
@@ -260,7 +260,9 @@ def match_entities(self, queries: Dict[str, str], top_k: int = 1):
logger.info(f"No entities matched for {queries}")
return matched_entities, matched_entities_scores
- def calculate_combined_scores(self, sim_scores: Dict[str, float], pagerank_scores: Dict[str, float]):
+ def calculate_combined_scores(
+ self, sim_scores: Dict[str, float], pagerank_scores: Dict[str, float]
+ ):
"""
Calculate and return the combined scores that integrate both similarity scores and PageRank scores.
@@ -271,6 +273,7 @@ def calculate_combined_scores(self, sim_scores: Dict[str, float], pagerank_score
Returns:
Dict[str, float]: A dictionary containing the combined scores, where keys are identifiers and values are the combined scores.
"""
+
def min_max_normalize(x):
if len(x) == 0:
return []
@@ -283,17 +286,24 @@ def min_max_normalize(x):
for key in all_keys:
sim_scores.setdefault(key, 0.0)
pagerank_scores.setdefault(key, 0.0)
- sim_scores = dict(zip(sim_scores.keys(), min_max_normalize(
- np.array(list(sim_scores.values()))
- )))
- pagerank_scores = dict(zip(pagerank_scores.keys(), min_max_normalize(
- np.array(list(pagerank_scores.values()))
- )))
+ sim_scores = dict(
+ zip(
+ sim_scores.keys(),
+ min_max_normalize(np.array(list(sim_scores.values()))),
+ )
+ )
+ pagerank_scores = dict(
+ zip(
+ pagerank_scores.keys(),
+ min_max_normalize(np.array(list(pagerank_scores.values()))),
+ )
+ )
combined_scores = dict()
for key in pagerank_scores.keys():
- combined_scores[key] = (sim_scores[key] * (1 - self.pagerank_weight) +
- pagerank_scores[key] * self.pagerank_weight
- )
+ combined_scores[key] = (
+ sim_scores[key] * (1 - self.pagerank_weight)
+ + pagerank_scores[key] * self.pagerank_weight
+ )
return combined_scores
def recall_docs(self, query: str, top_k: int = 5, **kwargs):
@@ -343,7 +353,9 @@ def recall_docs(self, query: str, top_k: int = 5, **kwargs):
elif matched_entities and np.min(matched_scores) > self.pagerank_threshold:
combined_scores = pagerank_scores
else:
- combined_scores = self.calculate_combined_scores(sim_scores, pagerank_scores)
+ combined_scores = self.calculate_combined_scores(
+ sim_scores, pagerank_scores
+ )
sorted_scores = sorted(
combined_scores.items(), key=lambda item: item[1], reverse=True
)
@@ -375,12 +387,19 @@ def get_all_docs_by_id(self, query: str, doc_ids: list, top_k: int):
else:
doc_score = doc_ids[doc_id]
counter += 1
- node = self.reason.query_node(label=self.schema_util.get_label_within_prefix(CHUNK_TYPE), id_value=doc_id)
+ node = self.reason.query_node(
+ label=self.schema_util.get_label_within_prefix(CHUNK_TYPE),
+ id_value=doc_id,
+ )
node_dict = dict(node.items())
- matched_docs.append(f"#{node_dict['name']}#{node_dict['content']}#{doc_score}")
- hits_docs.add(node_dict['name'])
+ matched_docs.append(
+ f"#{node_dict['name']}#{node_dict['content']}#{doc_score}"
+ )
+ hits_docs.add(node_dict["name"])
try:
- text_matched = self.sc.search_text(query, [self.schema_util.get_label_within_prefix(CHUNK_TYPE)], topk=1)
+ text_matched = self.sc.search_text(
+ query, [self.schema_util.get_label_within_prefix(CHUNK_TYPE)], topk=1
+ )
if text_matched:
for item in text_matched:
title = item["node"]["name"]
@@ -389,7 +408,9 @@ def get_all_docs_by_id(self, query: str, doc_ids: list, top_k: int):
matched_docs.pop()
else:
logger.warning(f"{query} matched docs is empty")
- matched_docs.append(f'#{item["node"]["name"]}#{item["node"]["content"]}#{item["score"]}')
+ matched_docs.append(
+ f'#{item["node"]["name"]}#{item["node"]["content"]}#{item["score"]}'
+ )
break
except Exception as e:
logger.warning(f"{query} query chunk failed: {e}", exc_info=True)
diff --git a/kag/common/retriever/retriever.py b/kag/common/retriever/retriever.py
index e125248b..232c5fc7 100644
--- a/kag/common/retriever/retriever.py
+++ b/kag/common/retriever/retriever.py
@@ -117,7 +117,7 @@ def index(self, items: Union[Item, Iterable[Item]]) -> None:
@abstractmethod
def retrieve(
- self, queries: Union[str, Iterable[str]], top_k: int = 10
+ self, queries: Union[str, Iterable[str]], top_k: int = 10
) -> Union[RetrievalResult, Iterable[RetrievalResult]]:
"""
Retrieve items for the given query or queries.
@@ -130,5 +130,3 @@ def retrieve(
"""
message = "abstract method retrieve is not implemented"
raise NotImplementedError(message)
-
-
diff --git a/kag/common/utils.py b/kag/common/utils.py
index 2a6f5ac0..b6891952 100644
--- a/kag/common/utils.py
+++ b/kag/common/utils.py
@@ -12,7 +12,7 @@
import re
import sys
import json
-from typing import Type,Tuple
+from typing import Type, Tuple
import inspect
import os
from pathlib import Path
@@ -70,6 +70,7 @@ def append_python_path(path: str) -> bool:
return True
return False
+
def render_template(
root_dir: Union[str, os.PathLike], file: Union[str, os.PathLike], **kwargs: Any
) -> None:
@@ -113,7 +114,7 @@ def copyfile(src: Path, dst: Path, **kwargs):
_make_writable(dst)
if dst.suffix != ".tmpl":
return
- render_template('/', dst, **kwargs)
+ render_template("/", dst, **kwargs)
def remove_files_except(path, file, new_file):
@@ -194,8 +195,7 @@ def processing_phrases(phrase):
def to_camel_case(phrase):
s = processing_phrases(phrase).replace(" ", "_")
return "".join(
- word.capitalize() if i != 0 else word
- for i, word in enumerate(s.split("_"))
+ word.capitalize() if i != 0 else word for i, word in enumerate(s.split("_"))
)
diff --git a/kag/common/vectorizer/local_bge_m3_vectorizer.py b/kag/common/vectorizer/local_bge_m3_vectorizer.py
deleted file mode 100644
index b7d21621..00000000
--- a/kag/common/vectorizer/local_bge_m3_vectorizer.py
+++ /dev/null
@@ -1,23 +0,0 @@
-# Copyright 2023 OpenSPG Authors
-#
-# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
-# in compliance with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software distributed under the License
-# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
-# or implied.
-
-import io
-import os
-import threading
-import tarfile
-import requests
-from typing import Any, Union, Iterable, Dict
-from kag.common.vectorizer.vectorizer import Vectorizer
-
-
-EmbeddingVector = Iterable[float]
-
-LOCAL_MODEL_MAP = {}
diff --git a/kag/common/vectorizer/local_bge_vectorizer.py b/kag/common/vectorizer/local_bge_vectorizer.py
index b1366090..a3b0ceca 100644
--- a/kag/common/vectorizer/local_bge_vectorizer.py
+++ b/kag/common/vectorizer/local_bge_vectorizer.py
@@ -9,11 +9,9 @@
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
-import io
import os
import logging
-import threading
-from typing import Any, Union, Iterable, Dict
+from typing import Union, Iterable
from kag.common.vectorizer.vectorizer import Vectorizer, EmbeddingVector
logger = logging.getLogger()
@@ -75,7 +73,7 @@ def __init__(
def _load_model(self, path):
# We need to import sklearn at first, otherwise sklearn will fail on macOS with m chip.
- import sklearn
+ import sklearn # noqa
from FlagEmbedding import FlagModel
logger.info(
@@ -138,7 +136,8 @@ def __init__(
def _load_model(self, path):
# We need to import sklearn at first, otherwise sklearn will fail on macOS with m chip.
- import sklearn
+
+ import sklearn # noqa
from FlagEmbedding import BGEM3FlagModel
logger.info(f"Loading BGEM3FlagModel from {path!r}")
diff --git a/kag/common/vectorizer/openai_vectorizer.py b/kag/common/vectorizer/openai_vectorizer.py
index f6d82ba6..009c6943 100644
--- a/kag/common/vectorizer/openai_vectorizer.py
+++ b/kag/common/vectorizer/openai_vectorizer.py
@@ -9,7 +9,7 @@
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
-from typing import Any, Union, Iterable, Dict
+from typing import Union, Iterable
from openai import OpenAI
from kag.common.vectorizer.vectorizer import Vectorizer
diff --git a/kag/common/vectorizer/vectorizer.py b/kag/common/vectorizer/vectorizer.py
index cad34ac9..60644b5d 100644
--- a/kag/common/vectorizer/vectorizer.py
+++ b/kag/common/vectorizer/vectorizer.py
@@ -11,14 +11,13 @@
import io
import os
-import json
import tarfile
import requests
import logging
-from pathlib import Path
+
from kag.common.registry import Registrable
-from typing import Any, Union, Iterable, Optional, Dict
+from typing import Union, Iterable
EmbeddingVector = Iterable[float]
logger = logging.getLogger()
diff --git a/kag/examples/2wiki/builder/__init__.py b/kag/examples/2wiki/builder/__init__.py
index 94be39bc..7a018e7c 100644
--- a/kag/examples/2wiki/builder/__init__.py
+++ b/kag/examples/2wiki/builder/__init__.py
@@ -11,4 +11,4 @@
"""
Builder Dir.
-"""
\ No newline at end of file
+"""
diff --git a/kag/examples/2wiki/builder/data/__init__.py b/kag/examples/2wiki/builder/data/__init__.py
index 6a8637b9..59bacd4d 100644
--- a/kag/examples/2wiki/builder/data/__init__.py
+++ b/kag/examples/2wiki/builder/data/__init__.py
@@ -11,4 +11,4 @@
"""
Place the files to be used for building the index in this directory.
-"""
\ No newline at end of file
+"""
diff --git a/kag/examples/2wiki/builder/indexer.py b/kag/examples/2wiki/builder/indexer.py
index e85ed3dd..a302a651 100644
--- a/kag/examples/2wiki/builder/indexer.py
+++ b/kag/examples/2wiki/builder/indexer.py
@@ -47,8 +47,8 @@ def invoke(self, input: str, **kwargs) -> List[Output]:
for idx, item in enumerate(corpus):
chunk = Chunk(
id=str(idx),
- name=item['title'],
- content=item['text'],
+ name=item["title"],
+ content=item["text"],
)
chunks.append(chunk)
return chunks
@@ -71,10 +71,8 @@ def buildKB(corpusFilePath):
logger.info(f"\n\nbuildKB successfully for {corpusFilePath}\n\n")
-if __name__ == '__main__':
+if __name__ == "__main__":
filePath = "./data/2wiki_sub_corpus.json"
# filePath = "./data/2wiki_corpus.json"
- corpusFilePath = os.path.join(
- os.path.abspath(os.path.dirname(__file__)), filePath
- )
+ corpusFilePath = os.path.join(os.path.abspath(os.path.dirname(__file__)), filePath)
buildKB(corpusFilePath)
diff --git a/kag/examples/2wiki/builder/prompt/__init__.py b/kag/examples/2wiki/builder/prompt/__init__.py
index 247bb44c..ba7d5d56 100644
--- a/kag/examples/2wiki/builder/prompt/__init__.py
+++ b/kag/examples/2wiki/builder/prompt/__init__.py
@@ -11,4 +11,4 @@
"""
Place the prompts to be used for building the index in this directory.
-"""
\ No newline at end of file
+"""
diff --git a/kag/examples/2wiki/builder/prompt/ner.py b/kag/examples/2wiki/builder/prompt/ner.py
index cf5aa897..79c022e5 100644
--- a/kag/examples/2wiki/builder/prompt/ner.py
+++ b/kag/examples/2wiki/builder/prompt/ner.py
@@ -85,9 +85,7 @@ class OpenIENERPrompt(PromptOp):
template_zh = template_en
- def __init__(
- self, language: Optional[str] = "en", **kwargs
- ):
+ def __init__(self, language: Optional[str] = "en", **kwargs):
super().__init__(language, **kwargs)
self.schema = SchemaClient(project_id=self.project_id).extract_types()
self.template = Template(self.template).safe_substitute(schema=self.schema)
diff --git a/kag/examples/2wiki/kag_config.cfg b/kag/examples/2wiki/kag_config.cfg
index 2c91dda0..e3a30c63 100644
--- a/kag/examples/2wiki/kag_config.cfg
+++ b/kag/examples/2wiki/kag_config.cfg
@@ -4,14 +4,14 @@ host_addr = http://127.0.0.1:8887
id = 11
[vectorizer]
-vectorizer = kag.common.vectorizer.OpenAIVectorizer
+type = openai
model = bge-m3
api_key = EMPTY
base_url = http://127.0.0.1:11434/v1
vector_dimensions = 1024
[llm]
-client_type = maas
+type = maas
base_url = https://api.deepseek.com/
api_key = put your deepseek api key here
model = deepseek-chat
diff --git a/kag/examples/2wiki/reasoner/__init__.py b/kag/examples/2wiki/reasoner/__init__.py
index a0c4032b..8b8a3c91 100644
--- a/kag/examples/2wiki/reasoner/__init__.py
+++ b/kag/examples/2wiki/reasoner/__init__.py
@@ -17,4 +17,4 @@
MATCH (s:DEFAULT.Company)
RETURN s.id, s.address
```
-"""
\ No newline at end of file
+"""
diff --git a/kag/examples/2wiki/schema/__init__.py b/kag/examples/2wiki/schema/__init__.py
index ef3dde6d..8ac86acc 100644
--- a/kag/examples/2wiki/schema/__init__.py
+++ b/kag/examples/2wiki/schema/__init__.py
@@ -15,4 +15,4 @@
You can execute `kag schema commit` to commit your schema to SPG server.
-"""
\ No newline at end of file
+"""
diff --git a/kag/examples/2wiki/solver/evaFor2wiki.py b/kag/examples/2wiki/solver/evaFor2wiki.py
index d6a533b1..ec5983a4 100644
--- a/kag/examples/2wiki/solver/evaFor2wiki.py
+++ b/kag/examples/2wiki/solver/evaFor2wiki.py
@@ -18,6 +18,7 @@ class EvaFor2wiki:
"""
init for kag client
"""
+
def __init__(self, configFilePath):
self.configFilePath = configFilePath
init_kag_config(self.configFilePath)
@@ -25,6 +26,7 @@ def __init__(self, configFilePath):
"""
qa from knowledge base,
"""
+
def qa(self, query):
# CA
resp = SolverPipeline()
@@ -37,6 +39,7 @@ def qa(self, query):
parallel qa from knowledge base
and getBenchmarks(em, f1, answer_similarity)
"""
+
def parallelQaAndEvaluate(
self, qaFilePath, resFilePath, threadNum=1, upperLimit=10
):
@@ -115,9 +118,7 @@ def process_sample(data):
filePath = "./data/2wiki_qa_sub.json"
# filePath = "./data/2wiki_qa.json"
- qaFilePath = os.path.join(
- os.path.abspath(os.path.dirname(__file__)), filePath
- )
+ qaFilePath = os.path.join(os.path.abspath(os.path.dirname(__file__)), filePath)
start_time = time.time()
resFilePath = os.path.join(
@@ -126,7 +127,7 @@ def process_sample(data):
total_metrics = evalObj.parallelQaAndEvaluate(
qaFilePath, resFilePath, threadNum=20, upperLimit=1000
)
- total_metrics['cost'] = time.time() - start_time
+ total_metrics["cost"] = time.time() - start_time
with open(f"./2wiki_metrics_{start_time}.json", "w") as f:
json.dump(total_metrics, f)
diff --git a/kag/examples/2wiki/solver/prompt/__init__.py b/kag/examples/2wiki/solver/prompt/__init__.py
index dadd42a3..dfa931cd 100644
--- a/kag/examples/2wiki/solver/prompt/__init__.py
+++ b/kag/examples/2wiki/solver/prompt/__init__.py
@@ -11,4 +11,4 @@
"""
Place the prompts to be used for solving problems in this directory.
-"""
\ No newline at end of file
+"""
diff --git a/kag/examples/2wiki/solver/prompt/resp_generator.py b/kag/examples/2wiki/solver/prompt/resp_generator.py
index 70e96cc9..fa70249c 100644
--- a/kag/examples/2wiki/solver/prompt/resp_generator.py
+++ b/kag/examples/2wiki/solver/prompt/resp_generator.py
@@ -9,12 +9,14 @@
class RespGenerator(PromptOp):
- template_zh = "基于给定的引用信息回答问题。" \
- "\n只输出答案,不需要输出额外的信息。" \
- "\n给定的引用信息:'$memory'\n问题:'$instruction'"
- template_en = "Answer the question based on the given reference." \
- "\nOnly give me the answer and do not output any other words." \
- "\nThe following are given reference:'$memory'\nQuestion: '$instruction'"
+ template_zh = (
+ "基于给定的引用信息回答问题。" "\n只输出答案,不需要输出额外的信息。" "\n给定的引用信息:'$memory'\n问题:'$instruction'"
+ )
+ template_en = (
+ "Answer the question based on the given reference."
+ "\nOnly give me the answer and do not output any other words."
+ "\nThe following are given reference:'$memory'\nQuestion: '$instruction'"
+ )
def __init__(self, language: str):
super().__init__(language)
@@ -24,5 +26,5 @@ def template_variables(self) -> List[str]:
return ["memory", "instruction"]
def parse_response(self, response: str, **kwargs):
- logger.debug('推理器判别:{}'.format(response))
+ logger.debug("推理器判别:{}".format(response))
return response
diff --git a/kag/examples/README.md b/kag/examples/README.md
index 6587ce7f..2ce64468 100644
--- a/kag/examples/README.md
+++ b/kag/examples/README.md
@@ -11,14 +11,14 @@ Create your new knext project from knext cli tool.
host_addr = http://localhost:8887
[vectorizer]
- vectorizer = kag.common.vectorizer.OpenAIVectorizer
+ type = openai
model = bge-m3
api_key = EMPTY
base_url = http://127.0.0.1:11434/v1
vector_dimensions = 1024
[llm]
- client_type = ollama
+ type = ollama
base_url = http://localhost:11434/api/generate
model = llama3.1
diff --git a/kag/examples/example.cfg b/kag/examples/example.cfg
index c8d2cdc4..371b1a03 100644
--- a/kag/examples/example.cfg
+++ b/kag/examples/example.cfg
@@ -4,15 +4,15 @@ host_addr = http://localhost:8887
# vectorizer loaded by ollma
[vectorizer]
-vectorizer = kag.common.vectorizer.OpenAIVectorizer
+type = openai
model = bge-m3
api_key = EMPTY
base_url = http://127.0.0.1:11434/v1
vector_dimensions = 1024
[llm]
-client_type = maas
-base_url = https://api.deepseek.com
+type = maas
+base_url = https://api.deepseek.com/
api_key = put your deepseek api key here
model = deepseek-chat
diff --git a/kag/examples/hotpotqa/builder/__init__.py b/kag/examples/hotpotqa/builder/__init__.py
index 94be39bc..7a018e7c 100644
--- a/kag/examples/hotpotqa/builder/__init__.py
+++ b/kag/examples/hotpotqa/builder/__init__.py
@@ -11,4 +11,4 @@
"""
Builder Dir.
-"""
\ No newline at end of file
+"""
diff --git a/kag/examples/hotpotqa/builder/data/__init__.py b/kag/examples/hotpotqa/builder/data/__init__.py
index 6a8637b9..59bacd4d 100644
--- a/kag/examples/hotpotqa/builder/data/__init__.py
+++ b/kag/examples/hotpotqa/builder/data/__init__.py
@@ -11,4 +11,4 @@
"""
Place the files to be used for building the index in this directory.
-"""
\ No newline at end of file
+"""
diff --git a/kag/examples/hotpotqa/builder/indexer.py b/kag/examples/hotpotqa/builder/indexer.py
index 68401c75..dcf08080 100644
--- a/kag/examples/hotpotqa/builder/indexer.py
+++ b/kag/examples/hotpotqa/builder/indexer.py
@@ -71,11 +71,9 @@ def buildKB(corpusFilePath):
logger.info(f"\n\nbuildKB successfully for {corpusFilePath}\n\n")
-if __name__ == '__main__':
+if __name__ == "__main__":
filePath = "./data/hotpotqa_sub_corpus.json"
# filePath = "./data/hotpotqa_train_corpus.json"
- corpusFilePath = os.path.join(
- os.path.abspath(os.path.dirname(__file__)), filePath
- )
+ corpusFilePath = os.path.join(os.path.abspath(os.path.dirname(__file__)), filePath)
buildKB(corpusFilePath)
diff --git a/kag/examples/hotpotqa/builder/prompt/__init__.py b/kag/examples/hotpotqa/builder/prompt/__init__.py
index 247bb44c..ba7d5d56 100644
--- a/kag/examples/hotpotqa/builder/prompt/__init__.py
+++ b/kag/examples/hotpotqa/builder/prompt/__init__.py
@@ -11,4 +11,4 @@
"""
Place the prompts to be used for building the index in this directory.
-"""
\ No newline at end of file
+"""
diff --git a/kag/examples/hotpotqa/builder/prompt/ner.py b/kag/examples/hotpotqa/builder/prompt/ner.py
index cf5aa897..79c022e5 100644
--- a/kag/examples/hotpotqa/builder/prompt/ner.py
+++ b/kag/examples/hotpotqa/builder/prompt/ner.py
@@ -85,9 +85,7 @@ class OpenIENERPrompt(PromptOp):
template_zh = template_en
- def __init__(
- self, language: Optional[str] = "en", **kwargs
- ):
+ def __init__(self, language: Optional[str] = "en", **kwargs):
super().__init__(language, **kwargs)
self.schema = SchemaClient(project_id=self.project_id).extract_types()
self.template = Template(self.template).safe_substitute(schema=self.schema)
diff --git a/kag/examples/hotpotqa/kag_config.cfg b/kag/examples/hotpotqa/kag_config.cfg
index c46b019d..0b442631 100644
--- a/kag/examples/hotpotqa/kag_config.cfg
+++ b/kag/examples/hotpotqa/kag_config.cfg
@@ -4,14 +4,14 @@ host_addr = http://127.0.0.1:8887
id = 4
[vectorizer]
-vectorizer = kag.common.vectorizer.OpenAIVectorizer
+type = openai
model = bge-m3
api_key = EMPTY
base_url = http://127.0.0.1:11434/v1
vector_dimensions = 1024
[llm]
-client_type = maas
+type = maas
base_url = https://api.deepseek.com/
api_key = put your deepseek api key here
model = deepseek-chat
diff --git a/kag/examples/hotpotqa/reasoner/__init__.py b/kag/examples/hotpotqa/reasoner/__init__.py
index a0c4032b..8b8a3c91 100644
--- a/kag/examples/hotpotqa/reasoner/__init__.py
+++ b/kag/examples/hotpotqa/reasoner/__init__.py
@@ -17,4 +17,4 @@
MATCH (s:DEFAULT.Company)
RETURN s.id, s.address
```
-"""
\ No newline at end of file
+"""
diff --git a/kag/examples/hotpotqa/schema/__init__.py b/kag/examples/hotpotqa/schema/__init__.py
index ef3dde6d..8ac86acc 100644
--- a/kag/examples/hotpotqa/schema/__init__.py
+++ b/kag/examples/hotpotqa/schema/__init__.py
@@ -15,4 +15,4 @@
You can execute `kag schema commit` to commit your schema to SPG server.
-"""
\ No newline at end of file
+"""
diff --git a/kag/examples/hotpotqa/solver/evaForHotpotqa.py b/kag/examples/hotpotqa/solver/evaForHotpotqa.py
index 041800d4..e7a31dc3 100644
--- a/kag/examples/hotpotqa/solver/evaForHotpotqa.py
+++ b/kag/examples/hotpotqa/solver/evaForHotpotqa.py
@@ -17,6 +17,7 @@ class EvaForHotpotqa:
"""
init for kag client
"""
+
def __init__(self):
pass
@@ -34,7 +35,7 @@ def qa(self, query):
"""
def parallelQaAndEvaluate(
- self, qaFilePath, resFilePath, threadNum=1, upperLimit=10, run_failed=False
+ self, qaFilePath, resFilePath, threadNum=1, upperLimit=10, run_failed=False
):
def process_sample(data):
try:
@@ -45,8 +46,8 @@ def process_sample(data):
if "prediction" not in sample.keys():
prediction, traceLog = self.qa(question)
else:
- prediction = sample['prediction']
- traceLog = sample['traceLog']
+ prediction = sample["prediction"]
+ traceLog = sample["traceLog"]
evaObj = Evaluate()
metrics = evaObj.getBenchMark([prediction], [gold])
@@ -72,9 +73,9 @@ def process_sample(data):
for sample_idx, sample in enumerate(qaList[:upperLimit])
]
for future in tqdm(
- as_completed(futures),
- total=len(futures),
- desc="parallelQaAndEvaluate completing: ",
+ as_completed(futures),
+ total=len(futures),
+ desc="parallelQaAndEvaluate completing: ",
):
result = future.result()
if result is not None:
@@ -115,9 +116,7 @@ def process_sample(data):
filePath = "./data/hotpotqa_qa_sub.json"
start_time = time.time()
- qaFilePath = os.path.join(
- os.path.abspath(os.path.dirname(__file__)), filePath
- )
+ qaFilePath = os.path.join(os.path.abspath(os.path.dirname(__file__)), filePath)
resFilePath = os.path.join(
os.path.abspath(os.path.dirname(__file__)), f"hotpotqa_res_{start_time}.json"
)
@@ -125,7 +124,7 @@ def process_sample(data):
qaFilePath, resFilePath, threadNum=20, upperLimit=100000, run_failed=True
)
- total_metrics['cost'] = time.time() - start_time
+ total_metrics["cost"] = time.time() - start_time
with open(f"./hotpotqa_metrics_{start_time}.json", "w") as f:
json.dump(total_metrics, f)
print(total_metrics)
diff --git a/kag/examples/hotpotqa/solver/prompt/__init__.py b/kag/examples/hotpotqa/solver/prompt/__init__.py
index dadd42a3..dfa931cd 100644
--- a/kag/examples/hotpotqa/solver/prompt/__init__.py
+++ b/kag/examples/hotpotqa/solver/prompt/__init__.py
@@ -11,4 +11,4 @@
"""
Place the prompts to be used for solving problems in this directory.
-"""
\ No newline at end of file
+"""
diff --git a/kag/examples/hotpotqa/solver/prompt/resp_generator.py b/kag/examples/hotpotqa/solver/prompt/resp_generator.py
index 70e96cc9..fa70249c 100644
--- a/kag/examples/hotpotqa/solver/prompt/resp_generator.py
+++ b/kag/examples/hotpotqa/solver/prompt/resp_generator.py
@@ -9,12 +9,14 @@
class RespGenerator(PromptOp):
- template_zh = "基于给定的引用信息回答问题。" \
- "\n只输出答案,不需要输出额外的信息。" \
- "\n给定的引用信息:'$memory'\n问题:'$instruction'"
- template_en = "Answer the question based on the given reference." \
- "\nOnly give me the answer and do not output any other words." \
- "\nThe following are given reference:'$memory'\nQuestion: '$instruction'"
+ template_zh = (
+ "基于给定的引用信息回答问题。" "\n只输出答案,不需要输出额外的信息。" "\n给定的引用信息:'$memory'\n问题:'$instruction'"
+ )
+ template_en = (
+ "Answer the question based on the given reference."
+ "\nOnly give me the answer and do not output any other words."
+ "\nThe following are given reference:'$memory'\nQuestion: '$instruction'"
+ )
def __init__(self, language: str):
super().__init__(language)
@@ -24,5 +26,5 @@ def template_variables(self) -> List[str]:
return ["memory", "instruction"]
def parse_response(self, response: str, **kwargs):
- logger.debug('推理器判别:{}'.format(response))
+ logger.debug("推理器判别:{}".format(response))
return response
diff --git a/kag/examples/medicine/builder/indexer.py b/kag/examples/medicine/builder/indexer.py
index e5ec09de..985584be 100644
--- a/kag/examples/medicine/builder/indexer.py
+++ b/kag/examples/medicine/builder/indexer.py
@@ -6,7 +6,10 @@
from kag.builder.component.extractor import SPGExtractor, KAGExtractor
from kag.builder.component.mapping.spo_mapping import SPOMapping
from kag.builder.component.splitter import LengthSplitter
-from kag.builder.default_chain import DefaultStructuredBuilderChain, DefaultUnstructuredBuilderChain
+from kag.builder.default_chain import (
+ DefaultStructuredBuilderChain,
+ DefaultUnstructuredBuilderChain,
+)
from kag.common.env import init_kag_config
from knext.builder.builder_chain_abc import BuilderChainABC
@@ -30,7 +33,9 @@ def build(self, **kwargs):
class DiseaseBuilderChain(BuilderChainABC):
def build(self, **kwargs):
- source = CSVReader(output_type="Chunk", id_col="idx", name_col="title", content_col="text")
+ source = CSVReader(
+ output_type="Chunk", id_col="idx", name_col="title", content_col="text"
+ )
splitter = LengthSplitter(split_length=2000)
extractor = KAGExtractor()
vectorizer = BatchVectorizer()
@@ -39,15 +44,19 @@ def build(self, **kwargs):
return source >> splitter >> extractor >> vectorizer >> sink
-
def import_data():
file_path = os.path.dirname(__file__)
init_kag_config(os.path.join(file_path, "../kag_config.cfg"))
- DefaultStructuredBuilderChain("HumanBodyPart").invoke(file_path=os.path.join(file_path,"data/HumanBodyPart.csv"))
- DefaultStructuredBuilderChain("HospitalDepartment").invoke(file_path=os.path.join(file_path,"data/HospitalDepartment.csv"))
- DiseaseBuilderChain().invoke(file_path=os.path.join(file_path,"data/Disease.csv"))
+ DefaultStructuredBuilderChain("HumanBodyPart").invoke(
+ file_path=os.path.join(file_path, "data/HumanBodyPart.csv")
+ )
+ DefaultStructuredBuilderChain("HospitalDepartment").invoke(
+ file_path=os.path.join(file_path, "data/HospitalDepartment.csv")
+ )
+ DiseaseBuilderChain().invoke(file_path=os.path.join(file_path, "data/Disease.csv"))
+
+ SPOBuilderChain().invoke(file_path=os.path.join(file_path, "data/SPO.csv"))
- SPOBuilderChain().invoke(file_path=os.path.join(file_path,"data/SPO.csv"))
-if __name__ == '__main__':
+if __name__ == "__main__":
import_data()
diff --git a/kag/examples/medicine/builder/prompt/ner.py b/kag/examples/medicine/builder/prompt/ner.py
index 07c6298a..7a9a9333 100644
--- a/kag/examples/medicine/builder/prompt/ner.py
+++ b/kag/examples/medicine/builder/prompt/ner.py
@@ -45,9 +45,7 @@ class OpenIENERPrompt(PromptOp):
template_en = template_zh
- def __init__(
- self, language: Optional[str] = "en", **kwargs
- ):
+ def __init__(self, language: Optional[str] = "en", **kwargs):
super().__init__(language, **kwargs)
self.schema = SchemaClient(project_id=self.project_id).extract_types()
self.template = Template(self.template).safe_substitute(schema=self.schema)
diff --git a/kag/examples/medicine/kag_config.cfg b/kag/examples/medicine/kag_config.cfg
index 205804de..b4975051 100644
--- a/kag/examples/medicine/kag_config.cfg
+++ b/kag/examples/medicine/kag_config.cfg
@@ -4,14 +4,14 @@ host_addr = http://127.0.0.1:8887
id = 4
[vectorizer]
-vectorizer = kag.common.vectorizer.OpenAIVectorizer
+type = openai
model = bge-m3
api_key = EMPTY
base_url = http://127.0.0.1:11434/v1
vector_dimensions = 1024
[llm]
-client_type = maas
+type = maas
base_url = https://api.deepseek.com/
api_key = put your deepseek api key here
model = deepseek-chat
diff --git a/kag/examples/medicine/reasoner/client.py b/kag/examples/medicine/reasoner/client.py
index 32ab9c5b..9a54b989 100644
--- a/kag/examples/medicine/reasoner/client.py
+++ b/kag/examples/medicine/reasoner/client.py
@@ -3,22 +3,24 @@
from knext.reasoner.client import ReasonerClient
from kag.common.env import init_kag_config
+
def read_dsl_files(directory):
"""
Read all dsl files in the reasoner directory.
"""
-
- dsl_contents = []
+
+ dsl_contents = []
for filename in os.listdir(directory):
- if filename.endswith('.dsl'):
+ if filename.endswith(".dsl"):
file_path = os.path.join(directory, filename)
- with open(file_path, 'r', encoding='utf-8') as file:
- content = file.read()
- dsl_contents.append(content)
+ with open(file_path, "r", encoding="utf-8") as file:
+ content = file.read()
+ dsl_contents.append(content)
return dsl_contents
+
if __name__ == "__main__":
resonser_path = os.path.dirname(os.path.abspath(__file__))
project_path = os.path.dirname(resonser_path)
@@ -27,7 +29,9 @@ def read_dsl_files(directory):
host_addr = os.environ["KAG_PROJECT_HOST_ADDR"]
project_id = os.environ["KAG_PROJECT_ID"]
namespace = os.environ["KAG_PROJECT_NAMESPACE"]
- client = ReasonerClient(host_addr=host_addr, project_id=project_id,namespace=namespace)
+ client = ReasonerClient(
+ host_addr=host_addr, project_id=project_id, namespace=namespace
+ )
dsls = read_dsl_files(resonser_path)
for dsl in dsls:
client.execute(dsl)
diff --git a/kag/examples/medicine/solver/evaForMedicine.py b/kag/examples/medicine/solver/evaForMedicine.py
index d83a8f72..d49be93b 100644
--- a/kag/examples/medicine/solver/evaForMedicine.py
+++ b/kag/examples/medicine/solver/evaForMedicine.py
@@ -21,7 +21,7 @@ def qa(self, query):
resp = SolverPipeline()
answer, trace_log = resp.run(query)
- return answer,trace_log
+ return answer, trace_log
"""
parallel qa from knowledge base
@@ -32,7 +32,7 @@ def qa(self, query):
if __name__ == "__main__":
demo = MedicineDemo()
query = "甲状腺结节可以吃什么药?"
- answer,trace_log = demo.qa(query)
+ answer, trace_log = demo.qa(query)
print(f"Question: {query}")
print(f"Answer: {answer}")
print(f"TraceLog: {trace_log}")
diff --git a/kag/examples/medicine/solver/prompt/question_ner.py b/kag/examples/medicine/solver/prompt/question_ner.py
index 3eb8ea9d..9dd1c4e1 100644
--- a/kag/examples/medicine/solver/prompt/question_ner.py
+++ b/kag/examples/medicine/solver/prompt/question_ner.py
@@ -55,9 +55,7 @@ class QuestionNER(PromptOp):
template_en = template_zh
- def __init__(
- self, language: Optional[str] = "en", **kwargs
- ):
+ def __init__(self, language: Optional[str] = "en", **kwargs):
super().__init__(language, **kwargs)
self.schema = SchemaClient(project_id=self.project_id).extract_types()
self.template = Template(self.template).safe_substitute(schema=self.schema)
diff --git a/kag/examples/medicine/solver/prompt/resp_generator.py b/kag/examples/medicine/solver/prompt/resp_generator.py
index 91a910d5..40981c3b 100644
--- a/kag/examples/medicine/solver/prompt/resp_generator.py
+++ b/kag/examples/medicine/solver/prompt/resp_generator.py
@@ -9,10 +9,11 @@
class RespGenerator(PromptOp):
- template_zh = "基于给定的引用信息完整回答问题。" \
- "\n给定的引用信息:'$memory'\n问题:'$instruction'"
- template_en = "Answer the question completely based on the given reference." \
- "\nThe following are given reference:'$memory'\nQuestion: '$instruction'"
+ template_zh = "基于给定的引用信息完整回答问题。" "\n给定的引用信息:'$memory'\n问题:'$instruction'"
+ template_en = (
+ "Answer the question completely based on the given reference."
+ "\nThe following are given reference:'$memory'\nQuestion: '$instruction'"
+ )
def __init__(self, language: str):
super().__init__(language)
@@ -22,5 +23,5 @@ def template_variables(self) -> List[str]:
return ["memory", "instruction"]
def parse_response(self, response: str, **kwargs):
- logger.debug('推理器判别:{}'.format(response))
+ logger.debug("推理器判别:{}".format(response))
return response
diff --git a/kag/examples/musique/builder/__init__.py b/kag/examples/musique/builder/__init__.py
index 94be39bc..7a018e7c 100644
--- a/kag/examples/musique/builder/__init__.py
+++ b/kag/examples/musique/builder/__init__.py
@@ -11,4 +11,4 @@
"""
Builder Dir.
-"""
\ No newline at end of file
+"""
diff --git a/kag/examples/musique/builder/data/__init__.py b/kag/examples/musique/builder/data/__init__.py
index 6a8637b9..59bacd4d 100644
--- a/kag/examples/musique/builder/data/__init__.py
+++ b/kag/examples/musique/builder/data/__init__.py
@@ -11,4 +11,4 @@
"""
Place the files to be used for building the index in this directory.
-"""
\ No newline at end of file
+"""
diff --git a/kag/examples/musique/builder/indexer.py b/kag/examples/musique/builder/indexer.py
index a9934ddf..47a4a9eb 100644
--- a/kag/examples/musique/builder/indexer.py
+++ b/kag/examples/musique/builder/indexer.py
@@ -79,11 +79,9 @@ def buildKB(corpusFilePath):
logger.info(f"\n\nbuildKB successfully for {corpusFilePath}\n\n")
-if __name__ == '__main__':
+if __name__ == "__main__":
filePath = "./data/musique_sub_corpus.json"
# filePath = "./data/musique_train_corpus.json"
- corpusFilePath = os.path.join(
- os.path.abspath(os.path.dirname(__file__)),filePath
- )
+ corpusFilePath = os.path.join(os.path.abspath(os.path.dirname(__file__)), filePath)
buildKB(corpusFilePath)
diff --git a/kag/examples/musique/builder/prompt/__init__.py b/kag/examples/musique/builder/prompt/__init__.py
index 247bb44c..ba7d5d56 100644
--- a/kag/examples/musique/builder/prompt/__init__.py
+++ b/kag/examples/musique/builder/prompt/__init__.py
@@ -11,4 +11,4 @@
"""
Place the prompts to be used for building the index in this directory.
-"""
\ No newline at end of file
+"""
diff --git a/kag/examples/musique/builder/prompt/ner.py b/kag/examples/musique/builder/prompt/ner.py
index cf5aa897..79c022e5 100644
--- a/kag/examples/musique/builder/prompt/ner.py
+++ b/kag/examples/musique/builder/prompt/ner.py
@@ -85,9 +85,7 @@ class OpenIENERPrompt(PromptOp):
template_zh = template_en
- def __init__(
- self, language: Optional[str] = "en", **kwargs
- ):
+ def __init__(self, language: Optional[str] = "en", **kwargs):
super().__init__(language, **kwargs)
self.schema = SchemaClient(project_id=self.project_id).extract_types()
self.template = Template(self.template).safe_substitute(schema=self.schema)
diff --git a/kag/examples/musique/kag_config.cfg b/kag/examples/musique/kag_config.cfg
index a392d61d..f8df808b 100644
--- a/kag/examples/musique/kag_config.cfg
+++ b/kag/examples/musique/kag_config.cfg
@@ -11,14 +11,14 @@ database = dev
[vectorizer]
-vectorizer = kag.common.vectorizer.OpenAIVectorizer
+type = openai
model = bge-m3
api_key = EMPTY
base_url = http://127.0.0.1:11434/v1
vector_dimensions = 1024
[llm]
-client_type = maas
+type = maas
base_url = https://api.deepseek.com/
api_key = put your deepseek api key here
model = deepseek-chat
diff --git a/kag/examples/musique/reasoner/__init__.py b/kag/examples/musique/reasoner/__init__.py
index a0c4032b..8b8a3c91 100644
--- a/kag/examples/musique/reasoner/__init__.py
+++ b/kag/examples/musique/reasoner/__init__.py
@@ -17,4 +17,4 @@
MATCH (s:DEFAULT.Company)
RETURN s.id, s.address
```
-"""
\ No newline at end of file
+"""
diff --git a/kag/examples/musique/schema/__init__.py b/kag/examples/musique/schema/__init__.py
index ef3dde6d..8ac86acc 100644
--- a/kag/examples/musique/schema/__init__.py
+++ b/kag/examples/musique/schema/__init__.py
@@ -15,4 +15,4 @@
You can execute `kag schema commit` to commit your schema to SPG server.
-"""
\ No newline at end of file
+"""
diff --git a/kag/examples/musique/solver/evaForMusique.py b/kag/examples/musique/solver/evaForMusique.py
index 8770e5ba..2c5fd892 100644
--- a/kag/examples/musique/solver/evaForMusique.py
+++ b/kag/examples/musique/solver/evaForMusique.py
@@ -35,8 +35,9 @@ def qa(self, query):
def qaWithoutLogicForm(self, query):
# CA
- lf_solver = LFSolver(chunk_retriever=LFChunkRetriever(),
- kg_retriever=KGRetrieverByLlm())
+ lf_solver = LFSolver(
+ chunk_retriever=LFChunkRetriever(), kg_retriever=KGRetrieverByLlm()
+ )
reasoner = DefaultReasoner(lf_planner=LFPlannerABC(), lf_solver=lf_solver)
resp = SolverPipeline(reasoner=reasoner)
answer, trace_log = resp.run(query)
@@ -49,7 +50,7 @@ def qaWithoutLogicForm(self, query):
"""
def parallelQaAndEvaluate(
- self, qaFilePath, resFilePath, threadNum=1, upperLimit=10
+ self, qaFilePath, resFilePath, threadNum=1, upperLimit=10
):
def process_sample(data):
try:
@@ -83,9 +84,9 @@ def process_sample(data):
for sample_idx, sample in enumerate(qaList[:upperLimit])
]
for future in tqdm(
- as_completed(futures),
- total=len(futures),
- desc="parallelQaAndEvaluate completing: ",
+ as_completed(futures),
+ total=len(futures),
+ desc="parallelQaAndEvaluate completing: ",
):
result = future.result()
if result is not None:
@@ -124,11 +125,9 @@ def process_sample(data):
start_time = time.time()
filePath = "./data/musique_qa_sub.json"
- #filePath = "./data/musique_qa_train.json"
+ # filePath = "./data/musique_qa_train.json"
- qaFilePath = os.path.join(
- os.path.abspath(os.path.dirname(__file__)), filePath
- )
+ qaFilePath = os.path.join(os.path.abspath(os.path.dirname(__file__)), filePath)
resFilePath = os.path.join(
os.path.abspath(os.path.dirname(__file__)), f"musique_res_{start_time}.json"
)
@@ -136,7 +135,7 @@ def process_sample(data):
qaFilePath, resFilePath, threadNum=20, upperLimit=10000
)
- total_metrics['cost'] = time.time() - start_time
+ total_metrics["cost"] = time.time() - start_time
with open(f"./musique_metrics_{start_time}.json", "w") as f:
json.dump(total_metrics, f)
print(total_metrics)
diff --git a/kag/examples/musique/solver/prompt/__init__.py b/kag/examples/musique/solver/prompt/__init__.py
index dadd42a3..dfa931cd 100644
--- a/kag/examples/musique/solver/prompt/__init__.py
+++ b/kag/examples/musique/solver/prompt/__init__.py
@@ -11,4 +11,4 @@
"""
Place the prompts to be used for solving problems in this directory.
-"""
\ No newline at end of file
+"""
diff --git a/kag/examples/musique/solver/prompt/resp_generator.py b/kag/examples/musique/solver/prompt/resp_generator.py
index 70e96cc9..fa70249c 100644
--- a/kag/examples/musique/solver/prompt/resp_generator.py
+++ b/kag/examples/musique/solver/prompt/resp_generator.py
@@ -9,12 +9,14 @@
class RespGenerator(PromptOp):
- template_zh = "基于给定的引用信息回答问题。" \
- "\n只输出答案,不需要输出额外的信息。" \
- "\n给定的引用信息:'$memory'\n问题:'$instruction'"
- template_en = "Answer the question based on the given reference." \
- "\nOnly give me the answer and do not output any other words." \
- "\nThe following are given reference:'$memory'\nQuestion: '$instruction'"
+ template_zh = (
+ "基于给定的引用信息回答问题。" "\n只输出答案,不需要输出额外的信息。" "\n给定的引用信息:'$memory'\n问题:'$instruction'"
+ )
+ template_en = (
+ "Answer the question based on the given reference."
+ "\nOnly give me the answer and do not output any other words."
+ "\nThe following are given reference:'$memory'\nQuestion: '$instruction'"
+ )
def __init__(self, language: str):
super().__init__(language)
@@ -24,5 +26,5 @@ def template_variables(self) -> List[str]:
return ["memory", "instruction"]
def parse_response(self, response: str, **kwargs):
- logger.debug('推理器判别:{}'.format(response))
+ logger.debug("推理器判别:{}".format(response))
return response
diff --git a/kag/examples/riskmining/builder/indexer.py b/kag/examples/riskmining/builder/indexer.py
index f66bbb86..390ac629 100644
--- a/kag/examples/riskmining/builder/indexer.py
+++ b/kag/examples/riskmining/builder/indexer.py
@@ -33,6 +33,7 @@ def build(self, **kwargs):
chain = source >> mapping >> vectorizer >> sink
return chain
+
class RiskMiningRelationChain(BuilderChainABC):
def __init__(self, spg_type_name: str):
super().__init__()
@@ -71,28 +72,42 @@ def build(self, **kwargs):
def import_data():
file_path = os.path.dirname(__file__)
init_kag_config(os.path.join(file_path, "../kag_config.cfg"))
- RiskMiningEntityChain(spg_type_name="Cert").invoke(os.path.join(file_path, "data/Cert.csv"))
- RiskMiningEntityChain(spg_type_name="App").invoke(os.path.join(file_path, "data/App.csv"))
- RiskMiningEntityChain(spg_type_name="Company").invoke(os.path.join(file_path, "data/Company.csv"))
+ RiskMiningEntityChain(spg_type_name="Cert").invoke(
+ os.path.join(file_path, "data/Cert.csv")
+ )
+ RiskMiningEntityChain(spg_type_name="App").invoke(
+ os.path.join(file_path, "data/App.csv")
+ )
+ RiskMiningEntityChain(spg_type_name="Company").invoke(
+ os.path.join(file_path, "data/Company.csv")
+ )
RiskMiningRelationChain(spg_type_name="Company_hasCert_Cert").invoke(
os.path.join(file_path, "data/Company_hasCert_Cert.csv")
)
- RiskMiningEntityChain(spg_type_name="Device").invoke(os.path.join(file_path, "data/Device.csv"))
+ RiskMiningEntityChain(spg_type_name="Device").invoke(
+ os.path.join(file_path, "data/Device.csv")
+ )
RiskMiningPersonFundTransPersonChain(
spg_type_name="Person_fundTrans_Person"
).invoke(os.path.join(file_path, "data/Person_fundTrans_Person.csv"))
RiskMiningRelationChain(spg_type_name="Person_hasCert_Cert").invoke(
os.path.join(file_path, "data/Person_hasCert_Cert.csv")
)
- RiskMiningRelationChain(
- spg_type_name="Person_hasDevice_Device"
- ).invoke(os.path.join(file_path, "data/Person_hasDevice_Device.csv"))
- RiskMiningRelationChain(
- spg_type_name="Person_holdShare_Company"
- ).invoke(os.path.join(file_path, "data/Person_holdShare_Company.csv"))
- RiskMiningEntityChain(spg_type_name="Person").invoke(os.path.join(file_path, "data/Person.csv"))
- RiskMiningEntityChain(spg_type_name="TaxOfRiskApp").invoke(os.path.join(file_path, "data/TaxOfRiskApp.csv"))
- RiskMiningEntityChain(spg_type_name="TaxOfRiskUser").invoke(os.path.join(file_path, "data/TaxOfRiskUser.csv"))
+ RiskMiningRelationChain(spg_type_name="Person_hasDevice_Device").invoke(
+ os.path.join(file_path, "data/Person_hasDevice_Device.csv")
+ )
+ RiskMiningRelationChain(spg_type_name="Person_holdShare_Company").invoke(
+ os.path.join(file_path, "data/Person_holdShare_Company.csv")
+ )
+ RiskMiningEntityChain(spg_type_name="Person").invoke(
+ os.path.join(file_path, "data/Person.csv")
+ )
+ RiskMiningEntityChain(spg_type_name="TaxOfRiskApp").invoke(
+ os.path.join(file_path, "data/TaxOfRiskApp.csv")
+ )
+ RiskMiningEntityChain(spg_type_name="TaxOfRiskUser").invoke(
+ os.path.join(file_path, "data/TaxOfRiskUser.csv")
+ )
if __name__ == "__main__":
diff --git a/kag/examples/riskmining/kag_config.cfg b/kag/examples/riskmining/kag_config.cfg
index c351ecd4..9124ef88 100644
--- a/kag/examples/riskmining/kag_config.cfg
+++ b/kag/examples/riskmining/kag_config.cfg
@@ -4,14 +4,14 @@ host_addr = http://127.0.0.1:8887
id = 8
[vectorizer]
-vectorizer = kag.common.vectorizer.OpenAIVectorizer
+type = openai
model = bge-m3
api_key = EMPTY
base_url = http://127.0.0.1:11434/v1
vector_dimensions = 1024
[llm]
-client_type = maas
+type = maas
base_url = https://api.deepseek.com/
api_key = put your deepseek api key here
model = deepseek-chat
diff --git a/kag/examples/riskmining/reasoner/client.py b/kag/examples/riskmining/reasoner/client.py
index be11abb3..400802af 100644
--- a/kag/examples/riskmining/reasoner/client.py
+++ b/kag/examples/riskmining/reasoner/client.py
@@ -3,22 +3,24 @@
from knext.reasoner.client import ReasonerClient
from kag.common.env import init_kag_config
+
def read_dsl_files(directory):
"""
Read all dsl files in the reasoner directory.
"""
-
- dsl_contents = []
+
+ dsl_contents = []
for filename in os.listdir(directory):
- if filename.endswith('.dsl'):
+ if filename.endswith(".dsl"):
file_path = os.path.join(directory, filename)
- with open(file_path, 'r', encoding='utf-8') as file:
- content = file.read()
- dsl_contents.append(content)
+ with open(file_path, "r", encoding="utf-8") as file:
+ content = file.read()
+ dsl_contents.append(content)
return dsl_contents
+
if __name__ == "__main__":
reasoner_path = os.path.dirname(os.path.abspath(__file__))
project_path = os.path.dirname(reasoner_path)
@@ -27,7 +29,9 @@ def read_dsl_files(directory):
host_addr = os.environ["KAG_PROJECT_HOST_ADDR"]
project_id = os.environ["KAG_PROJECT_ID"]
namespace = os.environ["KAG_PROJECT_NAMESPACE"]
- client = ReasonerClient(host_addr=host_addr, project_id=project_id, namespace=namespace)
+ client = ReasonerClient(
+ host_addr=host_addr, project_id=project_id, namespace=namespace
+ )
dsls = read_dsl_files(reasoner_path)
for dsl in dsls:
client.execute(dsl)
diff --git a/kag/examples/riskmining/solver/prompt/logic_form_plan.py b/kag/examples/riskmining/solver/prompt/logic_form_plan.py
index ab87efb7..428db570 100644
--- a/kag/examples/riskmining/solver/prompt/logic_form_plan.py
+++ b/kag/examples/riskmining/solver/prompt/logic_form_plan.py
@@ -2,6 +2,7 @@
import re
from string import Template
from typing import List
+
logger = logging.getLogger(__name__)
from kag.common.base.prompt_op import PromptOp
@@ -103,7 +104,6 @@ def __init__(self, language: str):
def template_variables(self) -> List[str]:
return ["question"]
-
def parse_response(self, response: str, **kwargs):
try:
logger.debug(f"logic form:{response}")
@@ -111,17 +111,17 @@ def parse_response(self, response: str, **kwargs):
_output_string = response.strip()
sub_querys = []
logic_forms = []
- current_sub_query = ''
- for line in _output_string.split('\n'):
- if line.startswith('Step'):
- sub_querys_regex = re.search('Step\d+:(.*)', line)
+ current_sub_query = ""
+ for line in _output_string.split("\n"):
+ if line.startswith("Step"):
+ sub_querys_regex = re.search("Step\d+:(.*)", line)
if sub_querys_regex is not None:
sub_querys.append(sub_querys_regex.group(1))
current_sub_query = sub_querys_regex.group(1)
- elif line.startswith('Output'):
+ elif line.startswith("Output"):
sub_querys.append("output")
- elif line.startswith('Action'):
- logic_forms_regex = re.search('Action\d+:(.*)', line)
+ elif line.startswith("Action"):
+ logic_forms_regex = re.search("Action\d+:(.*)", line)
if logic_forms_regex:
logic_forms.append(logic_forms_regex.group(1))
if len(logic_forms) - len(sub_querys) == 1:
diff --git a/kag/examples/riskmining/solver/qa.py b/kag/examples/riskmining/solver/qa.py
index e81e68f2..701bcd29 100644
--- a/kag/examples/riskmining/solver/qa.py
+++ b/kag/examples/riskmining/solver/qa.py
@@ -37,8 +37,9 @@ def qa(self, query):
def qaWithoutLogicForm(self, query):
# CA
- lf_solver = LFSolver(chunk_retriever=LFChunkRetriever(),
- kg_retriever=KGRetrieverByLlm())
+ lf_solver = LFSolver(
+ chunk_retriever=LFChunkRetriever(), kg_retriever=KGRetrieverByLlm()
+ )
reasoner = DefaultReasoner(lf_planner=LFPlannerABC(), lf_solver=lf_solver)
resp = SolverPipeline(reasoner=reasoner)
answer, trace_log = resp.run(query)
@@ -51,7 +52,7 @@ def qaWithoutLogicForm(self, query):
"""
def parallelQaAndEvaluate(
- self, qaFilePath, resFilePath, threadNum=1, upperLimit=10
+ self, qaFilePath, resFilePath, threadNum=1, upperLimit=10
):
def process_sample(data):
try:
@@ -85,9 +86,9 @@ def process_sample(data):
for sample_idx, sample in enumerate(qaList[:upperLimit])
]
for future in tqdm(
- as_completed(futures),
- total=len(futures),
- desc="parallelQaAndEvaluate completing: ",
+ as_completed(futures),
+ total=len(futures),
+ desc="parallelQaAndEvaluate completing: ",
):
result = future.result()
if result is not None:
@@ -129,18 +130,17 @@ def process_sample(data):
project_id = os.getenv("KAG_PROJECT_ID")
host_addr = os.getenv("KAG_PROJECT_HOST_ADDR")
- sc = ReasonerClient(host_addr, project_id, )
- param = {
- "spg.reasoner.plan.pretty.print.logger.enable": "true"
-
- }
+ sc = ReasonerClient(
+ host_addr,
+ project_id,
+ )
+ param = {"spg.reasoner.plan.pretty.print.logger.enable": "true"}
-# ret = sc.syn_execute("""MATCH
-# (u:`RiskMining.TaxOfRiskUser`/`赌博App开发者`)
-# RETURN u.name
-# """, **param)
-# print(ret)
+ # ret = sc.syn_execute("""MATCH
+ # (u:`RiskMining.TaxOfRiskUser`/`赌博App开发者`)
+ # RETURN u.name
+ # """, **param)
+ # print(ret)
evaObj = EvaQA(configFilePath=configFilePath)
print(evaObj.qa("裘**是否有风险?"))
-
diff --git a/kag/examples/supplychain/builder/indexer.py b/kag/examples/supplychain/builder/indexer.py
index 08415712..a6fef708 100644
--- a/kag/examples/supplychain/builder/indexer.py
+++ b/kag/examples/supplychain/builder/indexer.py
@@ -17,18 +17,24 @@
from kag.builder.component import SPGTypeMapping, KGWriter, RelationMapping
from kag.builder.component.reader.csv_reader import CSVReader
from kag.examples.supplychain.builder.operator.event_kg_writer_op import EventKGWriter
-from kag.examples.supplychain.builder.operator.fund_date_process_op import FundDateProcessComponent
+from kag.examples.supplychain.builder.operator.fund_date_process_op import (
+ FundDateProcessComponent,
+)
from knext.search.client import SearchClient
from knext.builder.builder_chain_abc import BuilderChainABC
from knext.search.client import SearchClient
def company_link_func(prop_value, node):
- sc = SearchClient(os.getenv("KAG_PROJECT_HOST_ADDR"), int(os.getenv("KAG_PROJECT_ID")))
+ sc = SearchClient(
+ os.getenv("KAG_PROJECT_HOST_ADDR"), int(os.getenv("KAG_PROJECT_ID"))
+ )
company_id = []
- records = sc.search_text(prop_value, label_constraints=["SupplyChain.Company"], topk=1)
+ records = sc.search_text(
+ prop_value, label_constraints=["SupplyChain.Company"], topk=1
+ )
if records:
- company_id.append(records[0]["node"]['id'])
+ company_id.append(records[0]["node"]["id"])
return company_id
@@ -45,7 +51,9 @@ def build(self, **kwargs):
.add_property_mapping("id", "id")
.add_property_mapping("age", "age")
.add_property_mapping(
- "legalRepresentative", "legalRepresentative", link_func=company_link_func
+ "legalRepresentative",
+ "legalRepresentative",
+ link_func=company_link_func,
)
)
vectorizer = BatchVectorizer()
@@ -73,6 +81,7 @@ def build(self, **kwargs):
sink = KGWriter()
return source >> date_process_op >> mapping >> vectorizer >> sink
+
class SupplyChainDefaulStructuredBuilderChain(DefaultStructuredBuilderChain):
def __init__(self, spg_type_name: str, **kwargs):
super().__init__(spg_type_name, **kwargs)
@@ -116,35 +125,41 @@ def build(self, **kwargs):
chain = source >> mapping >> vectorizer >> sink
return chain
+
def import_data():
file_path = os.path.dirname(__file__)
init_kag_config(os.path.join(file_path, "../kag_config.cfg"))
-
SupplyChainDefaulStructuredBuilderChain(spg_type_name="TaxOfCompanyEvent").invoke(
- file_path=os.path.join(file_path,"data/TaxOfCompanyEvent.csv")
+ file_path=os.path.join(file_path, "data/TaxOfCompanyEvent.csv")
)
SupplyChainDefaulStructuredBuilderChain(spg_type_name="TaxOfProdEvent").invoke(
- file_path=os.path.join(file_path,"data/TaxOfProdEvent.csv")
+ file_path=os.path.join(file_path, "data/TaxOfProdEvent.csv")
+ )
+ SupplyChainDefaulStructuredBuilderChain(spg_type_name="Trend").invoke(
+ file_path=os.path.join(file_path, "data/Trend.csv")
)
- SupplyChainDefaulStructuredBuilderChain(spg_type_name="Trend").invoke(file_path=os.path.join(file_path,"data/Trend.csv"))
SupplyChainDefaulStructuredBuilderChain(spg_type_name="Industry").invoke(
- file_path=os.path.join(file_path,"data/Industry.csv")
+ file_path=os.path.join(file_path, "data/Industry.csv")
)
SupplyChainDefaulStructuredBuilderChain(spg_type_name="Product").invoke(
- file_path=os.path.join(file_path,"data/Product.csv")
+ file_path=os.path.join(file_path, "data/Product.csv")
)
SupplyChainDefaulStructuredBuilderChain(spg_type_name="Company").invoke(
- file_path=os.path.join(file_path,"data/Company.csv")
+ file_path=os.path.join(file_path, "data/Company.csv")
+ )
+ SupplyChainDefaulStructuredBuilderChain(spg_type_name="Index").invoke(
+ file_path=os.path.join(file_path, "data/Index.csv")
+ )
+ SupplyChainPersonChain(spg_type_name="Person").invoke(
+ file_path=os.path.join(file_path, "data/Person.csv")
)
- SupplyChainDefaulStructuredBuilderChain(spg_type_name="Index").invoke(file_path=os.path.join(file_path,"data/Index.csv"))
- SupplyChainPersonChain(spg_type_name="Person").invoke(file_path=os.path.join(file_path,"data/Person.csv"))
SupplyChainCompanyFundTransCompanyChain(
spg_type_name="Company_fundTrans_Company"
- ).invoke(file_path=os.path.join(file_path,"data/Company_fundTrans_Company.csv"))
+ ).invoke(file_path=os.path.join(file_path, "data/Company_fundTrans_Company.csv"))
SupplyChainEventBuilderChain(spg_type_name="ProductChainEvent").invoke(
- file_path=os.path.join(file_path,"data/ProductChainEvent.csv")
+ file_path=os.path.join(file_path, "data/ProductChainEvent.csv")
)
diff --git a/kag/examples/supplychain/builder/operator/company_link_op.py b/kag/examples/supplychain/builder/operator/company_link_op.py
index 226b5c38..ba3c1bd8 100644
--- a/kag/examples/supplychain/builder/operator/company_link_op.py
+++ b/kag/examples/supplychain/builder/operator/company_link_op.py
@@ -22,9 +22,11 @@ class CompanyLinkOp(LinkOpABC):
bind_to = "Company"
def invoke(self, source: Node, prop_value: str, target_type: str) -> List[str]:
- sc = SearchClient(os.getenv("KAG_PROJECT_HOST_ADDR"), int(os.getenv("KAG_PROJECT_ID")))
+ sc = SearchClient(
+ os.getenv("KAG_PROJECT_HOST_ADDR"), int(os.getenv("KAG_PROJECT_ID"))
+ )
company_id = []
records = sc.search_text(prop_value, label_constraints=[target_type], topk=1)
if records:
- company_id.append(records[0]["node"]['id'])
+ company_id.append(records[0]["node"]["id"])
return company_id
diff --git a/kag/examples/supplychain/builder/operator/event_kg_writer_op.py b/kag/examples/supplychain/builder/operator/event_kg_writer_op.py
index 9a83dc85..ecc13587 100644
--- a/kag/examples/supplychain/builder/operator/event_kg_writer_op.py
+++ b/kag/examples/supplychain/builder/operator/event_kg_writer_op.py
@@ -11,6 +11,9 @@ def __init__(self, project_id: str = None, **kwargs):
super().__init__(project_id, **kwargs)
def invoke(
- self, input: Input, alter_operation: str = AlterOperationEnum.Upsert, lead_to_builder: bool = True
+ self,
+ input: Input,
+ alter_operation: str = AlterOperationEnum.Upsert,
+ lead_to_builder: bool = True,
) -> List[Output]:
return super().invoke(input, alter_operation, lead_to_builder)
diff --git a/kag/examples/supplychain/kag_config.cfg b/kag/examples/supplychain/kag_config.cfg
index 590f42c0..020c29e9 100644
--- a/kag/examples/supplychain/kag_config.cfg
+++ b/kag/examples/supplychain/kag_config.cfg
@@ -4,14 +4,14 @@ host_addr = http://127.0.0.1:8887
id = 7
[vectorizer]
-vectorizer = kag.common.vectorizer.OpenAIVectorizer
+type = openai
model = bge-m3
api_key = EMPTY
base_url = http://127.0.0.1:11434/v1
vector_dimensions = 1024
[llm]
-client_type = maas
+type = maas
base_url = https://api.deepseek.com/
api_key = put your deepseek api key here
model = deepseek-chat
diff --git a/kag/examples/supplychain/reasoner/client.py b/kag/examples/supplychain/reasoner/client.py
index 32ab9c5b..9a54b989 100644
--- a/kag/examples/supplychain/reasoner/client.py
+++ b/kag/examples/supplychain/reasoner/client.py
@@ -3,22 +3,24 @@
from knext.reasoner.client import ReasonerClient
from kag.common.env import init_kag_config
+
def read_dsl_files(directory):
"""
Read all dsl files in the reasoner directory.
"""
-
- dsl_contents = []
+
+ dsl_contents = []
for filename in os.listdir(directory):
- if filename.endswith('.dsl'):
+ if filename.endswith(".dsl"):
file_path = os.path.join(directory, filename)
- with open(file_path, 'r', encoding='utf-8') as file:
- content = file.read()
- dsl_contents.append(content)
+ with open(file_path, "r", encoding="utf-8") as file:
+ content = file.read()
+ dsl_contents.append(content)
return dsl_contents
+
if __name__ == "__main__":
resonser_path = os.path.dirname(os.path.abspath(__file__))
project_path = os.path.dirname(resonser_path)
@@ -27,7 +29,9 @@ def read_dsl_files(directory):
host_addr = os.environ["KAG_PROJECT_HOST_ADDR"]
project_id = os.environ["KAG_PROJECT_ID"]
namespace = os.environ["KAG_PROJECT_NAMESPACE"]
- client = ReasonerClient(host_addr=host_addr, project_id=project_id,namespace=namespace)
+ client = ReasonerClient(
+ host_addr=host_addr, project_id=project_id, namespace=namespace
+ )
dsls = read_dsl_files(resonser_path)
for dsl in dsls:
client.execute(dsl)
diff --git a/kag/examples/supplychain/solver/prompt/logic_form_plan.py b/kag/examples/supplychain/solver/prompt/logic_form_plan.py
index 2448e4a1..e60082b3 100644
--- a/kag/examples/supplychain/solver/prompt/logic_form_plan.py
+++ b/kag/examples/supplychain/solver/prompt/logic_form_plan.py
@@ -1,6 +1,7 @@
import logging
import re
from typing import List
+
logger = logging.getLogger(__name__)
from kag.common.base.prompt_op import PromptOp
@@ -102,7 +103,6 @@ def __init__(self, language: str):
def template_variables(self) -> List[str]:
return ["question"]
-
def parse_response(self, response: str, **kwargs):
try:
logger.debug(f"logic form:{response}")
@@ -110,17 +110,17 @@ def parse_response(self, response: str, **kwargs):
_output_string = response.strip()
sub_querys = []
logic_forms = []
- current_sub_query = ''
- for line in _output_string.split('\n'):
- if line.startswith('Step'):
- sub_querys_regex = re.search('Step\d+:(.*)', line)
+ current_sub_query = ""
+ for line in _output_string.split("\n"):
+ if line.startswith("Step"):
+ sub_querys_regex = re.search("Step\d+:(.*)", line)
if sub_querys_regex is not None:
sub_querys.append(sub_querys_regex.group(1))
current_sub_query = sub_querys_regex.group(1)
- elif line.startswith('Output'):
+ elif line.startswith("Output"):
sub_querys.append("output")
- elif line.startswith('Action'):
- logic_forms_regex = re.search('Action\d+:(.*)', line)
+ elif line.startswith("Action"):
+ logic_forms_regex = re.search("Action\d+:(.*)", line)
if logic_forms_regex:
logic_forms.append(logic_forms_regex.group(1))
if len(logic_forms) - len(sub_querys) == 1:
diff --git a/kag/examples/supplychain/solver/qa.py b/kag/examples/supplychain/solver/qa.py
index 79334ca6..2bb4770d 100644
--- a/kag/examples/supplychain/solver/qa.py
+++ b/kag/examples/supplychain/solver/qa.py
@@ -38,8 +38,9 @@ def qa(self, query):
def qaWithoutLogicForm(self, query):
# CA
- lf_solver = LFSolver(chunk_retriever=LFChunkRetriever(),
- kg_retriever=KGRetrieverByLlm())
+ lf_solver = LFSolver(
+ chunk_retriever=LFChunkRetriever(), kg_retriever=KGRetrieverByLlm()
+ )
reasoner = DefaultReasoner(lf_planner=LFPlannerABC(), lf_solver=lf_solver)
resp = SolverPipeline(reasoner=reasoner)
answer, trace_log = resp.run(query)
@@ -52,7 +53,7 @@ def qaWithoutLogicForm(self, query):
"""
def parallelQaAndEvaluate(
- self, qaFilePath, resFilePath, threadNum=1, upperLimit=10
+ self, qaFilePath, resFilePath, threadNum=1, upperLimit=10
):
def process_sample(data):
try:
@@ -86,9 +87,9 @@ def process_sample(data):
for sample_idx, sample in enumerate(qaList[:upperLimit])
]
for future in tqdm(
- as_completed(futures),
- total=len(futures),
- desc="parallelQaAndEvaluate completing: ",
+ as_completed(futures),
+ total=len(futures),
+ desc="parallelQaAndEvaluate completing: ",
):
result = future.result()
if result is not None:
diff --git a/kag/examples/utils.py b/kag/examples/utils.py
index 6706523d..9023ffb5 100644
--- a/kag/examples/utils.py
+++ b/kag/examples/utils.py
@@ -24,16 +24,16 @@ def compute_sub_query(trace_log: dict):
round_max_sub_query = 0
kg_direct_num = 0
sub_query_num = 0
- round_max_sub_query += 0 if len(trace_log) == 0 else len(trace_log[0]['history'])
+ round_max_sub_query += 0 if len(trace_log) == 0 else len(trace_log[0]["history"])
for log in trace_log:
- if 'history' not in log:
+ if "history" not in log:
continue
- history = log['history']
+ history = log["history"]
for h in history:
sub_query_num += 1
- source_type = h.get('answer_source', 'chunk')
- if source_type == 'spo':
+ source_type = h.get("answer_source", "chunk")
+ if source_type == "spo":
kg_direct_num += 1
return kg_direct_num, sub_query_num, round_max_sub_query
@@ -51,7 +51,7 @@ def run_rerank_by_score(recall_docs: list):
tmp_dict = {}
for doc in iter_recall_docs:
score = doc.split("#")[-1]
- header = doc.replace(f"#{score}", '')
+ header = doc.replace(f"#{score}", "")
tmp_dict[header] = score
normalized_iter_doc_scores = min_max_normalize(
np.array(list(tmp_dict.values())).astype(float)
@@ -76,7 +76,9 @@ def compute_recall_metrics(recall_docs: list, supporting_facts: list, extract_co
if header is None:
raise Exception(f"doc header extra failed {doc}")
recall_docs_header.append(header)
- return compute_hit_in_recalls(recall_docs_header, supporting_facts, 2), compute_hit_in_recalls(recall_docs_header,
- supporting_facts,
- 5), compute_hit_in_recalls(
- recall_docs_header, supporting_facts, 10), compute_hit_in_recalls(recall_docs_header, supporting_facts, 10000)
+ return (
+ compute_hit_in_recalls(recall_docs_header, supporting_facts, 2),
+ compute_hit_in_recalls(recall_docs_header, supporting_facts, 5),
+ compute_hit_in_recalls(recall_docs_header, supporting_facts, 10),
+ compute_hit_in_recalls(recall_docs_header, supporting_facts, 10000),
+ )
diff --git a/kag/interface/builder/__init__.py b/kag/interface/builder/__init__.py
index 8f7be0a4..97ffe341 100644
--- a/kag/interface/builder/__init__.py
+++ b/kag/interface/builder/__init__.py
@@ -25,5 +25,5 @@
"MappingABC",
"AlignerABC",
"SinkWriterABC",
- "BuilderChainABC"
+ "BuilderChainABC",
]
diff --git a/kag/interface/retriever/chunk_retriever_abc.py b/kag/interface/retriever/chunk_retriever_abc.py
index 33d8fb3c..41602d22 100644
--- a/kag/interface/retriever/chunk_retriever_abc.py
+++ b/kag/interface/retriever/chunk_retriever_abc.py
@@ -7,6 +7,7 @@
class ChunkRetrieverABC(KagBaseModule, ABC):
def __init__(self, **kwargs):
super().__init__(**kwargs)
+
"""
An abstract base class for chunk retrieval strategies.
diff --git a/kag/interface/retriever/kg_retriever_abc.py b/kag/interface/retriever/kg_retriever_abc.py
index c6593221..8d06686d 100644
--- a/kag/interface/retriever/kg_retriever_abc.py
+++ b/kag/interface/retriever/kg_retriever_abc.py
@@ -3,7 +3,11 @@
from kag.solver.common.base import KagBaseModule
from kag.solver.logic.core_modules.common.base_model import SPOEntity
-from kag.solver.logic.core_modules.common.one_hop_graph import OneHopGraphData, KgGraph, EntityData
+from kag.solver.logic.core_modules.common.one_hop_graph import (
+ OneHopGraphData,
+ KgGraph,
+ EntityData,
+)
from kag.solver.logic.core_modules.parser.logic_node_parser import GetSPONode
@@ -23,9 +27,12 @@ def __init__(self, **kwargs):
retrieval_entity(entity_mention, topk=1, params={}):
Retrieves related entities based on the given entity mention.
"""
+
@abstractmethod
- def retrieval_relation(self, n: GetSPONode, one_hop_graph_list: List[OneHopGraphData], **kwargs) -> KgGraph:
- '''
+ def retrieval_relation(
+ self, n: GetSPONode, one_hop_graph_list: List[OneHopGraphData], **kwargs
+ ) -> KgGraph:
+ """
Input:
n: GetSPONode, the relation to be standardized
one_hop_graph_list: List[OneHopGraphData], list of candidate sets
@@ -33,10 +40,12 @@ def retrieval_relation(self, n: GetSPONode, one_hop_graph_list: List[OneHopGraph
Output:
Returns KgGraph
- '''
+ """
@abstractmethod
- def retrieval_entity(self, mention_entity: SPOEntity, topk=1, **kwargs) -> List[EntityData]:
+ def retrieval_entity(
+ self, mention_entity: SPOEntity, topk=1, **kwargs
+ ) -> List[EntityData]:
"""
Retrieve related entities based on the given entity mention.
diff --git a/kag/interface/solver/kag_generator_abc.py b/kag/interface/solver/kag_generator_abc.py
index 9f3d4aef..eda8154f 100644
--- a/kag/interface/solver/kag_generator_abc.py
+++ b/kag/interface/solver/kag_generator_abc.py
@@ -6,9 +6,10 @@
class KAGGeneratorABC(KagBaseModule, ABC):
"""
- The Generator class is an abstract base class for generating responses using a language model module.
- It initializes prompts for judging and generating responses based on the business scene and language settings.
- """
+ The Generator class is an abstract base class for generating responses using a language model module.
+ It initializes prompts for judging and generating responses based on the business scene and language settings.
+ """
+
def __init__(self, **kwargs):
super().__init__(**kwargs)
diff --git a/kag/interface/solver/kag_memory_abc.py b/kag/interface/solver/kag_memory_abc.py
index c5716745..142d57f0 100644
--- a/kag/interface/solver/kag_memory_abc.py
+++ b/kag/interface/solver/kag_memory_abc.py
@@ -45,4 +45,4 @@ def refresh(self):
Refreshes the memory.
This method is used to reset the memory state.
- """
\ No newline at end of file
+ """
diff --git a/kag/interface/solver/kag_reasoner_abc.py b/kag/interface/solver/kag_reasoner_abc.py
index c2eb3adb..83c5d70a 100644
--- a/kag/interface/solver/kag_reasoner_abc.py
+++ b/kag/interface/solver/kag_reasoner_abc.py
@@ -24,7 +24,10 @@ class KagReasonerABC(KagBaseModule):
- kg_direct: Number of direct knowledge graph queries.
- trace_log: List to log trace information.
"""
- def __init__(self, lf_planner: LFPlannerABC = None, lf_solver: LFSolver = None, **kwargs):
+
+ def __init__(
+ self, lf_planner: LFPlannerABC = None, lf_solver: LFSolver = None, **kwargs
+ ):
super().__init__(**kwargs)
@abstractmethod
@@ -40,4 +43,4 @@ def reason(self, question: str) -> Tuple[str, str, dict]:
- solved_answer: The final answer derived from solving the logical forms.
- supporting_fact: Supporting facts gathered during the reasoning process.
- history_log: A dictionary containing the history of QA pairs and re-ranked documents.
- """
\ No newline at end of file
+ """
diff --git a/kag/interface/solver/kag_reflector_abc.py b/kag/interface/solver/kag_reflector_abc.py
index ef75afd9..4ee59e24 100644
--- a/kag/interface/solver/kag_reflector_abc.py
+++ b/kag/interface/solver/kag_reflector_abc.py
@@ -22,7 +22,9 @@ def reflect_query(self, memory: KagMemoryABC, instruction: str) -> (bool, str):
- refined_query: The refined query (string)
"""
can_answer = self._can_answer(memory, instruction)
- refined_query = self._refine_query(memory, instruction) if not can_answer else instruction
+ refined_query = (
+ self._refine_query(memory, instruction) if not can_answer else instruction
+ )
return can_answer, refined_query
diff --git a/kag/interface/solver/lf_planner_abc.py b/kag/interface/solver/lf_planner_abc.py
index fc7719c4..1570b747 100644
--- a/kag/interface/solver/lf_planner_abc.py
+++ b/kag/interface/solver/lf_planner_abc.py
@@ -1,4 +1,3 @@
-import os
from abc import ABC, abstractmethod
from typing import List
@@ -27,4 +26,4 @@ def lf_planing(self, question, llm_output=None) -> List[LFPlanResult]:
Returns:
list of LFPlanResult
"""
- pass
\ No newline at end of file
+ pass
diff --git a/kag/solver/implementation/default_generator.py b/kag/solver/implementation/default_generator.py
index 05e67ea6..0944cc57 100644
--- a/kag/solver/implementation/default_generator.py
+++ b/kag/solver/implementation/default_generator.py
@@ -7,9 +7,10 @@
class DefaultGenerator(KAGGeneratorABC):
"""
- The Generator class is an abstract base class for generating responses using a language model module.
- It initializes prompts for judging and generating responses based on the business scene and language settings.
- """
+ The Generator class is an abstract base class for generating responses using a language model module.
+ It initializes prompts for judging and generating responses based on the business scene and language settings.
+ """
+
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.generate_prompt = PromptOp.load(self.biz_scene, "resp_generator")(
@@ -22,4 +23,9 @@ def generate(self, instruction, memory: DefaultMemory):
if solved_answer is not None:
return solved_answer
present_memory = memory.serialize_memory()
- return self.llm_module.invoke({'memory': present_memory, 'instruction': instruction}, self.generate_prompt, with_json_parse=False, with_except=True)
+ return self.llm_module.invoke(
+ {"memory": present_memory, "instruction": instruction},
+ self.generate_prompt,
+ with_json_parse=False,
+ with_except=True,
+ )
diff --git a/kag/solver/implementation/default_kg_retrieval.py b/kag/solver/implementation/default_kg_retrieval.py
index fe2a8f8f..df1a1573 100644
--- a/kag/solver/implementation/default_kg_retrieval.py
+++ b/kag/solver/implementation/default_kg_retrieval.py
@@ -9,14 +9,30 @@
from kag.interface.retriever.kg_retriever_abc import KGRetrieverABC
from knext.search.client import SearchClient
from kag.solver.logic.core_modules.common.base_model import SPOEntity
-from kag.solver.logic.core_modules.common.one_hop_graph import KgGraph, OneHopGraphData, EntityData
+from kag.solver.logic.core_modules.common.one_hop_graph import (
+ KgGraph,
+ OneHopGraphData,
+ EntityData,
+)
from kag.solver.logic.core_modules.common.schema_utils import SchemaUtils
from kag.solver.logic.core_modules.common.text_sim_by_vector import TextSimilarity
-from kag.solver.logic.core_modules.common.utils import get_recall_node_label, generate_biz_id_with_type
+from kag.solver.logic.core_modules.common.utils import (
+ get_recall_node_label,
+ generate_biz_id_with_type,
+)
from kag.solver.logic.core_modules.config import LogicFormConfiguration
-from kag.solver.logic.core_modules.parser.logic_node_parser import GetSPONode, ParseLogicForm
-from kag.solver.logic.core_modules.retriver.graph_retriver.dsl_executor import DslRunner, DslRunnerOnGraphStore
-from kag.solver.logic.core_modules.retriver.retrieval_spo import FuzzyMatchRetrievalSpo, ExactMatchRetrievalSpo
+from kag.solver.logic.core_modules.parser.logic_node_parser import (
+ GetSPONode,
+ ParseLogicForm,
+)
+from kag.solver.logic.core_modules.retriver.graph_retriver.dsl_executor import (
+ DslRunner,
+ DslRunnerOnGraphStore,
+)
+from kag.solver.logic.core_modules.retriver.retrieval_spo import (
+ FuzzyMatchRetrievalSpo,
+ ExactMatchRetrievalSpo,
+)
current_dir = os.path.dirname(os.path.abspath(__file__))
import logging
@@ -37,7 +53,9 @@ def __init__(self, disable_exact_match=False, **kwargs):
vectorizer_config = eval(os.getenv("KAG_VECTORIZER", "{}"))
if self.host_addr and self.project_id:
- config = ProjectClient(host_addr=self.host_addr, project_id=self.project_id).get_config(self.project_id)
+ config = ProjectClient(
+ host_addr=self.host_addr, project_id=self.project_id
+ ).get_config(self.project_id)
vectorizer_config.update(config.get("vectorizer", {}))
self.vectorizer: Vectorizer = Vectorizer.from_config(vectorizer_config)
self.text_similarity = TextSimilarity(vec_config=vectorizer_config)
@@ -47,38 +65,62 @@ def __init__(self, disable_exact_match=False, **kwargs):
self.disable_exact_match = disable_exact_match
self.sc: SearchClient = SearchClient(self.host_addr, self.project_id)
- self.dsl_runner: DslRunner = DslRunnerOnGraphStore(self.project_id, self.schema, LogicFormConfiguration(kwargs))
+ self.dsl_runner: DslRunner = DslRunnerOnGraphStore(
+ self.project_id, self.schema, LogicFormConfiguration(kwargs)
+ )
- self.fuzzy_match = FuzzyMatchRetrievalSpo(text_similarity=self.text_similarity, llm=self.llm_module)
+ self.fuzzy_match = FuzzyMatchRetrievalSpo(
+ text_similarity=self.text_similarity, llm=self.llm_module
+ )
self.exact_match = ExactMatchRetrievalSpo(self.schema)
self.parser = ParseLogicForm(self.schema, None)
- def retrieval_relation(self, n: GetSPONode, one_hop_graph_list: List[OneHopGraphData], **kwargs) -> KgGraph:
- req_id = kwargs.get('req_id', '')
- debug_info = kwargs.get('debug_info', {})
+ def retrieval_relation(
+ self, n: GetSPONode, one_hop_graph_list: List[OneHopGraphData], **kwargs
+ ) -> KgGraph:
+ req_id = kwargs.get("req_id", "")
+ debug_info = kwargs.get("debug_info", {})
if not self.disable_exact_match:
- process_kg, is_matched = self._exact_match_spo(n, one_hop_graph_list, req_id)
+ process_kg, is_matched = self._exact_match_spo(
+ n, one_hop_graph_list, req_id
+ )
if is_matched:
- debug_info["exact_match_spo"] = True and debug_info.get('exact_match_spo', True)
+ debug_info["exact_match_spo"] = True and debug_info.get(
+ "exact_match_spo", True
+ )
return process_kg
else:
debug_info["exact_match_spo"] = False
return self._fuzzy_match_spo(n, one_hop_graph_list, req_id)
- def retrieval_entity(self, mention_entity: SPOEntity, topk=5, **kwargs) -> List[EntityData]:
- recalled_el_set = self._search_retrieval_entity(mention_entity, topk=topk, kwargs=kwargs)
+ def retrieval_entity(
+ self, mention_entity: SPOEntity, topk=5, **kwargs
+ ) -> List[EntityData]:
+ recalled_el_set = self._search_retrieval_entity(
+ mention_entity, topk=topk, kwargs=kwargs
+ )
if len(mention_entity.value_list) == 0:
return recalled_el_set
# 存在参数进行过滤,先过去一跳子图
- one_hop_graph_map = self.dsl_runner.query_vertex_one_graph_by_s_o_ids(recalled_el_set, [], {})
+ one_hop_graph_map = self.dsl_runner.query_vertex_one_graph_by_s_o_ids(
+ recalled_el_set, [], {}
+ )
matched_entity_list = recalled_el_set
# 将待匹配的作为spg进行匹配
for k, v in mention_entity.value_list:
- choosed_one_hop_graph_list = self._get_matched_one_hop(one_hop_graph_map, matched_entity_list)
+ choosed_one_hop_graph_list = self._get_matched_one_hop(
+ one_hop_graph_map, matched_entity_list
+ )
param_spo = f"get_spo(s=s1:{mention_entity.get_entity_first_type_or_zh()}[{mention_entity.entity_name}],p=p1:{k},o=o1:Entity[{v}])"
- tmp_spo = self.parser.parse_logic_form(param_spo, parsed_entity_set={}, sub_query=f"{mention_entity.entity_name} {k} {v}")
+ tmp_spo = self.parser.parse_logic_form(
+ param_spo,
+ parsed_entity_set={},
+ sub_query=f"{mention_entity.entity_name} {k} {v}",
+ )
debug_info = {}
- kg_graph = self.retrieval_relation(tmp_spo, choosed_one_hop_graph_list, debug_info=debug_info)
+ kg_graph = self.retrieval_relation(
+ tmp_spo, choosed_one_hop_graph_list, debug_info=debug_info
+ )
kg_graph.nodes_alias.append(tmp_spo.s.alias_name)
kg_graph.nodes_alias.append(tmp_spo.o.alias_name)
kg_graph.edge_alias.append(tmp_spo.p.alias_name)
@@ -91,31 +133,43 @@ def retrieval_entity(self, mention_entity: SPOEntity, topk=5, **kwargs) -> List[
def _get_matched_one_hop(self, one_hop_graph_map: dict, matched_entity_list: list):
ret_one_hop_list = []
for matched_entity in matched_entity_list:
- cached_id = generate_biz_id_with_type(matched_entity.biz_id,
- matched_entity.type_zh if matched_entity.type_zh else matched_entity.type)
+ cached_id = generate_biz_id_with_type(
+ matched_entity.biz_id,
+ matched_entity.type_zh
+ if matched_entity.type_zh
+ else matched_entity.type,
+ )
if cached_id in one_hop_graph_map:
ret_one_hop_list.append(one_hop_graph_map[cached_id])
return ret_one_hop_list
- def _search_retrieval_entity(self, mention_entity: SPOEntity, topk=5, **kwargs) -> List[EntityData]:
+ def _search_retrieval_entity(
+ self, mention_entity: SPOEntity, topk=5, **kwargs
+ ) -> List[EntityData]:
retdata = []
if mention_entity is None:
return retdata
- content = kwargs.get('content', mention_entity.entity_name)
+ content = kwargs.get("content", mention_entity.entity_name)
query_type = mention_entity.get_entity_first_type_or_zh()
- recognition_threshold = kwargs.get('recognition_threshold', 0.8)
+ recognition_threshold = kwargs.get("recognition_threshold", 0.8)
recall_topk = topk
if "entity" not in query_type.lower():
recall_topk = 10
query_vector = self.vectorizer.vectorize(mention_entity.entity_name)
typed_nodes = self.sc.search_vector(
- label="Entity", property_key="name", query_vector=query_vector, topk=recall_topk
+ label="Entity",
+ property_key="name",
+ query_vector=query_vector,
+ topk=recall_topk,
)
# 根据query召回
if query_type not in ["Others", "Entity"]:
content_vector = self.vectorizer.vectorize(content)
content_recall_nodes = self.sc.search_vector(
- label="Entity", property_key="desc", query_vector=content_vector, topk=recall_topk
+ label="Entity",
+ property_key="desc",
+ query_vector=content_vector,
+ topk=recall_topk,
)
else:
content_recall_nodes = []
@@ -127,48 +181,69 @@ def _search_retrieval_entity(self, mention_entity: SPOEntity, topk=5, **kwargs)
def rerank_sematic_type(cands_nodes: list, sematic_type: str):
sematic_type_list = []
for cands in cands_nodes:
- node = cands['node']
- if "semanticType" not in node.keys() or node['semanticType'] == '':
+ node = cands["node"]
+ if "semanticType" not in node.keys() or node["semanticType"] == "":
continue
- sematic_type_list.append(node['semanticType'])
+ sematic_type_list.append(node["semanticType"])
sematic_type_list = list(set(sematic_type_list))
- sematic_match_score_list = self.text_similarity.text_sim_result(sematic_type, sematic_type_list,
- len(sematic_type_list), low_score=-1)
+ sematic_match_score_list = self.text_similarity.text_sim_result(
+ sematic_type, sematic_type_list, len(sematic_type_list), low_score=-1
+ )
sematic_match_score_map = {}
for i in sematic_match_score_list:
sematic_match_score_map[i[0]] = i[1]
for node in cands_nodes:
- recall_node_label = get_recall_node_label(node['node']['__labels__'])
+ recall_node_label = get_recall_node_label(node["node"]["__labels__"])
if recall_node_label == sematic_type:
- node['type_match_score'] = node['score']
- elif "semanticType" not in node['node'].keys() or node['node']['semanticType'] == '':
- node['type_match_score'] = 0.3
+ node["type_match_score"] = node["score"]
+ elif (
+ "semanticType" not in node["node"].keys()
+ or node["node"]["semanticType"] == ""
+ ):
+ node["type_match_score"] = 0.3
else:
- node['type_match_score'] = node['score'] * sematic_match_score_map[node['node']['semanticType']]
- sorted_people_dicts = sorted(cands_nodes, key=lambda node: node['type_match_score'], reverse=True)
+ node["type_match_score"] = (
+ node["score"]
+ * sematic_match_score_map[node["node"]["semanticType"]]
+ )
+ sorted_people_dicts = sorted(
+ cands_nodes, key=lambda node: node["type_match_score"], reverse=True
+ )
# 取top5
return sorted_people_dicts[:topk]
if "entity" not in query_type.lower():
sorted_nodes = rerank_sematic_type(sorted_nodes, query_type)
- sorted_people_dicts = sorted(sorted_nodes, key=lambda node: node['score'], reverse=True)
+ sorted_people_dicts = sorted(
+ sorted_nodes, key=lambda node: node["score"], reverse=True
+ )
for recall in sorted_people_dicts:
- if len(sorted_people_dicts) != 0 and recall["score"] >= recognition_threshold:
+ if (
+ len(sorted_people_dicts) != 0
+ and recall["score"] >= recognition_threshold
+ ):
recalled_entity = EntityData()
recalled_entity.score = recall["score"]
recalled_entity.biz_id = recall["node"]["id"]
recalled_entity.name = recall["node"]["name"]
- recalled_entity.type = get_recall_node_label(recall["node"]["__labels__"])
+ recalled_entity.type = get_recall_node_label(
+ recall["node"]["__labels__"]
+ )
retdata.append(recalled_entity)
else:
break
return retdata[:topk]
- def _exact_match_spo(self, n: GetSPONode, one_hop_graph_list: List[OneHopGraphData], req_id: str):
+ def _exact_match_spo(
+ self, n: GetSPONode, one_hop_graph_list: List[OneHopGraphData], req_id: str
+ ):
start_time = time.time()
- total_one_kg_graph, matched_flag = self.exact_match.match_spo(n, one_hop_graph_list)
+ total_one_kg_graph, matched_flag = self.exact_match.match_spo(
+ n, one_hop_graph_list
+ )
logger.debug(
- f"{req_id} _exact_match_spo cost={time.time() - start_time} matched_flag={matched_flag}")
+ f"{req_id} _exact_match_spo cost={time.time() - start_time} matched_flag={matched_flag}"
+ )
if not matched_flag:
return total_one_kg_graph, matched_flag
for alias_name in total_one_kg_graph.entity_map.keys():
@@ -179,10 +254,15 @@ def _exact_match_spo(self, n: GetSPONode, one_hop_graph_list: List[OneHopGraphDa
return total_one_kg_graph, False
return total_one_kg_graph, matched_flag
- def _fuzzy_match_spo(self, n: GetSPONode, one_hop_graph_list: List[OneHopGraphData], req_id: str):
+ def _fuzzy_match_spo(
+ self, n: GetSPONode, one_hop_graph_list: List[OneHopGraphData], req_id: str
+ ):
start_time = time.time()
- total_one_kg_graph, matched_flag = self.fuzzy_match.match_spo(n, one_hop_graph_list)
+ total_one_kg_graph, matched_flag = self.fuzzy_match.match_spo(
+ n, one_hop_graph_list
+ )
logger.debug(
- f"{req_id} _fuzzy_match_spo cost={time.time() - start_time} matched_flag={matched_flag}")
+ f"{req_id} _fuzzy_match_spo cost={time.time() - start_time} matched_flag={matched_flag}"
+ )
return total_one_kg_graph
diff --git a/kag/solver/implementation/default_lf_planner.py b/kag/solver/implementation/default_lf_planner.py
index 75cc9314..e94d8520 100644
--- a/kag/solver/implementation/default_lf_planner.py
+++ b/kag/solver/implementation/default_lf_planner.py
@@ -61,29 +61,36 @@ def _split_sub_query(self, logic_nodes: List[LogicNode]) -> List[LFPlanResult]:
def _parse_lf(self, question, sub_querys, logic_forms) -> List[LFPlanResult]:
if sub_querys is None:
sub_querys = []
- parsed_logic_nodes = self.parser.parse_logic_form_set(logic_forms, sub_querys, question)
+ parsed_logic_nodes = self.parser.parse_logic_form_set(
+ logic_forms, sub_querys, question
+ )
return self._split_sub_query(parsed_logic_nodes)
def generate_logic_form(self, question: str):
- return self.llm_module.invoke({'question': question}, self.logic_form_plan_prompt, with_json_parse=False, with_except=True)
+ return self.llm_module.invoke(
+ {"question": question},
+ self.logic_form_plan_prompt,
+ with_json_parse=False,
+ with_except=True,
+ )
def parse_logic_form_llm_output(self, llm_output):
_output_string = llm_output.replace(":", ":")
_output_string = llm_output.strip()
sub_querys = []
logic_forms = []
- current_sub_query = ''
- for line in _output_string.split('\n'):
+ current_sub_query = ""
+ for line in _output_string.split("\n"):
line = line.strip()
- if line.startswith('Step'):
- sub_querys_regex = re.search('Step\d+:(.*)', line)
+ if line.startswith("Step"):
+ sub_querys_regex = re.search("Step\d+:(.*)", line)
if sub_querys_regex is not None:
sub_querys.append(sub_querys_regex.group(1))
current_sub_query = sub_querys_regex.group(1)
- elif line.startswith('Output'):
+ elif line.startswith("Output"):
sub_querys.append("output")
- elif line.startswith('Action'):
- logic_forms_regex = re.search('Action\d+:(.*)', line)
+ elif line.startswith("Action"):
+ logic_forms_regex = re.search("Action\d+:(.*)", line)
if logic_forms_regex:
logic_forms.append(logic_forms_regex.group(1))
if len(logic_forms) - len(sub_querys) == 1:
diff --git a/kag/solver/implementation/default_memory.py b/kag/solver/implementation/default_memory.py
index ae55c96b..faf5305b 100644
--- a/kag/solver/implementation/default_memory.py
+++ b/kag/solver/implementation/default_memory.py
@@ -20,9 +20,12 @@ def __init__(self, **kwargs):
@retry(stop=stop_after_attempt(3))
def _verifier(self, supporting_fact, sub_instruction):
- res = self.llm_module.invoke({'sub_instruction': sub_instruction,
- 'supporting_fact': supporting_fact}, self.verify_prompt,
- with_json_parse=False, with_except=True)
+ res = self.llm_module.invoke(
+ {"sub_instruction": sub_instruction, "supporting_fact": supporting_fact},
+ self.verify_prompt,
+ with_json_parse=False,
+ with_except=True,
+ )
if res is None:
return
if res not in self.state_memory:
@@ -30,10 +33,14 @@ def _verifier(self, supporting_fact, sub_instruction):
@retry(stop=stop_after_attempt(3))
def _extractor(self, supporting_fact, instruction):
- if supporting_fact is None or supporting_fact == '':
+ if supporting_fact is None or supporting_fact == "":
return
- evidence = self.llm_module.invoke({'supporting_fact': supporting_fact, 'instruction': instruction},
- self.extractor_prompt, with_json_parse=False, with_except=True)
+ evidence = self.llm_module.invoke(
+ {"supporting_fact": supporting_fact, "instruction": instruction},
+ self.extractor_prompt,
+ with_json_parse=False,
+ with_except=True,
+ )
if evidence not in self.evidence_memory:
self.evidence_memory.append(evidence)
@@ -62,4 +69,4 @@ def serialize_memory(self):
def refresh(self):
self.state_memory = []
self.evidence_memory = []
- self.exact_answer = []
\ No newline at end of file
+ self.exact_answer = []
diff --git a/kag/solver/implementation/default_reasoner.py b/kag/solver/implementation/default_reasoner.py
index 54a4316e..1b3f29ba 100644
--- a/kag/solver/implementation/default_reasoner.py
+++ b/kag/solver/implementation/default_reasoner.py
@@ -11,6 +11,7 @@
logger = logging.getLogger()
+
class DefaultReasoner(KagReasonerABC):
"""
A processor class for handling logical form tasks in language processing.
@@ -29,18 +30,16 @@ class DefaultReasoner(KagReasonerABC):
- trace_log: List to log trace information.
"""
- def __init__(self, lf_planner: LFPlannerABC = None, lf_solver: LFSolver = None, **kwargs):
- super().__init__(
- lf_planner=lf_planner,
- lf_solver=lf_solver,
- **kwargs
- )
+ def __init__(
+ self, lf_planner: LFPlannerABC = None, lf_solver: LFSolver = None, **kwargs
+ ):
+ super().__init__(lf_planner=lf_planner, lf_solver=lf_solver, **kwargs)
self.lf_planner = lf_planner or DefaultLFPlanner(**kwargs)
self.lf_solver = lf_solver or LFSolver(
kg_retriever=KGRetrieverByLlm(**kwargs),
chunk_retriever=LFChunkRetriever(**kwargs),
- **kwargs
+ **kwargs,
)
self.sub_query_total = 0
@@ -63,21 +62,22 @@ def reason(self, question: str):
lf_nodes: List[LFPlanResult] = self.lf_planner.lf_planing(question)
# logic form execution
- solved_answer, sub_qa_pair, recall_docs, history_qa_log = self.lf_solver.solve(question, lf_nodes)
+ solved_answer, sub_qa_pair, recall_docs, history_qa_log = self.lf_solver.solve(
+ question, lf_nodes
+ )
# Generate supporting facts for sub question-answer pair
- supporting_fact = '\n'.join(sub_qa_pair)
+ supporting_fact = "\n".join(sub_qa_pair)
# Retrieve and rank documents
sub_querys = [lf.query for lf in lf_nodes]
if self.lf_solver.chunk_retriever:
- docs = self.lf_solver.chunk_retriever.rerank_docs([question] + sub_querys, recall_docs)
+ docs = self.lf_solver.chunk_retriever.rerank_docs(
+ [question] + sub_querys, recall_docs
+ )
else:
logger.info("DefaultReasoner not enable chunk retriever")
docs = []
- history_log = {
- 'history': history_qa_log,
- 'rerank_docs': docs
- }
+ history_log = {"history": history_qa_log, "rerank_docs": docs}
if len(docs) > 0:
# Append supporting facts for retrieved chunks
supporting_fact += f"\nPassages:{str(docs)}"
diff --git a/kag/solver/implementation/default_reflector.py b/kag/solver/implementation/default_reflector.py
index b972ad5b..a5bce7b7 100644
--- a/kag/solver/implementation/default_reflector.py
+++ b/kag/solver/implementation/default_reflector.py
@@ -44,8 +44,12 @@ def _can_answer(self, memory: KagMemoryABC, instruction: str):
if memory.get_solved_answer() != "":
return True
- return self.llm_module.invoke({'memory': serialize_memory, 'instruction': instruction}, self.judge_prompt,
- with_json_parse=False, with_except=True)
+ return self.llm_module.invoke(
+ {"memory": serialize_memory, "instruction": instruction},
+ self.judge_prompt,
+ with_json_parse=False,
+ with_except=True,
+ )
@retry(stop=stop_after_attempt(3))
def _refine_query(self, memory: KagMemoryABC, instruction: str):
@@ -60,9 +64,12 @@ def _refine_query(self, memory: KagMemoryABC, instruction: str):
if serialize_memory == "":
return instruction
- update_reason_path = self.llm_module.invoke({"memory": serialize_memory, "instruction": instruction},
- self.refine_prompt,
- with_json_parse=False, with_except=True)
+ update_reason_path = self.llm_module.invoke(
+ {"memory": serialize_memory, "instruction": instruction},
+ self.refine_prompt,
+ with_json_parse=False,
+ with_except=True,
+ )
if len(update_reason_path) == 0:
return None
- return update_reason_path[0]
\ No newline at end of file
+ return update_reason_path[0]
diff --git a/kag/solver/implementation/lf_chunk_retriever.py b/kag/solver/implementation/lf_chunk_retriever.py
index 576b7239..70a2cfdc 100644
--- a/kag/solver/implementation/lf_chunk_retriever.py
+++ b/kag/solver/implementation/lf_chunk_retriever.py
@@ -6,7 +6,10 @@
from kag.common.retriever import DefaultRetriever
from kag.solver.logic.core_modules.common.one_hop_graph import EntityData
-from kag.solver.logic.core_modules.common.text_sim_by_vector import TextSimilarity, cosine_similarity
+from kag.solver.logic.core_modules.common.text_sim_by_vector import (
+ TextSimilarity,
+ cosine_similarity,
+)
from kag.solver.logic.core_modules.retriver.retrieval_spo import logger
from knext.project.client import ProjectClient
@@ -16,7 +19,9 @@ def __init__(self, **kwargs):
super().__init__(**kwargs)
vectorizer_config = eval(os.getenv("KAG_VECTORIZER", "{}"))
if self.host_addr and self.project_id:
- config = ProjectClient(host_addr=self.host_addr, project_id=self.project_id).get_config(self.project_id)
+ config = ProjectClient(
+ host_addr=self.host_addr, project_id=self.project_id
+ ).get_config(self.project_id)
vectorizer_config.update(config.get("vectorizer", {}))
self.text_sim = TextSimilarity(vec_config=vectorizer_config)
@@ -32,7 +37,10 @@ def rerank(self, queries: List[str], passages: List[str]):
for query in queries:
query_emb = self.text_sim.sentence_encode(query)
- scores = [cosine_similarity(np.array(query_emb), np.array(passage_emb)) for passage_emb in passages_embs]
+ scores = [
+ cosine_similarity(np.array(query_emb), np.array(passage_emb))
+ for passage_emb in passages_embs
+ ]
sorted_idx = np.argsort(-np.array(scores))
for rank, passage_id in enumerate(sorted_idx):
passage_scores[passage_id] += rank_scores[rank]
@@ -43,12 +51,14 @@ def rerank(self, queries: List[str], passages: List[str]):
return new_passages[:10]
def recall_docs(self, query: str, top_k=5, **kwargs):
- all_related_entities = kwargs.get('related_entities', None)
- query_ner_dict = kwargs.get('query_ner_dict', None)
- req_id = kwargs.get('req_id', '')
+ all_related_entities = kwargs.get("related_entities", None)
+ query_ner_dict = kwargs.get("query_ner_dict", None)
+ req_id = kwargs.get("req_id", "")
if all_related_entities is None:
return super().recall_docs(query, top_k, **kwargs)
- return self.recall_docs_by_entities(query, all_related_entities, top_k, req_id, query_ner_dict)
+ return self.recall_docs_by_entities(
+ query, all_related_entities, top_k, req_id, query_ner_dict
+ )
def get_std_ner_by_query(self, query: str):
"""
@@ -61,9 +71,7 @@ def get_std_ner_by_query(self, query: str):
dict: A dictionary containing standardized entity names and types along with their scores.
"""
entities = self.named_entity_recognition(query)
- entities_with_official_name = self.named_entity_standardization(
- query, entities
- )
+ entities_with_official_name = self.named_entity_standardization(query, entities)
query_ner_list = entities_with_official_name
try:
@@ -95,17 +103,19 @@ def get_std_ner_by_query(self, query: str):
return_data[f"{phrases['name']}_{phrases['type']}"] = phrases
return return_data
- def recall_docs_by_entities(self, query: str, all_related_entities: List[EntityData], top_k=10,
- req_id='', query_ner_dict: dict = None):
+ def recall_docs_by_entities(
+ self,
+ query: str,
+ all_related_entities: List[EntityData],
+ top_k=10,
+ req_id="",
+ query_ner_dict: dict = None,
+ ):
def convert_entity_data_to_ppr_cand(related_entities: List[EntityData]):
ret_ppr_candis = {}
for e in related_entities:
k = f"{e.name}_{e.type}"
- ret_ppr_candis[k] = {
- 'name': e.name,
- 'type': e.type,
- 'score': e.score
- }
+ ret_ppr_candis[k] = {"name": e.name, "type": e.type, "score": e.score}
return ret_ppr_candis
start_time = time.time()
@@ -114,18 +124,20 @@ def convert_entity_data_to_ppr_cand(related_entities: List[EntityData]):
kg_cands = convert_entity_data_to_ppr_cand(all_related_entities)
except Exception as e:
kg_cands = {}
- logger.warning(f"{req_id} {query} generate logic form failed {str(e)}", exc_info=True)
+ logger.warning(
+ f"{req_id} {query} generate logic form failed {str(e)}", exc_info=True
+ )
for k, v in ner_cands.items():
if k in kg_cands.keys():
- if v['score'] > kg_cands[k]['score']:
- kg_cands[k]['score'] = v['score']
+ if v["score"] > kg_cands[k]["score"]:
+ kg_cands[k]["score"] = v["score"]
else:
kg_cands[k] = v
if query_ner_dict is not None:
for k, v in query_ner_dict.items():
if k in kg_cands.keys():
- if v['score'] > kg_cands[k]['score']:
- kg_cands[k]['score'] = v['score']
+ if v["score"] > kg_cands[k]["score"]:
+ kg_cands[k]["score"] = v["score"]
else:
kg_cands[k] = v
@@ -133,7 +145,7 @@ def convert_entity_data_to_ppr_cand(related_entities: List[EntityData]):
matched_entities_scores = []
for _, v in kg_cands.items():
matched_entities.append(v)
- matched_entities_scores.append(v['score'])
+ matched_entities_scores.append(v["score"])
logger.info(f"{req_id} kgpath ner cost={time.time() - start_time}")
start_time = time.time()
@@ -141,7 +153,8 @@ def convert_entity_data_to_ppr_cand(related_entities: List[EntityData]):
combined_scores = self.calculate_sim_scores(query, top_k * 20)
logger.info(f"{req_id} only get_dpr_scores cost={time.time() - start_time}")
elif (
- matched_entities_scores and np.min(matched_entities_scores) > self.pagerank_threshold
+ matched_entities_scores
+ and np.min(matched_entities_scores) > self.pagerank_threshold
): # high confidence in named entities
combined_scores = self.calculate_pagerank_scores(matched_entities)
else:
@@ -152,11 +165,13 @@ def convert_entity_data_to_ppr_cand(related_entities: List[EntityData]):
sim_doc_scores = self.calculate_sim_scores(query, top_k * 20)
logger.info(f"{req_id} only get_dpr_scores cost={time.time() - start_time}")
- combined_scores = self.calculate_combined_scores(sim_doc_scores, pagerank_scores)
+ combined_scores = self.calculate_combined_scores(
+ sim_doc_scores, pagerank_scores
+ )
# Return ranked docs and ranked scores
sorted_doc_ids = sorted(
- combined_scores.items(), key=lambda item: item[1], reverse=True
- )
+ combined_scores.items(), key=lambda item: item[1], reverse=True
+ )
logger.debug(f"kgpath chunk recall cost={time.time() - start_time}")
- return self.get_all_docs_by_id(query, sorted_doc_ids, top_k)
\ No newline at end of file
+ return self.get_all_docs_by_id(query, sorted_doc_ids, top_k)
diff --git a/kag/solver/logic/core_modules/common/base_model.py b/kag/solver/logic/core_modules/common/base_model.py
index 048c2ee5..fa4d8105 100644
--- a/kag/solver/logic/core_modules/common/base_model.py
+++ b/kag/solver/logic/core_modules/common/base_model.py
@@ -37,8 +37,11 @@ def __repr__(self):
def parse_entity(raw_entity):
if raw_entity is None:
return []
- entity_parts = re.findall(r'(?:`(.+?)`|([^|]+))', raw_entity)
- return [part.replace('``', '|') if part else escaping_part for escaping_part, part in entity_parts]
+ entity_parts = re.findall(r"(?:`(.+?)`|([^|]+))", raw_entity)
+ return [
+ part.replace("``", "|") if part else escaping_part
+ for escaping_part, part in entity_parts
+ ]
class SPOBase:
@@ -52,7 +55,7 @@ def __repr__(self):
return f"{self.alias_name}:{self.get_entity_first_type_or_en()}"
def get_value_list_str(self):
- return [f"{self.alias_name}.{k}={v}" for k,v in self.value_list]
+ return [f"{self.alias_name}.{k}={v}" for k, v in self.value_list]
def get_mention_name(self):
return ""
@@ -63,7 +66,9 @@ def get_type_with_gql_format(self):
if len(entity_types) == 0 and len(entity_zh_types) == 0:
return None
if None in entity_types and None in entity_zh_types:
- raise RuntimeError(f"None type in entity type en {entity_types} zh {entity_zh_types}")
+ raise RuntimeError(
+ f"None type in entity type en {entity_types} zh {entity_zh_types}"
+ )
if len(entity_types) > 0:
return "|".join(entity_types)
if len(entity_zh_types) > 0:
@@ -154,7 +159,7 @@ def parse_logic_form(input_str):
rel_type_set = []
# Split the input string into alias and entity_type_set parts
- split_input = input_str.split(':', 1)
+ split_input = input_str.split(":", 1)
alias = split_input[0]
# If entity_type_set exists, process it further
if len(split_input) > 1:
@@ -173,8 +178,15 @@ def parse_logic_form(input_str):
class SPOEntity(SPOBase):
- def __init__(self, entity_id=None, entity_type=None, entity_type_zh=None, entity_name=None, alias_name=None,
- is_attribute=False):
+ def __init__(
+ self,
+ entity_id=None,
+ entity_type=None,
+ entity_type_zh=None,
+ entity_name=None,
+ alias_name=None,
+ is_attribute=False,
+ ):
super().__init__()
self.is_attribute = is_attribute
self.id_set = []
@@ -191,12 +203,15 @@ def __init__(self, entity_id=None, entity_type=None, entity_type_zh=None, entity
self.type_set.append(type_info)
def __str__(self):
- show = [f"{self.alias_name}:{self.get_entity_first_type_or_en()}{'' if self.entity_name is None else '[' + self.entity_name + ']'} "]
+ show = [
+ f"{self.alias_name}:{self.get_entity_first_type_or_en()}{'' if self.entity_name is None else '[' + self.entity_name + ']'} "
+ ]
show = show + self.get_value_list_str()
return ",".join(show)
def get_mention_name(self):
return self.entity_name
+
def generate_id_key(self):
if len(self.id_set) == 0:
return None
@@ -210,13 +225,16 @@ def generate_start_infos(self, prefix=None):
return []
id_type_info = list(itertools.product(self.id_set, self.type_set))
- return [{
- "alias": self.alias_name.alias_name,
- "id": info[0],
- "type": info[1].entity_type if '.' in info[1].entity_type else (
- prefix + '.' if prefix is not None else '') +
- info[1].entity_type
- } for info in id_type_info]
+ return [
+ {
+ "alias": self.alias_name.alias_name,
+ "id": info[0],
+ "type": info[1].entity_type
+ if "." in info[1].entity_type
+ else (prefix + "." if prefix is not None else "") + info[1].entity_type,
+ }
+ for info in id_type_info
+ ]
@staticmethod
def parse_logic_form(input_str):
@@ -235,9 +253,9 @@ def parse_logic_form(input_str):
entity_type_set = parse_entity(entity_type_raw)
# 解析entity_name和entity_id_set
- entity_name = entity_name_raw.strip('][') if entity_name_raw else None
- entity_name = entity_name.strip('`') if entity_name else None
- entity_id_set = parse_entity(entity_id_raw.strip('][')) if entity_id_raw else []
+ entity_name = entity_name_raw.strip("][") if entity_name_raw else None
+ entity_name = entity_name.strip("`") if entity_name else None
+ entity_id_set = parse_entity(entity_id_raw.strip("][")) if entity_id_raw else []
spo_entity = SPOEntity()
spo_entity.id_set = entity_id_set
@@ -252,7 +270,14 @@ def parse_logic_form(input_str):
class Entity:
- def __init__(self, entity_id=None, entity_type=None, entity_type_zh=None, entity_name=None, alias_name=None):
+ def __init__(
+ self,
+ entity_id=None,
+ entity_type=None,
+ entity_type_zh=None,
+ entity_name=None,
+ alias_name=None,
+ ):
self.id = entity_id
self.type = entity_type
self.entity_type_zh = entity_type_zh
@@ -262,7 +287,9 @@ def __init__(self, entity_id=None, entity_type=None, entity_type_zh=None, entity
def __repr__(self):
return f"{[self.entity_name, self.alias_name]}:{self.id}({self.type, self.entity_type_zh})"
- def save_args(self, id=None, type=None, entity_type_zh=None, entity_name=None, alias_name=None):
+ def save_args(
+ self, id=None, type=None, entity_type_zh=None, entity_name=None, alias_name=None
+ ):
self.id = id if id else self.id
self.type = type if type else self.type
self.entity_type_zh = entity_type_zh if entity_type_zh else self.entity_type_zh
@@ -271,37 +298,42 @@ def save_args(self, id=None, type=None, entity_type_zh=None, entity_name=None, a
@staticmethod
def parse_zh(entity_str):
- alias, type_zh, name = '', '', ''
- entity_str = entity_str.replace(':', ':')
- match_alias_type_entity = re.match(r'(.*):(.*)\[(.*)\]', entity_str)
+ alias, type_zh, name = "", "", ""
+ entity_str = entity_str.replace(":", ":")
+ match_alias_type_entity = re.match(r"(.*):(.*)\[(.*)\]", entity_str)
if match_alias_type_entity:
alias, type_zh, name = match_alias_type_entity.groups()
else:
- match_alias_type = re.match(r'(.*):(.*)', entity_str)
+ match_alias_type = re.match(r"(.*):(.*)", entity_str)
if match_alias_type:
alias, type_zh = match_alias_type.groups()
else:
alias = entity_str
- return Entity(entity_type_zh=type_zh.strip(), entity_name=name.strip(), alias_name=alias.strip())
+ return Entity(
+ entity_type_zh=type_zh.strip(),
+ entity_name=name.strip(),
+ alias_name=alias.strip(),
+ )
class LogicNode:
def __init__(self, operator, args):
self.operator = operator
self.args = args
- self.sub_query = args.get('sub_query', '')
+ self.sub_query = args.get("sub_query", "")
def __repr__(self):
params = [f"{k}={v}" for k, v in self.args.items()]
- params_str = ','.join(params)
+ params_str = ",".join(params)
return f"{self.operator}({params_str})"
def to_dict(self):
return json.loads(self.to_json())
def to_json(self):
- return json.dumps(obj=self,
- default=lambda x: x.__dict__, sort_keys=False, indent=2)
+ return json.dumps(
+ obj=self, default=lambda x: x.__dict__, sort_keys=False, indent=2
+ )
def to_dsl(self):
raise NotImplementedError("Subclasses should implement this method.")
@@ -309,10 +341,10 @@ def to_dsl(self):
def to_std(self, args):
for key, value in args.items():
self.args[key] = value
- self.sub_query = args.get('sub_query', '')
+ self.sub_query = args.get("sub_query", "")
class LFPlanResult:
def __init__(self, query: str, lf_nodes: List[LogicNode]):
self.query: str = query
- self.lf_nodes: List[LogicNode] = lf_nodes
\ No newline at end of file
+ self.lf_nodes: List[LogicNode] = lf_nodes
diff --git a/kag/solver/logic/core_modules/common/one_hop_graph.py b/kag/solver/logic/core_modules/common/one_hop_graph.py
index 0cb1848c..4edaf809 100644
--- a/kag/solver/logic/core_modules/common/one_hop_graph.py
+++ b/kag/solver/logic/core_modules/common/one_hop_graph.py
@@ -22,23 +22,21 @@ def find_and_extra_prop_objects(text):
list: A list of dictionaries representing the parsed objects.
"""
- pattern = re.compile(r'\001(.*?)\003')
+ pattern = re.compile(r"\001(.*?)\003")
matches = pattern.findall(text)
objects = []
for match in matches:
- attributes = match.split('\002')
+ attributes = match.split("\002")
if len(attributes) != 3:
logger.info(f"find_and_extra_prop_objects attribute not match {match}")
continue
- objects.append({
- "id": attributes[1],
- "name": attributes[0],
- "type": attributes[2]
- })
+ objects.append(
+ {"id": attributes[1], "name": attributes[0], "type": attributes[2]}
+ )
return objects
@@ -76,7 +74,7 @@ def from_dict(json_dict: dict, label_name: str, schema: SchemaUtils):
attr_en_zh = Prop._get_attr_en_zh_by_label(label_name, schema)
black_attr = ["biz_node_id", "gdb_timestamp"]
for k in json_dict.keys():
- if json_dict[k] == '' or k in black_attr:
+ if json_dict[k] == "" or k in black_attr:
continue
if k.startswith("_") or k in ext_attrs:
continue
@@ -97,13 +95,16 @@ def from_dict(json_dict: dict, label_name: str, schema: SchemaUtils):
prop.linked_prop_map[k] = link_res
prop.extend_prop_map = basic_info
except Exception as e:
- logger.warning(f"parse basic info failed reasone: {json_dict[ext_attr]}", exc_info=True)
+ logger.warning(
+ f"parse basic info failed reasone: {json_dict[ext_attr]}",
+ exc_info=True,
+ )
return prop
def to_json(self):
return {
"origin_prop_map": self.origin_prop_map,
- "extend_prop_map": self.extend_prop_map
+ "extend_prop_map": self.extend_prop_map,
}
def get_prop_value(self, p):
@@ -130,7 +131,11 @@ def get_properties_map_list_value(self):
return self.prop.get_properties_map_list_value()
def to_show_id(self):
- if self.type in ['verify_op_result'] and self.description is not None and self.description != '':
+ if (
+ self.type in ["verify_op_result"]
+ and self.description is not None
+ and self.description != ""
+ ):
return f"{self.type_zh}[{self.name}]({self.description})"
if self.name == self.biz_id:
return f"{self.type_zh}[{self.name}]"
@@ -143,7 +148,7 @@ def to_json(self):
"name": self.name,
"description": self.description,
"type": self.type,
- "type_zh": self.type_zh
+ "type_zh": self.type_zh,
}
def get_attribute_value(self, p):
@@ -154,42 +159,52 @@ def get_attribute_value(self, p):
def merge_entity_data(self, other):
if other.prop is not None:
self.prop = other.prop
- if other.name is not None and other.name != '':
+ if other.name is not None and other.name != "":
self.name = other.name
- if other.description is not None and other.description != '':
+ if other.description is not None and other.description != "":
self.description = other.description
- if other.type is not None and other.type != '':
+ if other.type is not None and other.type != "":
self.type = other.type
- if other.type_zh is not None and other.type_zh != '':
+ if other.type_zh is not None and other.type_zh != "":
self.type_zh = other.type_zh
def to_spo_list(self):
spo_list = []
- spo_list.append(json.dumps({
- "s": self.name,
- "p": "归属类型",
- "o": self.type
- }, ensure_ascii=False))
+ spo_list.append(
+ json.dumps(
+ {"s": self.name, "p": "归属类型", "o": self.type}, ensure_ascii=False
+ )
+ )
if self.prop is not None:
for prop_key in self.prop.origin_prop_map.keys():
if prop_key.startswith("_"):
continue
- if prop_key in ['id', 'name']:
+ if prop_key in ["id", "name"]:
continue
- spo_list.append(json.dumps({
- "s": self.name,
- "p": prop_key,
- "o": self.prop.origin_prop_map[prop_key]
- }, ensure_ascii=False))
+ spo_list.append(
+ json.dumps(
+ {
+ "s": self.name,
+ "p": prop_key,
+ "o": self.prop.origin_prop_map[prop_key],
+ },
+ ensure_ascii=False,
+ )
+ )
for prop_key in self.prop.extend_prop_map.keys():
- spo_list.append(json.dumps({
- "s": self.name,
- "p": prop_key,
- "o": self.prop.extend_prop_map[prop_key]
- }, ensure_ascii=False))
+ spo_list.append(
+ json.dumps(
+ {
+ "s": self.name,
+ "p": prop_key,
+ "o": self.prop.extend_prop_map[prop_key],
+ },
+ ensure_ascii=False,
+ )
+ )
return spo_list
# def __repr__(self):
@@ -240,13 +255,13 @@ def to_json(self):
"from_type": self.from_type,
"end_entity_name": self.end_entity.name,
"end_type": self.end_type,
- "type": self.type
+ "type": self.type,
}
def _get_entity_description(self, entity: EntityData):
if entity is None:
return None
- if entity.description is None or entity.description == '':
+ if entity.description is None or entity.description == "":
return None
if entity.type == "attribute":
return None
@@ -260,31 +275,39 @@ def _get_entity_id(self, name: str, id: str):
def to_spo_list(self):
spo_list = []
- rel = {
- "s": self.from_entity.name,
- "p": self.type,
- "o": self.end_entity.name
- }
+ rel = {"s": self.from_entity.name, "p": self.type, "o": self.end_entity.name}
spo_list.append(json.dumps(rel, ensure_ascii=False))
# prop
if self.prop is not None:
for prop_key in self.prop.origin_prop_map.keys():
- spo_list.append(json.dumps({
- "s": rel,
- "p": prop_key,
- "o": self.prop.origin_prop_map[prop_key]
- }, ensure_ascii=False))
+ spo_list.append(
+ json.dumps(
+ {
+ "s": rel,
+ "p": prop_key,
+ "o": self.prop.origin_prop_map[prop_key],
+ },
+ ensure_ascii=False,
+ )
+ )
for prop_key in self.prop.extend_prop_map.keys():
- spo_list.append(json.dumps({
- "s": rel,
- "p": prop_key,
- "o": self.prop.extend_prop_map[prop_key]
- }, ensure_ascii=False))
+ spo_list.append(
+ json.dumps(
+ {
+ "s": rel,
+ "p": prop_key,
+ "o": self.prop.extend_prop_map[prop_key],
+ },
+ ensure_ascii=False,
+ )
+ )
return spo_list
def __repr__(self):
from_entity_desc = self._get_entity_description(self.from_entity)
- from_entity_desc_str = "" if from_entity_desc is None else f"({from_entity_desc})"
+ from_entity_desc_str = (
+ "" if from_entity_desc is None else f"({from_entity_desc})"
+ )
to_entity_desc = self._get_entity_description(self.end_entity)
to_entity_desc_str = "" if to_entity_desc is None else f"({to_entity_desc})"
return f"({self.from_entity.name}{from_entity_desc_str} {self.type} {self.end_entity.name}{to_entity_desc_str})"
@@ -370,13 +393,9 @@ def to_graph_detail(self):
s_po_map[f"{s} {p} {o}"] = rel_prop_map
start_prop_map = rel.from_entity.get_properties_map_list_value()
if s not in s_po_map.keys():
- s_po_map[s] = {
- p: [o]
- }
+ s_po_map[s] = {p: [o]}
else:
- s_po_map[s].update({
- p: [o]
- })
+ s_po_map[s].update({p: [o]})
s_po_map[s].update(start_prop_map)
return s_po_map
@@ -392,7 +411,7 @@ def get_s_all_attribute_spo(self):
return attr_name_set_map
if len(self.s.prop.origin_prop_map) > 0:
for k in self.s.prop.origin_prop_map.keys():
- if k in ['id', 'name']:
+ if k in ["id", "name"]:
continue
if k.startswith("_"):
continue
@@ -405,7 +424,7 @@ def get_s_all_attribute_spo(self):
attr_name_set_map[k] = spo_list
if len(self.s.prop.extend_prop_map) > 0:
for k in self.s.prop.extend_prop_map.keys():
- if k in ['id', 'name']:
+ if k in ["id", "name"]:
continue
if k.startswith("_"):
continue
@@ -481,11 +500,10 @@ def _prase_entity_relation(self, std_p: str, o_value: EntityData):
if self.s_alias_name == "o":
o_entity = self.s
s_entity = o_value
- if o_value.description is None or o_value.description == '':
+ if o_value.description is None or o_value.description == "":
o_value.description = f"{s_entity.name} {std_p} {o_entity.name}"
return RelationData.from_prop_value(s_entity, std_p, o_entity)
-
def get_std_attr_value_by_spo_text(self, p, spo_text):
spo_list = []
@@ -493,7 +511,7 @@ def get_std_attr_value_by_spo_text(self, p, spo_text):
return spo_list
if len(self.s.prop.origin_prop_map) > 0:
for k in self.s.prop.origin_prop_map.keys():
- if k in ['id', 'name']:
+ if k in ["id", "name"]:
continue
if k.startswith("_"):
continue
@@ -510,14 +528,14 @@ def get_std_p_value_by_spo_text(self, p, spo_text):
relation_value_set = []
if p in self.in_relations.keys():
for rel in self.in_relations[p]:
- if spo_text == str(rel).strip('(').strip(')'):
+ if spo_text == str(rel).strip("(").strip(")"):
if "s" == self.s_alias_name:
relation_value_set.append(rel.revert_spo())
else:
relation_value_set.append(rel)
if p in self.out_relations.keys():
for rel in self.out_relations[p]:
- if spo_text == str(rel).strip('(').strip(')'):
+ if spo_text == str(rel).strip("(").strip(")"):
if "o" == self.s_alias_name:
relation_value_set.append(rel.revert_spo())
else:
@@ -529,7 +547,6 @@ def get_std_p_value_by_spo_text(self, p, spo_text):
relation_value_set.append(self._prase_attribute_relation(p, str(rel)))
return relation_value_set
-
def get_edge_en_to_zh(self, k):
if self.schema is None:
return k
@@ -540,19 +557,19 @@ def get_s_all_relation_spo(self):
relation_name_set_map = {}
if len(self.in_relations) > 0:
for k in self.in_relations.keys():
- if k in ['similarity']:
+ if k in ["similarity"]:
continue
spo_list = []
for v in self.in_relations[k]:
- spo_list.append(str(v).strip('(').strip(')'))
+ spo_list.append(str(v).strip("(").strip(")"))
relation_name_set_map[k] = spo_list
if len(self.out_relations) > 0:
for k in self.out_relations.keys():
- if k in ['similarity']:
+ if k in ["similarity"]:
continue
spo_list = []
for v in self.out_relations[k]:
- spo_list.append(str(v).strip('(').strip(')'))
+ spo_list.append(str(v).strip("(").strip(")"))
relation_name_set_map[k] = spo_list
return relation_name_set_map
@@ -607,7 +624,9 @@ def merge_kg_graph(self, other, wo_intersect=True):
for e_alias in other.edge_map.keys():
if e_alias in self.edge_map.keys():
- self.edge_map[e_alias] = self.edge_map[e_alias] + other.edge_map[e_alias]
+ self.edge_map[e_alias] = (
+ self.edge_map[e_alias] + other.edge_map[e_alias]
+ )
else:
self.edge_map[e_alias] = other.edge_map[e_alias]
for p in other.query_graph.keys():
@@ -675,11 +694,7 @@ def to_answer_path(self):
sp_o_map[(s, p)] = [o]
used_entities = []
for k in sp_o_map.keys():
- answer_path.append({
- "s": k[0],
- "p": k[1],
- "o": sp_o_map[k]
- })
+ answer_path.append({"s": k[0], "p": k[1], "o": sp_o_map[k]})
used_entities.append(k[0])
used_entities = used_entities + sp_o_map[k]
used_entities = list(set(used_entities))
@@ -720,9 +735,15 @@ def _graph_to_json(self):
for d in self.edge_map[k]:
has_entity = True
rels.append(d.to_json())
- if d.from_alias == "s" and d.from_entity not in total_entity_map[s_alias]:
+ if (
+ d.from_alias == "s"
+ and d.from_entity not in total_entity_map[s_alias]
+ ):
total_entity_map[s_alias].append(d.from_entity)
- if d.from_alias == "o" and d.from_entity not in total_entity_map[o_alias]:
+ if (
+ d.from_alias == "o"
+ and d.from_entity not in total_entity_map[o_alias]
+ ):
total_entity_map[o_alias].append(d.from_entity)
if d.end_alias == "s" and d.end_entity not in total_entity_map[s_alias]:
@@ -758,7 +779,7 @@ def to_json(self):
"start_node_alias_name": list(set(self.start_node_alias_name)),
"start_node_name": list(set(self.start_node_name)),
"entity_map": node_dict,
- "edge_map": edge_dict
+ "edge_map": edge_dict,
}
def to_edge_str(self):
@@ -840,8 +861,9 @@ def rmv_node_ins(self, alias_name, alias_ins_set):
allowed_entity_dict[s.alias_name] = []
allowed_entity_dict[o.alias_name] = []
for rel in self.edge_map[p]:
- if (s.alias_name == alias_name and rel.from_id not in alias_ins_set) \
- or (o.alias_name == alias_name and rel.end_id not in alias_ins_set):
+ if (
+ s.alias_name == alias_name and rel.from_id not in alias_ins_set
+ ) or (o.alias_name == alias_name and rel.end_id not in alias_ins_set):
rel_list.append(rel)
self.append_into_map(allowed_entity_dict, s.alias_name, rel.from_id)
self.append_into_map(allowed_entity_dict, o.alias_name, rel.end_id)
diff --git a/kag/solver/logic/core_modules/common/schema_utils.py b/kag/solver/logic/core_modules/common/schema_utils.py
index 964341c8..74679376 100644
--- a/kag/solver/logic/core_modules/common/schema_utils.py
+++ b/kag/solver/logic/core_modules/common/schema_utils.py
@@ -50,7 +50,7 @@ def __init__(self, config: LogicFormConfiguration):
self.get_schema()
def get_spo_with_p(self, spo):
- _, p, _ = spo.split('_')
+ _, p, _ = spo.split("_")
return p
def get_label_within_prefix(self, label_name_without_prefix):
@@ -75,18 +75,14 @@ def _add_attr_with_label(self, label_name, nameZh, name):
attr_en_zh_tmp = self.attr_en_zh_by_label[label_name]
attr_en_zh_tmp[name] = nameZh
else:
- attr_en_zh_tmp = {
- name: nameZh
- }
+ attr_en_zh_tmp = {name: nameZh}
self.attr_en_zh_by_label[label_name] = attr_en_zh_tmp
if label_name in self.attr_zh_en_by_label.keys():
attr_zh_en_tmp = self.attr_zh_en_by_label[label_name]
attr_zh_en_tmp[nameZh] = name
else:
- attr_zh_en_tmp = {
- nameZh: name
- }
+ attr_zh_en_tmp = {nameZh: name}
self.attr_zh_en_by_label[label_name] = attr_zh_en_tmp
def get_attr_en_zh_by_label(self, label_name):
@@ -108,18 +104,23 @@ def get_attr(self, label_name, attributes):
continue
# print('attribute:', attribute)
attribute = json.loads(attribute)
- if 'constraints' in attribute and 'name' in attribute['constraints'] and attribute['constraints'][
- 'name'] == "Enum":
- enums = list(attribute['constraints']['value'].keys())
+ if (
+ "constraints" in attribute
+ and "name" in attribute["constraints"]
+ and attribute["constraints"]["name"] == "Enum"
+ ):
+ enums = list(attribute["constraints"]["value"].keys())
else:
enums = None
- if attribute['name'].startswith('kg') and attribute['name'].endswith('Raw'):
+ if attribute["name"].startswith("kg") and attribute["name"].endswith("Raw"):
continue
- self.attr_zh_en[attribute['nameZh']] = attribute['name']
- self.attr_en_zh[attribute['name']] = attribute['nameZh']
- self.attr_enums[attribute['nameZh']] = enums
- self._add_attr_with_label(label_name, attribute['nameZh'], attribute['name'])
- attributes_namezh.append(attribute['nameZh'])
+ self.attr_zh_en[attribute["nameZh"]] = attribute["name"]
+ self.attr_en_zh[attribute["name"]] = attribute["nameZh"]
+ self.attr_enums[attribute["nameZh"]] = enums
+ self._add_attr_with_label(
+ label_name, attribute["nameZh"], attribute["name"]
+ )
+ attributes_namezh.append(attribute["nameZh"])
return attributes_namezh
def get_ext_json_prop(self):
@@ -152,7 +153,7 @@ def get_schema_from_spg(self):
entity_default_attributes = [
'{"name": "name", "nameZh": "名称"}',
'{"name": "id", "nameZh": "实体主键"}',
- '{"name": "description", "nameZh": "描述"}'
+ '{"name": "description", "nameZh": "描述"}',
]
attributes += entity_default_attributes
attributes_namezh = self.get_attr(name_en, attributes)
@@ -194,7 +195,9 @@ def get_schema_from_spg(self):
if o_name_zh not in self.node_edge:
self.node_edge[o_name_zh] = set()
self.node_edge[o_name_zh].add(p_name_zh)
- r_attributes = self._convert_spg_attr_set(list(relation.sub_properties.values()))
+ r_attributes = self._convert_spg_attr_set(
+ list(relation.sub_properties.values())
+ )
r_attributes_namezh = self.get_attr(name_en, r_attributes)
self.edge_attr[p_name_zh] = r_attributes_namezh
@@ -206,18 +209,24 @@ def get_schema_from_csv(self):
# next(reader)
node_attributes = {}
for row in reader:
- obj, name_zh, name_en, father_en, edge_direction, attributes = row[0], row[1], row[2], row[3], row[4], row[
- 6:]
+ obj, name_zh, name_en, father_en, edge_direction, attributes = (
+ row[0],
+ row[1],
+ row[2],
+ row[3],
+ row[4],
+ row[6:],
+ )
if "nodeType/edgeType" in obj:
continue
- name_en = name_en.replace(self.prefix, '')
- # if name_en in ['Event', 'ProductTaxon']:
+ name_en = name_en.replace(self.prefix, "")
+ # if name_en in ['Event', 'ProductTaxon']:
if father_en and father_en in node_attributes:
attributes += node_attributes[father_en]
node_attributes[name_en] = attributes
- if obj not in ['edge', 'inputEdge']:
+ if obj not in ["edge", "inputEdge"]:
# if name_zh in ['百科实体', '热点事件', '事件']:
- if name_zh in ['百科实体']:
+ if name_zh in ["百科实体"]:
continue
self.nodes.add(name_zh)
self.node_zh_en[name_zh] = name_en
@@ -225,17 +234,16 @@ def get_schema_from_csv(self):
entity_default_attributes = [
'{"name": "name", "nameZh": "名称"}',
'{"name": "id", "nameZh": "实体主键"}',
- '{"name": "description", "nameZh": "描述"}'
+ '{"name": "description", "nameZh": "描述"}',
]
attributes += entity_default_attributes
attributes_namezh = self.get_attr(name_en, attributes)
self.node_attr[name_zh] = attributes_namezh
-
- elif obj == 'edge':
- s, p, o = name_zh.split('_')
+ elif obj == "edge":
+ s, p, o = name_zh.split("_")
# if s in ['百科实体', '热点事件', '事件'] or o in ['百科实体', '热点事件', '事件']:
- if s in ['百科实体'] or o in ['百科实体']:
+ if s in ["百科实体"] or o in ["百科实体"]:
continue
if name_zh not in self.spo:
self.spo.add(name_zh)
@@ -246,7 +254,7 @@ def get_schema_from_csv(self):
self.so_p[(s, o)].add(p)
self.sp_o[(s, p)].add(o)
self.sp_o[(o, p)].add(s)
- s_en, p_en, o_en = name_en.split('_')
+ s_en, p_en, o_en = name_en.split("_")
self.so_p_en[(s_en, o_en)].add(p_en)
self.sp_o_en[(s_en, p_en)].add(o_en)
self.op_s_en[(o_en, p_en)].add(s_en)
@@ -268,22 +276,27 @@ def get_schema_rdf(self, path_node, path_edge):
f_node = open(path_node)
f_edge = open(path_edge)
for row in csv.DictReader(f_node):
- name, id = row['name'], row['alias']
+ name, id = row["name"], row["alias"]
self.nodes.add(name)
for row in csv.DictReader(f_edge):
- name = row['name']
+ name = row["name"]
self.edges.add(name)
def _convert_spg_attr_set(self, attr_set: List[Property]):
- return [json.dumps({
- 'constraints': attr.to_dict().get('constraint', {}),
- 'name': attr.to_dict().get('name'),
- 'nameZh': attr.to_dict().get('name_zh')
- }) for attr in attr_set]
+ return [
+ json.dumps(
+ {
+ "constraints": attr.to_dict().get("constraint", {}),
+ "name": attr.to_dict().get("name"),
+ "nameZh": attr.to_dict().get("name_zh"),
+ }
+ )
+ for attr in attr_set
+ ]
def generate_nodes_edges_hetero(schema):
- '''
+ """
nodes {
hetero {
"CommonSenseKG.Person" {
@@ -308,21 +321,27 @@ def generate_nodes_edges_hetero(schema):
}
}
}
- '''
+ """
nodes_hetero, edges_hetero = defaultdict(dict), defaultdict(dict)
for node in schema.nodes:
node = schema.node_zh_en[node]
features = []
for attr in schema.node_attr[schema.node_en_zh[node]]:
attr = schema.attr_zh_en[attr]
- features.append(attr + ';Raw|use_fe=False;Direct;str')
- node = schema.prefix + '.' + node
- nodes_hetero[node] = {'fe': features}
+ features.append(attr + ";Raw|use_fe=False;Direct;str")
+ node = schema.prefix + "." + node
+ nodes_hetero[node] = {"fe": features}
for spo in schema.spo_en:
- s, p, o = spo.split('_')
- edge = '_'.join([schema.prefix + '.' + s, p, schema.prefix + '.' + o])
- edges_hetero[edge] = {'fe': []}
+ s, p, o = spo.split("_")
+ edge = "_".join([schema.prefix + "." + s, p, schema.prefix + "." + o])
+ edges_hetero[edge] = {"fe": []}
- print('nodes_hetero:', json.dumps(nodes_hetero, indent=2).replace('"fe"', 'fe').replace('},', '}'))
- print('edges_hetero:', json.dumps(edges_hetero, indent=2).replace('"fe"', 'fe').replace('},', '}'))
+ print(
+ "nodes_hetero:",
+ json.dumps(nodes_hetero, indent=2).replace('"fe"', "fe").replace("},", "}"),
+ )
+ print(
+ "edges_hetero:",
+ json.dumps(edges_hetero, indent=2).replace('"fe"', "fe").replace("},", "}"),
+ )
diff --git a/kag/solver/logic/core_modules/common/text_sim_by_vector.py b/kag/solver/logic/core_modules/common/text_sim_by_vector.py
index 44b20bc0..9b4c886a 100644
--- a/kag/solver/logic/core_modules/common/text_sim_by_vector.py
+++ b/kag/solver/logic/core_modules/common/text_sim_by_vector.py
@@ -7,9 +7,12 @@
def cosine_similarity(vector1, vector2):
- cosine = np.dot(vector1, vector2) / (np.linalg.norm(vector1) * np.linalg.norm(vector2))
+ cosine = np.dot(vector1, vector2) / (
+ np.linalg.norm(vector1) * np.linalg.norm(vector2)
+ )
return cosine
+
def split_list(input_list, max_length=30):
"""
Splits a list into multiple sublists where each sublist has a maximum length of max_length.
@@ -18,7 +21,9 @@ def split_list(input_list, max_length=30):
:param max_length: The maximum length of each sublist
:return: A list containing multiple sublists
"""
- return [input_list[i:i + max_length] for i in range(0, len(input_list), max_length)]
+ return [
+ input_list[i : i + max_length] for i in range(0, len(input_list), max_length)
+ ]
class TextSimilarity:
@@ -60,13 +65,15 @@ def sentence_encode(self, sentences, is_cached=False):
return ret
def text_sim_result(self, mention, candidates: List[str], topk=1, low_score=0.63):
- '''
+ """
output: [(candi_name, candi_score),...]
- '''
+ """
if mention is None:
return []
mention_emb = self.sentence_encode(mention)
- candidates = [cand for cand in candidates if cand is not None and cand.strip() != '']
+ candidates = [
+ cand for cand in candidates if cand is not None and cand.strip() != ""
+ ]
if len(candidates) == 0:
return []
candidates_emb = self.sentence_encode(candidates)
@@ -76,15 +83,17 @@ def text_sim_result(self, mention, candidates: List[str], topk=1, low_score=0.63
if cosine < low_score:
continue
candidates_dis[candidate] = cosine
- candidates_dis = sorted(candidates_dis.items(), key=lambda x:x[-1], reverse=True)
+ candidates_dis = sorted(
+ candidates_dis.items(), key=lambda x: x[-1], reverse=True
+ )
candis = candidates_dis[:topk]
return candis
def text_type_sim(self, mention, candidates, topk=1):
- '''
+ """
output: [(candi_name, candi_score),...]
- '''
+ """
res = self.text_sim_result(mention, candidates, topk)
if len(res) == 0:
- return [('Entity', 1.)]
+ return [("Entity", 1.0)]
return res
diff --git a/kag/solver/logic/core_modules/common/utils.py b/kag/solver/logic/core_modules/common/utils.py
index 4c384e82..e4897d9f 100644
--- a/kag/solver/logic/core_modules/common/utils.py
+++ b/kag/solver/logic/core_modules/common/utils.py
@@ -1,4 +1,4 @@
-#coding=utf8
+# coding=utf8
import random
import re
import string
@@ -6,23 +6,28 @@
def generate_random_string(bit=8):
possible_characters = string.ascii_letters + string.digits
- return ''.join(random.choice(possible_characters) for _ in range(bit))
+ return "".join(random.choice(possible_characters) for _ in range(bit))
+
def generate_biz_id_with_type(biz_id, type_name):
return f"{biz_id}_{type_name}"
+
def get_p_clean(p):
if re.search(".*[\\u4e00-\\u9fa5]+.*", p):
- p = re.sub('[ \t::()“”‘’\'"\[\]\(\)]+?', '', p)
+ p = re.sub("[ \t::()“”‘’'\"\[\]\(\)]+?", "", p)
else:
p = None
return p
+
def get_recall_node_label(label_set):
for l in label_set:
if l != "Entity":
return l
-def node_2_doc(node:dict):
+
+
+def node_2_doc(node: dict):
prop_set = []
for key in node.keys():
if key in ["id"]:
@@ -39,4 +44,4 @@ def node_2_doc(node:dict):
else:
prop = f"{key}:{value}"
prop_set.append(prop)
- return "\n".join(prop_set)
\ No newline at end of file
+ return "\n".join(prop_set)
diff --git a/kag/solver/logic/core_modules/config.py b/kag/solver/logic/core_modules/config.py
index 5a5b02b8..0a77e01b 100644
--- a/kag/solver/logic/core_modules/config.py
+++ b/kag/solver/logic/core_modules/config.py
@@ -2,22 +2,28 @@
class LogicFormConfiguration:
-
def __init__(self, args={}):
self.resource_path = args.get("resource_path", "./")
self.prefix = args.get("prefix", "")
# kg graph project ID.
- self.project_id = args.get("KAG_PROJECT_ID", None) or os.getenv("KAG_PROJECT_ID")
+ self.project_id = args.get("KAG_PROJECT_ID", None) or os.getenv(
+ "KAG_PROJECT_ID"
+ )
if not self.project_id:
- raise RuntimeError("init LogicFormConfiguration failed, not found params KAG_PROJECT_ID")
+ raise RuntimeError(
+ "init LogicFormConfiguration failed, not found params KAG_PROJECT_ID"
+ )
# kg graph schema file path.
self.schema_file_name = args.get("schema_file_name", "")
- self.host_addr = args.get("KAG_PROJECT_HOST_ADDR", None) or os.getenv("KAG_PROJECT_HOST_ADDR")
+ self.host_addr = args.get("KAG_PROJECT_HOST_ADDR", None) or os.getenv(
+ "KAG_PROJECT_HOST_ADDR"
+ )
if not self.host_addr:
- raise RuntimeError("init LogicFormConfiguration failed, not found params KAG_PROJECT_HOST_ADDR")
-
+ raise RuntimeError(
+ "init LogicFormConfiguration failed, not found params KAG_PROJECT_HOST_ADDR"
+ )
diff --git a/kag/solver/logic/core_modules/lf_executor.py b/kag/solver/logic/core_modules/lf_executor.py
index 9ec12103..b2cd1166 100644
--- a/kag/solver/logic/core_modules/lf_executor.py
+++ b/kag/solver/logic/core_modules/lf_executor.py
@@ -12,14 +12,23 @@
from kag.solver.logic.core_modules.common.schema_utils import SchemaUtils
from kag.solver.logic.core_modules.common.text_sim_by_vector import TextSimilarity
from kag.solver.logic.core_modules.config import LogicFormConfiguration
-from kag.solver.logic.core_modules.op_executor.op_deduce.deduce_executor import DeduceExecutor
+from kag.solver.logic.core_modules.op_executor.op_deduce.deduce_executor import (
+ DeduceExecutor,
+)
from kag.solver.logic.core_modules.op_executor.op_math.math_executor import MathExecutor
-from kag.solver.logic.core_modules.op_executor.op_output.output_executor import OutputExecutor
-from kag.solver.logic.core_modules.op_executor.op_retrieval.retrieval_executor import RetrievalExecutor
+from kag.solver.logic.core_modules.op_executor.op_output.output_executor import (
+ OutputExecutor,
+)
+from kag.solver.logic.core_modules.op_executor.op_retrieval.retrieval_executor import (
+ RetrievalExecutor,
+)
from kag.solver.logic.core_modules.op_executor.op_sort.sort_executor import SortExecutor
from kag.solver.logic.core_modules.parser.logic_node_parser import ParseLogicForm
from kag.solver.logic.core_modules.retriver.entity_linker import EntityLinkerBase
-from kag.solver.logic.core_modules.retriver.graph_retriver.dsl_executor import DslRunner, DslRunnerOnGraphStore
+from kag.solver.logic.core_modules.retriver.graph_retriver.dsl_executor import (
+ DslRunner,
+ DslRunnerOnGraphStore,
+)
from kag.solver.logic.core_modules.retriver.schema_std import SchemaRetrieval
from kag.solver.logic.core_modules.rule_runner.rule_runner import OpRunner
from kag.solver.tools.info_processor import ReporterIntermediateProcessTool
@@ -28,13 +37,24 @@
class LogicExecutor:
- def __init__(self, query: str, project_id: str,
- schema: SchemaUtils, kg_retriever: KGRetrieverABC,
- chunk_retriever: ChunkRetrieverABC, std_schema: SchemaRetrieval, el: EntityLinkerBase, generator,
- dsl_runner: DslRunner,
- text_similarity: TextSimilarity=None,
- req_id='',
- need_detail=False, llm=None, report_tool=None, params=None):
+ def __init__(
+ self,
+ query: str,
+ project_id: str,
+ schema: SchemaUtils,
+ kg_retriever: KGRetrieverABC,
+ chunk_retriever: ChunkRetrieverABC,
+ std_schema: SchemaRetrieval,
+ el: EntityLinkerBase,
+ generator,
+ dsl_runner: DslRunner,
+ text_similarity: TextSimilarity = None,
+ req_id="",
+ need_detail=False,
+ llm=None,
+ report_tool=None,
+ params=None,
+ ):
"""
Initializes the LogicEngine with necessary parameters and configurations.
@@ -78,7 +98,7 @@ def __init__(self, query: str, project_id: str,
"el_detail": [],
"std_out": [],
"get_empty": [],
- "sub_qa_pair": []
+ "sub_qa_pair": [],
}
self.op_runner = OpRunner(self.kg_graph, llm, query, self.req_id)
self.parser = ParseLogicForm(self.schema, std_schema)
@@ -90,19 +110,55 @@ def __init__(self, query: str, project_id: str,
self.force_chunk_retriever = os.getenv("KAG_QA_FORCE_CHUNK_RETRIEVER", False)
# Initialize executors for different operations.
- self.retrieval_executor = RetrievalExecutor(query, self.kg_graph, self.schema, self.kg_retriever,
- self.el,
- self.dsl_runner, self.debug_info, text_similarity,KAG_PROJECT_ID = self.project_id)
- self.deduce_executor = DeduceExecutor(query, self.kg_graph, self.schema, self.op_runner, self.debug_info, KAG_PROJECT_ID = self.project_id)
- self.sort_executor = SortExecutor(query, self.kg_graph, self.schema, self.debug_info, KAG_PROJECT_ID = self.project_id)
- self.math_executor = MathExecutor(query, self.kg_graph, self.schema, self.debug_info, KAG_PROJECT_ID = self.project_id)
- self.output_executor = OutputExecutor(query, self.kg_graph, self.schema, self.el,
- self.dsl_runner,
- self.retrieval_executor.query_one_graph_cache, self.debug_info, KAG_PROJECT_ID = self.project_id)
+ self.retrieval_executor = RetrievalExecutor(
+ query,
+ self.kg_graph,
+ self.schema,
+ self.kg_retriever,
+ self.el,
+ self.dsl_runner,
+ self.debug_info,
+ text_similarity,
+ KAG_PROJECT_ID=self.project_id,
+ )
+ self.deduce_executor = DeduceExecutor(
+ query,
+ self.kg_graph,
+ self.schema,
+ self.op_runner,
+ self.debug_info,
+ KAG_PROJECT_ID=self.project_id,
+ )
+ self.sort_executor = SortExecutor(
+ query,
+ self.kg_graph,
+ self.schema,
+ self.debug_info,
+ KAG_PROJECT_ID=self.project_id,
+ )
+ self.math_executor = MathExecutor(
+ query,
+ self.kg_graph,
+ self.schema,
+ self.debug_info,
+ KAG_PROJECT_ID=self.project_id,
+ )
+ self.output_executor = OutputExecutor(
+ query,
+ self.kg_graph,
+ self.schema,
+ self.el,
+ self.dsl_runner,
+ self.retrieval_executor.query_one_graph_cache,
+ self.debug_info,
+ KAG_PROJECT_ID=self.project_id,
+ )
self.with_sub_answer = os.getenv("KAG_QA_WITH_SUB_ANSWER", True)
- def _convert_logic_nodes_2_question(self, logic_nodes: List[LFPlanResult]) -> List[Question]:
+ def _convert_logic_nodes_2_question(
+ self, logic_nodes: List[LFPlanResult]
+ ) -> List[Question]:
ret_question = []
for i in range(0, len(logic_nodes)):
if i == 0:
@@ -111,17 +167,20 @@ def _convert_logic_nodes_2_question(self, logic_nodes: List[LFPlanResult]) -> Li
)
else:
question = Question(
- question=logic_nodes[i].query,
- dependencies=[ret_question[i - 1]]
+ question=logic_nodes[i].query, dependencies=[ret_question[i - 1]]
)
question.id = i
ret_question.append(question)
return ret_question
- def _generate_sub_answer(self, history: list, spo_retrieved: list, docs: list, sub_query: str):
+ def _generate_sub_answer(
+ self, history: list, spo_retrieved: list, docs: list, sub_query: str
+ ):
if not self.with_sub_answer:
return "I don't know"
- return self.generator.generate_sub_answer(sub_query, spo_retrieved, docs, history)
+ return self.generator.generate_sub_answer(
+ sub_query, spo_retrieved, docs, history
+ )
def execute(self, lf_nodes: List[LFPlanResult], init_query):
"""
@@ -137,7 +196,9 @@ def execute(self, lf_nodes: List[LFPlanResult], init_query):
history = []
query_ner_list = {}
# get NER results for the initial query, for chunk retrieve
- if self.chunk_retriever and hasattr(self.chunk_retriever, 'get_std_ner_by_query'):
+ if self.chunk_retriever and hasattr(
+ self.chunk_retriever, "get_std_ner_by_query"
+ ):
query_ner_list = self.chunk_retriever.get_std_ner_by_query(init_query)
query_num = 0
@@ -149,7 +210,9 @@ def execute(self, lf_nodes: List[LFPlanResult], init_query):
node_begin_time = time.time()
sub_logic_nodes_str = "\n".join([str(ln) for ln in sub_logic_nodes])
- question = self._create_sub_question_report_node(query_num, sub_logic_nodes_str, sub_query)
+ question = self._create_sub_question_report_node(
+ query_num, sub_logic_nodes_str, sub_query
+ )
if self.kg_retriever:
kg_qa_result, spo_retrieved = self._execute_lf(sub_logic_nodes)
else:
@@ -157,59 +220,86 @@ def execute(self, lf_nodes: List[LFPlanResult], init_query):
kg_qa_result, spo_retrieved = [], []
question.context.append(f"#### spo retrieved:")
- question.context.append(f"{spo_retrieved if len(spo_retrieved) > 0 else 'no spo tuple retrieved'}.")
- self._update_sub_question_status(question, None, ReporterIntermediateProcessTool.STATE.RUNNING)
-
+ question.context.append(
+ f"{spo_retrieved if len(spo_retrieved) > 0 else 'no spo tuple retrieved'}."
+ )
+ self._update_sub_question_status(
+ question, None, ReporterIntermediateProcessTool.STATE.RUNNING
+ )
answer_source = "spo"
docs_with_score = []
all_related_entities, sub_answer = self._generate_sub_answer_by_graph(
- history, kg_qa_result, spo_retrieved, sub_query)
+ history, kg_qa_result, spo_retrieved, sub_query
+ )
# if sub answer is `I don't know`, we use chunk retriever
if "i don't know" in sub_answer.lower() and self.chunk_retriever:
answer_source = "chunk"
question.context.append(f"## Chunk Retriever")
- self._update_sub_question_status(question, None, ReporterIntermediateProcessTool.STATE.RUNNING)
+ self._update_sub_question_status(
+ question, None, ReporterIntermediateProcessTool.STATE.RUNNING
+ )
start_time = time.time()
# Update parameters to include retrieved SPO entities as starting points for chunk retrieval.
params = {
- 'related_entities': all_related_entities,
- 'query_ner_dict': query_ner_list,
- 'req_id': self.req_id
+ "related_entities": all_related_entities,
+ "query_ner_dict": query_ner_list,
+ "req_id": self.req_id,
}
# Retrieve chunks using the updated parameters.
- sub_query_with_history_qa = self._generate_sub_query_with_history_qa(history, sub_query)
- docs_with_score = self.chunk_retriever.recall_docs(sub_query_with_history_qa, top_k=10, **params)
+ sub_query_with_history_qa = self._generate_sub_query_with_history_qa(
+ history, sub_query
+ )
+ docs_with_score = self.chunk_retriever.recall_docs(
+ sub_query_with_history_qa, top_k=10, **params
+ )
docs = ["#".join(item.split("#")[:-1]) for item in docs_with_score]
self._update_sub_question_recall_docs(docs, question)
- self._update_sub_question_status(question, None, ReporterIntermediateProcessTool.STATE.RUNNING)
+ self._update_sub_question_status(
+ question, None, ReporterIntermediateProcessTool.STATE.RUNNING
+ )
retrival_time = time.time() - start_time
- sub_answer = self._generate_sub_answer(history, spo_retrieved, docs, sub_query)
+ sub_answer = self._generate_sub_answer(
+ history, spo_retrieved, docs, sub_query
+ )
question.context.append("#### answer based by fuzzy retrieved:")
question.context.append(f"{sub_answer}")
- logger.info(f"{self.req_id} call by docs cost: {retrival_time} docs num={len(docs)}")
+ logger.info(
+ f"{self.req_id} call by docs cost: {retrival_time} docs num={len(docs)}"
+ )
history.append(
- {"sub_query": sub_query, "sub_answer": sub_answer, 'docs': docs_with_score,
- 'spo_retrieved': spo_retrieved,
- 'exactly_match': self.debug_info.get('exact_match_spo', False),
- 'logic_expr': sub_logic_nodes_str, 'answer_source': answer_source,
- 'cost': time.time() - node_begin_time})
- self.debug_info['sub_qa_pair'].append([sub_query, sub_answer])
- self._update_sub_question_status(question, sub_answer, ReporterIntermediateProcessTool.STATE.FINISH)
+ {
+ "sub_query": sub_query,
+ "sub_answer": sub_answer,
+ "docs": docs_with_score,
+ "spo_retrieved": spo_retrieved,
+ "exactly_match": self.debug_info.get("exact_match_spo", False),
+ "logic_expr": sub_logic_nodes_str,
+ "answer_source": answer_source,
+ "cost": time.time() - node_begin_time,
+ }
+ )
+ self.debug_info["sub_qa_pair"].append([sub_query, sub_answer])
+ self._update_sub_question_status(
+ question, sub_answer, ReporterIntermediateProcessTool.STATE.FINISH
+ )
return kg_qa_result, self.kg_graph, history
def _generate_sub_query_with_history_qa(self, history, sub_query):
# Generate a sub-query with history qa pair
if history:
- history_sub_answer = [h['sub_answer'] for h in history[:3] if
- "i don't know" not in h['sub_answer'].lower()]
- sub_query_with_history_qa = '\n'.join(history_sub_answer) + '\n' + sub_query
+ history_sub_answer = [
+ h["sub_answer"]
+ for h in history[:3]
+ if "i don't know" not in h["sub_answer"].lower()
+ ]
+ sub_query_with_history_qa = "\n".join(history_sub_answer) + "\n" + sub_query
else:
sub_query_with_history_qa = sub_query
return sub_query_with_history_qa
@@ -217,10 +307,12 @@ def _generate_sub_query_with_history_qa(self, history, sub_query):
def _update_sub_question_recall_docs(self, docs, question):
question.context.extend(["|id|content|", "|-|-|"])
for i, d in enumerate(docs, start=1):
- _d = d.replace('\n', '
')
+ _d = d.replace("\n", "
")
question.context.append(f"|{i}|{_d}|")
- def _generate_sub_answer_by_graph(self, history, kg_qa_result, spo_retrieved, sub_query):
+ def _generate_sub_answer_by_graph(
+ self, history, kg_qa_result, spo_retrieved, sub_query
+ ):
sub_answer = "I don't know"
all_related_entities = self.kg_graph.get_all_entity()
all_related_entities = list(set(all_related_entities))
@@ -232,7 +324,7 @@ def _generate_sub_answer_by_graph(self, history, kg_qa_result, spo_retrieved, su
if len(spo_retrieved) == 0 and len(kg_qa_result) == 0:
all_related_entities = []
# if there is answer in kg_qa_result, and the answer is exact match with spo, we generate answer with kg result
- elif self.debug_info.get('exact_match_spo', False):
+ elif self.debug_info.get("exact_match_spo", False):
if len(kg_qa_result) > 0:
sub_answer = str(kg_qa_result)
else:
@@ -241,27 +333,39 @@ def _generate_sub_answer_by_graph(self, history, kg_qa_result, spo_retrieved, su
else:
if len(spo_retrieved) == 0:
spo_retrieved = kg_qa_result
- sub_answer = self._generate_sub_answer(history, spo_retrieved, [], sub_query)
+ sub_answer = self._generate_sub_answer(
+ history, spo_retrieved, [], sub_query
+ )
return all_related_entities, sub_answer
def _update_sub_question_status(self, question, answer, status):
if self.report_tool:
self.report_tool.report_node(question, answer, status)
- def _create_sub_question_report_node(self, query_num, sub_logic_nodes_str, sub_query):
+ def _create_sub_question_report_node(
+ self, query_num, sub_logic_nodes_str, sub_query
+ ):
question = Question(
question=sub_query,
)
question.id = query_num
- question.context = ["## SPO Retriever", "#### logic_form expression: ",
- f'```java\n{sub_logic_nodes_str}\n```']
+ question.context = [
+ "## SPO Retriever",
+ "#### logic_form expression: ",
+ f"```java\n{sub_logic_nodes_str}\n```",
+ ]
if self.report_tool:
- self.report_tool.report_node(question, None, ReporterIntermediateProcessTool.STATE.RUNNING)
+ self.report_tool.report_node(
+ question, None, ReporterIntermediateProcessTool.STATE.RUNNING
+ )
return question
def _create_report_pipeline(self, init_query, lf_nodes):
if self.report_tool:
- self.report_tool.report_pipeline(Question(question=init_query), self._convert_logic_nodes_2_question(lf_nodes))
+ self.report_tool.report_pipeline(
+ Question(question=init_query),
+ self._convert_logic_nodes_2_question(lf_nodes),
+ )
def _execute_lf(self, sub_logic_nodes):
kg_qa_result = []
@@ -283,7 +387,9 @@ def _execute_lf(self, sub_logic_nodes):
elif self.sort_executor.is_this_op(n):
self.sort_executor.executor(n, self.req_id, self.params)
elif self.output_executor.is_this_op(n):
- kg_qa_result += self.output_executor.executor(n, self.req_id, self.params)
+ kg_qa_result += self.output_executor.executor(
+ n, self.req_id, self.params
+ )
else:
logger.warning(f"unknown operator: {n.operator}")
return kg_qa_result, spo_set
diff --git a/kag/solver/logic/core_modules/lf_generator.py b/kag/solver/logic/core_modules/lf_generator.py
index 89a303ca..6204bf4e 100644
--- a/kag/solver/logic/core_modules/lf_generator.py
+++ b/kag/solver/logic/core_modules/lf_generator.py
@@ -12,21 +12,23 @@ class LFGenerator(KagBaseModule):
This class can be extended to implement custom generation strategies.
"""
- def __init__(self,**kwargs):
+ def __init__(self, **kwargs):
super().__init__(**kwargs)
self.solve_question_prompt = PromptOp.load(self.biz_scene, "solve_question")(
language=self.language
)
- self.solve_question_without_docs_prompt = PromptOp.load(self.biz_scene, "solve_question_without_docs")(
- language=self.language
- )
+ self.solve_question_without_docs_prompt = PromptOp.load(
+ self.biz_scene, "solve_question_without_docs"
+ )(language=self.language)
- self.solve_question_without_spo_prompt = PromptOp.load(self.biz_scene, "solve_question_without_spo")(
- language=self.language
- )
+ self.solve_question_without_spo_prompt = PromptOp.load(
+ self.biz_scene, "solve_question_without_spo"
+ )(language=self.language)
- def generate_sub_answer(self, question: str, knowledge_graph: [], docs: [], history=[]):
+ def generate_sub_answer(
+ self, question: str, knowledge_graph: [], docs: [], history=[]
+ ):
"""
Generates a sub-answer based on the given question, knowledge graph, documents, and history.
@@ -39,33 +41,39 @@ def generate_sub_answer(self, question: str, knowledge_graph: [], docs: [], hist
Returns:
str: The generated sub-answer.
"""
- history_qa = [f"query{i}: {item['sub_query']}\nanswer{i}: {item['sub_answer']}" for i, item in
- enumerate(history)]
+ history_qa = [
+ f"query{i}: {item['sub_query']}\nanswer{i}: {item['sub_answer']}"
+ for i, item in enumerate(history)
+ ]
if knowledge_graph:
if len(docs) > 0:
prompt = self.solve_question_prompt
params = {
- 'question': question,
- 'knowledge_graph': str(knowledge_graph),
- 'docs': str(docs),
- 'history': '\n'.join(history_qa)
+ "question": question,
+ "knowledge_graph": str(knowledge_graph),
+ "docs": str(docs),
+ "history": "\n".join(history_qa),
}
else:
prompt = self.solve_question_without_docs_prompt
params = {
- 'question': question,
- 'knowledge_graph': str(knowledge_graph),
- 'history': '\n'.join(history_qa)
+ "question": question,
+ "knowledge_graph": str(knowledge_graph),
+ "history": "\n".join(history_qa),
}
else:
prompt = self.solve_question_without_spo_prompt
params = {
- 'question': question,
- 'docs': str(docs),
- 'history': '\n'.join(history_qa)
+ "question": question,
+ "docs": str(docs),
+ "history": "\n".join(history_qa),
}
- llm_output = self.llm_module.invoke(params, prompt, with_json_parse=False, with_except=True)
- logger.debug(f"sub_question:{question}\n sub_answer:{llm_output} prompt:\n{prompt}")
+ llm_output = self.llm_module.invoke(
+ params, prompt, with_json_parse=False, with_except=True
+ )
+ logger.debug(
+ f"sub_question:{question}\n sub_answer:{llm_output} prompt:\n{prompt}"
+ )
if llm_output:
return llm_output
return "I don't know"
diff --git a/kag/solver/logic/core_modules/lf_solver.py b/kag/solver/logic/core_modules/lf_solver.py
index a24d98b5..3aa6471e 100644
--- a/kag/solver/logic/core_modules/lf_solver.py
+++ b/kag/solver/logic/core_modules/lf_solver.py
@@ -15,7 +15,9 @@
from kag.solver.logic.core_modules.lf_executor import LogicExecutor
from kag.solver.logic.core_modules.lf_generator import LFGenerator
from kag.solver.logic.core_modules.retriver.entity_linker import DefaultEntityLinker
-from kag.solver.logic.core_modules.retriver.graph_retriver.dsl_executor import DslRunnerOnGraphStore
+from kag.solver.logic.core_modules.retriver.graph_retriver.dsl_executor import (
+ DslRunnerOnGraphStore,
+)
from kag.solver.logic.core_modules.retriver.schema_std import SchemaRetrieval
from knext.project.client import ProjectClient
@@ -28,9 +30,13 @@ class LFSolver:
This class can't be extended to implement custom solver strategies.
"""
- def __init__(self, kg_retriever: KGRetrieverABC = None,
- chunk_retriever: ChunkRetrieverABC = None,
- report_tool=None, **kwargs):
+ def __init__(
+ self,
+ kg_retriever: KGRetrieverABC = None,
+ chunk_retriever: ChunkRetrieverABC = None,
+ report_tool=None,
+ **kwargs,
+ ):
"""
Initializes the solver with necessary modules and configurations.
@@ -46,12 +52,16 @@ def __init__(self, kg_retriever: KGRetrieverABC = None,
ValueError: If both `kg_retriever` and `chunk_retriever` are None.
"""
if kg_retriever is None and chunk_retriever is None:
- raise ValueError("At least one of `kg_retriever` or `chunk_retriever` must be provided.")
+ raise ValueError(
+ "At least one of `kg_retriever` or `chunk_retriever` must be provided."
+ )
self.kg_retriever = kg_retriever
self.chunk_retriever = chunk_retriever
self.project_id = kwargs.get("KAG_PROJECT_ID") or os.getenv("KAG_PROJECT_ID")
- self.host_addr = kwargs.get("KAG_PROJECT_HOST_ADDR") or os.getenv("KAG_PROJECT_HOST_ADDR")
+ self.host_addr = kwargs.get("KAG_PROJECT_HOST_ADDR") or os.getenv(
+ "KAG_PROJECT_HOST_ADDR"
+ )
if report_tool and report_tool.project_id:
self.project_id = report_tool.project_id
@@ -65,7 +75,9 @@ def __init__(self, kg_retriever: KGRetrieverABC = None,
vectorizer_config = eval(os.getenv("KAG_VECTORIZER", "{}"))
if self.host_addr and self.project_id:
- config = ProjectClient(host_addr=self.host_addr, project_id=self.project_id).get_config(self.project_id)
+ config = ProjectClient(
+ host_addr=self.host_addr, project_id=self.project_id
+ ).get_config(self.project_id)
vectorizer_config.update(config.get("vectorizer", {}))
self.vectorizer: Vectorizer = Vectorizer.from_config(vectorizer_config)
self.text_similarity = TextSimilarity(vec_config=vectorizer_config)
@@ -85,10 +97,12 @@ def _process_history(self, history):
for i, h in enumerate(history):
if "sub_query" not in h:
continue
- if 'sub_answer' in h and h['sub_answer'].lower() != "i don't know":
- sub_qa_pair.append(f"query{i + 1}: {h['sub_query']}\nanswer{i + 1}: {h['sub_answer']}")
- if "docs" in h and len(h['docs']) > 0:
- docs_set.append(h['docs'])
+ if "sub_answer" in h and h["sub_answer"].lower() != "i don't know":
+ sub_qa_pair.append(
+ f"query{i + 1}: {h['sub_query']}\nanswer{i + 1}: {h['sub_answer']}"
+ )
+ if "docs" in h and len(h["docs"]) > 0:
+ docs_set.append(h["docs"])
return sub_qa_pair, docs_set
def _flat_passages_set(self, passages_set: list):
@@ -113,7 +127,12 @@ def _flat_passages_set(self, passages_set: list):
else:
score_map[passage] = score
- return [k for k, v in sorted(score_map.items(), key=lambda item: item[1], reverse=True)]
+ return [
+ k
+ for k, v in sorted(
+ score_map.items(), key=lambda item: item[1], reverse=True
+ )
+ ]
def solve(self, query, lf_nodes: List[LFPlanResult]):
"""
@@ -129,19 +148,27 @@ def solve(self, query, lf_nodes: List[LFPlanResult]):
try:
start_time = time.time()
executor = LogicExecutor(
- query, self.project_id, self.schema,
+ query,
+ self.project_id,
+ self.schema,
kg_retriever=self.kg_retriever,
chunk_retriever=self.chunk_retriever,
std_schema=self.std_schema,
el=self.el,
text_similarity=self.text_similarity,
- dsl_runner=DslRunnerOnGraphStore(self.project_id, self.schema, LogicFormConfiguration({
- "KAG_PROJECT_ID": self.project_id,
- "KAG_PROJECT_HOST_ADDR": self.host_addr
- })),
+ dsl_runner=DslRunnerOnGraphStore(
+ self.project_id,
+ self.schema,
+ LogicFormConfiguration(
+ {
+ "KAG_PROJECT_ID": self.project_id,
+ "KAG_PROJECT_HOST_ADDR": self.host_addr,
+ }
+ ),
+ ),
generator=self.generator,
report_tool=self.report_tool,
- req_id=generate_random_string(10)
+ req_id=generate_random_string(10),
)
kg_qa_result, kg_graph, history = executor.execute(lf_nodes, query)
logger.info(
@@ -159,7 +186,7 @@ def solve(self, query, lf_nodes: List[LFPlanResult]):
docs = self._flat_passages_set(docs_set)
if len(docs) == 0 and len(sub_qa_pair) == 0 and self.chunk_retriever:
cur_step_recall_docs = self.chunk_retriever.recall_docs(query)
- history.append({'docs': cur_step_recall_docs})
+ history.append({"docs": cur_step_recall_docs})
docs = self._flat_passages_set([cur_step_recall_docs])
if len(docs) != 0:
self.last_iter_docs = docs
diff --git a/kag/solver/logic/core_modules/op_executor/op_deduce/deduce_executor.py b/kag/solver/logic/core_modules/op_executor/op_deduce/deduce_executor.py
index f8d67c82..165bb241 100644
--- a/kag/solver/logic/core_modules/op_executor/op_deduce/deduce_executor.py
+++ b/kag/solver/logic/core_modules/op_executor/op_deduce/deduce_executor.py
@@ -4,33 +4,74 @@
from kag.solver.logic.core_modules.common.one_hop_graph import KgGraph
from kag.solver.logic.core_modules.common.schema_utils import SchemaUtils
from kag.solver.logic.core_modules.op_executor.op_deduce.module.choice import ChoiceOp
-from kag.solver.logic.core_modules.op_executor.op_deduce.module.entailment import EntailmentOp
-from kag.solver.logic.core_modules.op_executor.op_deduce.module.judgement import JudgementOp
-from kag.solver.logic.core_modules.op_executor.op_deduce.module.multi_choice import MultiChoiceOp
+from kag.solver.logic.core_modules.op_executor.op_deduce.module.entailment import (
+ EntailmentOp,
+)
+from kag.solver.logic.core_modules.op_executor.op_deduce.module.judgement import (
+ JudgementOp,
+)
+from kag.solver.logic.core_modules.op_executor.op_deduce.module.multi_choice import (
+ MultiChoiceOp,
+)
from kag.solver.logic.core_modules.op_executor.op_executor import OpExecutor
-from kag.solver.logic.core_modules.parser.logic_node_parser import FilterNode, VerifyNode, \
- ExtractorNode, DeduceNode
+from kag.solver.logic.core_modules.parser.logic_node_parser import (
+ FilterNode,
+ VerifyNode,
+ ExtractorNode,
+ DeduceNode,
+)
from kag.solver.logic.core_modules.rule_runner.rule_runner import OpRunner
class DeduceExecutor(OpExecutor):
- def __init__(self, nl_query: str, kg_graph: KgGraph, schema: SchemaUtils, rule_runner: OpRunner, debug_info: dict,
- **kwargs):
+ def __init__(
+ self,
+ nl_query: str,
+ kg_graph: KgGraph,
+ schema: SchemaUtils,
+ rule_runner: OpRunner,
+ debug_info: dict,
+ **kwargs
+ ):
super().__init__(nl_query, kg_graph, schema, debug_info, **kwargs)
- self.KAG_PROJECT_ID = kwargs.get('KAG_PROJECT_ID')
+ self.KAG_PROJECT_ID = kwargs.get("KAG_PROJECT_ID")
self.rule_runner = rule_runner
self.op_register_map = {
- 'verify': self.rule_runner.run_verify_op,
- 'filter': self.rule_runner.run_filter_op,
- 'extractor': self.rule_runner.run_extractor_op
+ "verify": self.rule_runner.run_verify_op,
+ "filter": self.rule_runner.run_filter_op,
+ "extractor": self.rule_runner.run_extractor_op,
}
def _deduce_call(self, node: DeduceNode, req_id: str, param: dict) -> list:
op_mapping = {
- 'choice': ChoiceOp(self.nl_query, self.kg_graph, self.schema, self.debug_info,KAG_PROJECT_ID = self.KAG_PROJECT_ID),
- 'multiChoice': MultiChoiceOp(self.nl_query, self.kg_graph, self.schema, self.debug_info,KAG_PROJECT_ID = self.KAG_PROJECT_ID),
- 'entailment': EntailmentOp(self.nl_query, self.kg_graph, self.schema, self.debug_info,KAG_PROJECT_ID = self.KAG_PROJECT_ID),
- 'judgement': JudgementOp(self.nl_query, self.kg_graph, self.schema, self.debug_info,KAG_PROJECT_ID = self.KAG_PROJECT_ID)
+ "choice": ChoiceOp(
+ self.nl_query,
+ self.kg_graph,
+ self.schema,
+ self.debug_info,
+ KAG_PROJECT_ID=self.KAG_PROJECT_ID,
+ ),
+ "multiChoice": MultiChoiceOp(
+ self.nl_query,
+ self.kg_graph,
+ self.schema,
+ self.debug_info,
+ KAG_PROJECT_ID=self.KAG_PROJECT_ID,
+ ),
+ "entailment": EntailmentOp(
+ self.nl_query,
+ self.kg_graph,
+ self.schema,
+ self.debug_info,
+ KAG_PROJECT_ID=self.KAG_PROJECT_ID,
+ ),
+ "judgement": JudgementOp(
+ self.nl_query,
+ self.kg_graph,
+ self.schema,
+ self.debug_info,
+ KAG_PROJECT_ID=self.KAG_PROJECT_ID,
+ ),
}
result = []
for op in node.deduce_ops:
@@ -40,9 +81,13 @@ def _deduce_call(self, node: DeduceNode, req_id: str, param: dict) -> list:
return result
def is_this_op(self, logic_node: LogicNode) -> bool:
- return isinstance(logic_node, (DeduceNode, FilterNode, VerifyNode, ExtractorNode))
+ return isinstance(
+ logic_node, (DeduceNode, FilterNode, VerifyNode, ExtractorNode)
+ )
- def executor(self, logic_node: LogicNode, req_id: str, param: dict) -> Union[KgGraph, list]:
+ def executor(
+ self, logic_node: LogicNode, req_id: str, param: dict
+ ) -> Union[KgGraph, list]:
if isinstance(logic_node, DeduceNode):
return self._deduce_call(logic_node, req_id, param)
op_func = self.op_register_map.get(logic_node.operator, None)
diff --git a/kag/solver/logic/core_modules/op_executor/op_deduce/module/choice.py b/kag/solver/logic/core_modules/op_executor/op_deduce/module/choice.py
index ef51a843..e34af87e 100644
--- a/kag/solver/logic/core_modules/op_executor/op_deduce/module/choice.py
+++ b/kag/solver/logic/core_modules/op_executor/op_deduce/module/choice.py
@@ -6,7 +6,14 @@
class ChoiceOp(OpExecutor):
- def __init__(self, nl_query: str, kg_graph: KgGraph, schema: SchemaUtils, debug_info: dict, **kwargs):
+ def __init__(
+ self,
+ nl_query: str,
+ kg_graph: KgGraph,
+ schema: SchemaUtils,
+ debug_info: dict,
+ **kwargs,
+ ):
super().__init__(nl_query, kg_graph, schema, debug_info, **kwargs)
self.prompt = PromptOp.load(self.biz_scene, "deduce_choice")(
language=self.language
@@ -16,6 +23,10 @@ def executor(self, logic_node: LogicNode, req_id: str, param: dict) -> list:
# get history qa pair from debug_info
history_qa_pair = self.debug_info.get("sub_qa_pair", [])
qa_pair = "\n".join([f"Q: {q}\nA: {a}" for q, a in history_qa_pair])
- if_answered, answer = self.llm_module.invoke({'instruction': self.nl_query, 'memory': qa_pair},
- self.prompt, with_json_parse=False, with_except=True)
+ if_answered, answer = self.llm_module.invoke(
+ {"instruction": self.nl_query, "memory": qa_pair},
+ self.prompt,
+ with_json_parse=False,
+ with_except=True,
+ )
return [if_answered, answer]
diff --git a/kag/solver/logic/core_modules/op_executor/op_deduce/module/entailment.py b/kag/solver/logic/core_modules/op_executor/op_deduce/module/entailment.py
index ed737946..5bce8fd8 100644
--- a/kag/solver/logic/core_modules/op_executor/op_deduce/module/entailment.py
+++ b/kag/solver/logic/core_modules/op_executor/op_deduce/module/entailment.py
@@ -6,7 +6,14 @@
class EntailmentOp(OpExecutor):
- def __init__(self, nl_query: str, kg_graph: KgGraph, schema: SchemaUtils, debug_info: dict, **kwargs):
+ def __init__(
+ self,
+ nl_query: str,
+ kg_graph: KgGraph,
+ schema: SchemaUtils,
+ debug_info: dict,
+ **kwargs,
+ ):
super().__init__(nl_query, kg_graph, schema, debug_info, **kwargs)
self.prompt = PromptOp.load(self.biz_scene, "deduce_entail")(
language=self.language
@@ -17,6 +24,10 @@ def executor(self, logic_node: LogicNode, req_id: str, param: dict) -> list:
qa_pair = "\n".join([f"Q: {q}\nA: {a}" for q, a in history_qa_pair])
spo_info = self.kg_graph.to_evidence()
information = str(spo_info) + "\n" + qa_pair
- if_answered, answer = self.llm_module.invoke({'instruction': self.nl_query, 'memory': information},
- self.prompt, with_json_parse=False, with_except=True)
- return [if_answered, answer]
\ No newline at end of file
+ if_answered, answer = self.llm_module.invoke(
+ {"instruction": self.nl_query, "memory": information},
+ self.prompt,
+ with_json_parse=False,
+ with_except=True,
+ )
+ return [if_answered, answer]
diff --git a/kag/solver/logic/core_modules/op_executor/op_deduce/module/judgement.py b/kag/solver/logic/core_modules/op_executor/op_deduce/module/judgement.py
index 965f1bf4..105d78cf 100644
--- a/kag/solver/logic/core_modules/op_executor/op_deduce/module/judgement.py
+++ b/kag/solver/logic/core_modules/op_executor/op_deduce/module/judgement.py
@@ -6,7 +6,14 @@
class JudgementOp(OpExecutor):
- def __init__(self, nl_query: str, kg_graph: KgGraph, schema: SchemaUtils, debug_info: dict, **kwargs):
+ def __init__(
+ self,
+ nl_query: str,
+ kg_graph: KgGraph,
+ schema: SchemaUtils,
+ debug_info: dict,
+ **kwargs,
+ ):
super().__init__(nl_query, kg_graph, schema, debug_info, **kwargs)
self.prompt = PromptOp.load(self.biz_scene, "deduce_judge")(
language=self.language
@@ -17,6 +24,10 @@ def executor(self, logic_node: LogicNode, req_id: str, param: dict) -> list:
qa_pair = "\n".join([f"Q: {q}\nA: {a}" for q, a in history_qa_pair])
spo_info = self.kg_graph.to_evidence()
information = str(spo_info) + "\n" + qa_pair
- if_answered, answer = self.llm_module.invoke({'instruction': self.nl_query, 'memory': information},
- self.prompt, with_json_parse=False, with_except=True)
- return [if_answered, answer]
\ No newline at end of file
+ if_answered, answer = self.llm_module.invoke(
+ {"instruction": self.nl_query, "memory": information},
+ self.prompt,
+ with_json_parse=False,
+ with_except=True,
+ )
+ return [if_answered, answer]
diff --git a/kag/solver/logic/core_modules/op_executor/op_deduce/module/multi_choice.py b/kag/solver/logic/core_modules/op_executor/op_deduce/module/multi_choice.py
index ead1646f..1706c4f5 100644
--- a/kag/solver/logic/core_modules/op_executor/op_deduce/module/multi_choice.py
+++ b/kag/solver/logic/core_modules/op_executor/op_deduce/module/multi_choice.py
@@ -6,7 +6,14 @@
class MultiChoiceOp(OpExecutor):
- def __init__(self, nl_query: str, kg_graph: KgGraph, schema: SchemaUtils, debug_info: dict, **kwargs):
+ def __init__(
+ self,
+ nl_query: str,
+ kg_graph: KgGraph,
+ schema: SchemaUtils,
+ debug_info: dict,
+ **kwargs,
+ ):
super().__init__(nl_query, kg_graph, schema, debug_info, **kwargs)
self.prompt = PromptOp.load(self.biz_scene, "deduce_multi_choice")(
language=self.language
@@ -16,6 +23,10 @@ def executor(self, logic_node: LogicNode, req_id: str, param: dict) -> list:
# get history qa pair from debug_info
history_qa_pair = self.debug_info.get("sub_qa_pair", [])
qa_pair = "\n".join([f"Q: {q}\nA: {a}" for q, a in history_qa_pair])
- if_answered, answer = self.llm_module.invoke({'instruction': self.nl_query, 'memory': qa_pair},
- self.prompt, with_json_parse=False, with_except=True)
+ if_answered, answer = self.llm_module.invoke(
+ {"instruction": self.nl_query, "memory": qa_pair},
+ self.prompt,
+ with_json_parse=False,
+ with_except=True,
+ )
return [if_answered, answer]
diff --git a/kag/solver/logic/core_modules/op_executor/op_executor.py b/kag/solver/logic/core_modules/op_executor/op_executor.py
index 0bc47931..8ed1b0fa 100644
--- a/kag/solver/logic/core_modules/op_executor/op_executor.py
+++ b/kag/solver/logic/core_modules/op_executor/op_executor.py
@@ -13,7 +13,15 @@ class OpExecutor(KagBaseModule, ABC):
Each subclass must implement the execution and judgment functions.
"""
- def __init__(self, nl_query: str, kg_graph: KgGraph, schema: SchemaUtils, debug_info: dict, **kwargs):
+
+ def __init__(
+ self,
+ nl_query: str,
+ kg_graph: KgGraph,
+ schema: SchemaUtils,
+ debug_info: dict,
+ **kwargs
+ ):
"""
Initializes the operator executor with necessary components.
@@ -29,20 +37,22 @@ def __init__(self, nl_query: str, kg_graph: KgGraph, schema: SchemaUtils, debug_
self.nl_query = nl_query
self.debug_info = debug_info
- def executor(self, logic_node: LogicNode, req_id: str, param: dict) -> Union[KgGraph, list]:
+ def executor(
+ self, logic_node: LogicNode, req_id: str, param: dict
+ ) -> Union[KgGraph, list]:
"""
- Executes the operation based on the given logic node.
+ Executes the operation based on the given logic node.
- This method should be implemented by subclasses to define how the operation is executed.
+ This method should be implemented by subclasses to define how the operation is executed.
- Parameters:
- logic_node (LogicNode): The logic node that defines the operation to execute.
- req_id (str): Request identifier.
- param (dict): Parameters needed for the execution.
+ Parameters:
+ logic_node (LogicNode): The logic node that defines the operation to execute.
+ req_id (str): Request identifier.
+ param (dict): Parameters needed for the execution.
- Returns:
- Union[KgGraph, list]: The result of the operation, which could be a knowledge graph or a list.
- """
+ Returns:
+ Union[KgGraph, list]: The result of the operation, which could be a knowledge graph or a list.
+ """
pass
def is_this_op(self, logic_node: LogicNode) -> bool:
@@ -58,4 +68,4 @@ def is_this_op(self, logic_node: LogicNode) -> bool:
Returns:
bool: True if this executor can handle the logic node, False otherwise.
"""
- pass
\ No newline at end of file
+ pass
diff --git a/kag/solver/logic/core_modules/op_executor/op_math/math_executor.py b/kag/solver/logic/core_modules/op_executor/op_math/math_executor.py
index 511b0288..81ac5043 100644
--- a/kag/solver/logic/core_modules/op_executor/op_math/math_executor.py
+++ b/kag/solver/logic/core_modules/op_executor/op_math/math_executor.py
@@ -8,11 +8,20 @@
class MathExecutor(OpExecutor):
- def __init__(self, nl_query: str, kg_graph: KgGraph, schema: SchemaUtils, debug_info: dict, **kwargs):
+ def __init__(
+ self,
+ nl_query: str,
+ kg_graph: KgGraph,
+ schema: SchemaUtils,
+ debug_info: dict,
+ **kwargs
+ ):
super().__init__(nl_query, kg_graph, schema, debug_info, **kwargs)
def is_this_op(self, logic_node: LogicNode) -> bool:
return isinstance(logic_node, (CountNode, SumNode))
- def executor(self, logic_node: LogicNode, req_id: str, param: dict) -> Union[KgGraph, list]:
+ def executor(
+ self, logic_node: LogicNode, req_id: str, param: dict
+ ) -> Union[KgGraph, list]:
pass
diff --git a/kag/solver/logic/core_modules/op_executor/op_output/module/get_executor.py b/kag/solver/logic/core_modules/op_executor/op_output/module/get_executor.py
index 61612951..1c1450f8 100644
--- a/kag/solver/logic/core_modules/op_executor/op_output/module/get_executor.py
+++ b/kag/solver/logic/core_modules/op_executor/op_output/module/get_executor.py
@@ -1,15 +1,31 @@
from kag.solver.logic.core_modules.common.base_model import SPOEntity, LogicNode
-from kag.solver.logic.core_modules.common.one_hop_graph import KgGraph, EntityData, RelationData
+from kag.solver.logic.core_modules.common.one_hop_graph import (
+ KgGraph,
+ EntityData,
+ RelationData,
+)
from kag.solver.logic.core_modules.common.schema_utils import SchemaUtils
from kag.solver.logic.core_modules.op_executor.op_executor import OpExecutor
from kag.solver.logic.core_modules.parser.logic_node_parser import GetNode
-from kag.solver.logic.core_modules.retriver.entity_linker import spo_entity_linker, EntityLinkerBase
+from kag.solver.logic.core_modules.retriver.entity_linker import (
+ spo_entity_linker,
+ EntityLinkerBase,
+)
from kag.solver.logic.core_modules.retriver.graph_retriver.dsl_executor import DslRunner
class GetExecutor(OpExecutor):
- def __init__(self, nl_query: str, kg_graph: KgGraph, schema: SchemaUtils, el: EntityLinkerBase,
- dsl_runner: DslRunner, cached_map: dict, debug_info: dict, **kwargs):
+ def __init__(
+ self,
+ nl_query: str,
+ kg_graph: KgGraph,
+ schema: SchemaUtils,
+ el: EntityLinkerBase,
+ dsl_runner: DslRunner,
+ cached_map: dict,
+ debug_info: dict,
+ **kwargs
+ ):
super().__init__(nl_query, kg_graph, schema, debug_info, **kwargs)
self.el = el
@@ -18,7 +34,9 @@ def __init__(self, nl_query: str, kg_graph: KgGraph, schema: SchemaUtils, el: En
def executor(self, logic_node: LogicNode, req_id: str, param: dict) -> list:
kg_qa_result = []
- if not isinstance(logic_node, GetNode) or not self.debug_info.get('exact_match_spo', False):
+ if not isinstance(logic_node, GetNode) or not self.debug_info.get(
+ "exact_match_spo", False
+ ):
return kg_qa_result
n = logic_node
@@ -28,43 +46,41 @@ def executor(self, logic_node: LogicNode, req_id: str, param: dict) -> list:
start_info_set = n.s.generate_start_infos()
for start_info in start_info_set:
id_info = EntityData()
- id_info.type = start_info['type']
- id_info.biz_id = start_info['id']
+ id_info.type = start_info["type"]
+ id_info.biz_id = start_info["id"]
s_data_set = [id_info]
elif n.s.entity_name:
- el_results, el_request, err_msg, call_result_data = spo_entity_linker(self.kg_graph,
- n,
- self.nl_query,
- self.el,
- self.schema,
- req_id,
- param)
- self.debug_info['el'] = self.debug_info['el'] + el_results
- self.debug_info['el_detail'] = self.debug_info['el_detail'] + [{
- "el_request": el_request,
- 'el_results': el_results,
- 'el_debug_result': call_result_data,
- 'err_msg': err_msg
- }]
+ el_results, el_request, err_msg, call_result_data = spo_entity_linker(
+ self.kg_graph, n, self.nl_query, self.el, self.schema, req_id, param
+ )
+ self.debug_info["el"] = self.debug_info["el"] + el_results
+ self.debug_info["el_detail"] = self.debug_info["el_detail"] + [
+ {
+ "el_request": el_request,
+ "el_results": el_results,
+ "el_debug_result": call_result_data,
+ "err_msg": err_msg,
+ }
+ ]
n.to_std(n.args)
s_data_set = self.kg_graph.get_entity_by_alias(n.alias_name)
if s_data_set is None:
- self.debug_info['get_empty'].append(n.to_dict())
+ self.debug_info["get_empty"].append(n.to_dict())
return kg_qa_result
s_biz_id_set = []
for s_data in s_data_set:
if isinstance(s_data, EntityData):
- if s_data.name == '':
+ if s_data.name == "":
s_biz_id_set.append(s_data.biz_id)
else:
kg_qa_result.append(s_data.name)
if isinstance(s_data, RelationData):
kg_qa_result.append(str(s_data))
if len(s_biz_id_set) > 0:
- one_hop_cached_map = self.dsl_runner.query_vertex_property_by_s_ids(s_biz_id_set,
- n.s.get_entity_first_type(),
- self.cached_map)
+ one_hop_cached_map = self.dsl_runner.query_vertex_property_by_s_ids(
+ s_biz_id_set, n.s.get_entity_first_type(), self.cached_map
+ )
self.kg_graph.nodes_alias.append(n.alias_name)
entities = []
@@ -75,5 +91,7 @@ def executor(self, logic_node: LogicNode, req_id: str, param: dict) -> list:
if n.alias_name not in self.kg_graph.entity_map.keys():
self.kg_graph.entity_map[n.alias_name] = entities
else:
- self.kg_graph.entity_map[n.alias_name] = self.kg_graph.entity_map[n.alias_name] + entities
+ self.kg_graph.entity_map[n.alias_name] = (
+ self.kg_graph.entity_map[n.alias_name] + entities
+ )
return kg_qa_result
diff --git a/kag/solver/logic/core_modules/op_executor/op_output/output_executor.py b/kag/solver/logic/core_modules/op_executor/op_output/output_executor.py
index 57ac70fa..22187367 100644
--- a/kag/solver/logic/core_modules/op_executor/op_output/output_executor.py
+++ b/kag/solver/logic/core_modules/op_executor/op_output/output_executor.py
@@ -4,24 +4,47 @@
from kag.solver.logic.core_modules.common.one_hop_graph import KgGraph
from kag.solver.logic.core_modules.common.schema_utils import SchemaUtils
from kag.solver.logic.core_modules.op_executor.op_executor import OpExecutor
-from kag.solver.logic.core_modules.op_executor.op_output.module.get_executor import GetExecutor
+from kag.solver.logic.core_modules.op_executor.op_output.module.get_executor import (
+ GetExecutor,
+)
from kag.solver.logic.core_modules.parser.logic_node_parser import GetNode
from kag.solver.logic.core_modules.retriver.entity_linker import EntityLinkerBase
from kag.solver.logic.core_modules.retriver.graph_retriver.dsl_executor import DslRunner
class OutputExecutor(OpExecutor):
- def __init__(self, nl_query: str, kg_graph: KgGraph, schema: SchemaUtils, el: EntityLinkerBase, dsl_runner: DslRunner, cached_map: dict, debug_info: dict, **kwargs):
+ def __init__(
+ self,
+ nl_query: str,
+ kg_graph: KgGraph,
+ schema: SchemaUtils,
+ el: EntityLinkerBase,
+ dsl_runner: DslRunner,
+ cached_map: dict,
+ debug_info: dict,
+ **kwargs
+ ):
super().__init__(nl_query, kg_graph, schema, debug_info, **kwargs)
- self.KAG_PROJECT_ID = kwargs.get('KAG_PROJECT_ID')
+ self.KAG_PROJECT_ID = kwargs.get("KAG_PROJECT_ID")
self.op_register_map = {
- 'get': GetExecutor(nl_query, kg_graph, schema, el, dsl_runner, cached_map, self.debug_info,KAG_PROJECT_ID = kwargs.get('KAG_PROJECT_ID'))
+ "get": GetExecutor(
+ nl_query,
+ kg_graph,
+ schema,
+ el,
+ dsl_runner,
+ cached_map,
+ self.debug_info,
+ KAG_PROJECT_ID=kwargs.get("KAG_PROJECT_ID"),
+ )
}
def is_this_op(self, logic_node: LogicNode) -> bool:
return isinstance(logic_node, GetNode)
- def executor(self, logic_node: LogicNode, req_id: str, param: dict) -> Union[KgGraph, list]:
+ def executor(
+ self, logic_node: LogicNode, req_id: str, param: dict
+ ) -> Union[KgGraph, list]:
op = self.op_register_map.get(logic_node.operator, None)
if op is None:
return []
diff --git a/kag/solver/logic/core_modules/op_executor/op_retrieval/module/get_spo_executor.py b/kag/solver/logic/core_modules/op_executor/op_retrieval/module/get_spo_executor.py
index d8f5d657..5bbda570 100644
--- a/kag/solver/logic/core_modules/op_executor/op_retrieval/module/get_spo_executor.py
+++ b/kag/solver/logic/core_modules/op_executor/op_retrieval/module/get_spo_executor.py
@@ -4,13 +4,20 @@
from kag.interface.retriever.kg_retriever_abc import KGRetrieverABC
from kag.solver.logic.core_modules.common.base_model import SPOEntity, LogicNode
-from kag.solver.logic.core_modules.common.one_hop_graph import KgGraph, EntityData, RelationData, \
- OneHopGraphData
+from kag.solver.logic.core_modules.common.one_hop_graph import (
+ KgGraph,
+ EntityData,
+ RelationData,
+ OneHopGraphData,
+)
from kag.solver.logic.core_modules.common.schema_utils import SchemaUtils
from kag.solver.logic.core_modules.common.text_sim_by_vector import TextSimilarity
from kag.solver.logic.core_modules.op_executor.op_executor import OpExecutor
from kag.solver.logic.core_modules.parser.logic_node_parser import GetSPONode
-from kag.solver.logic.core_modules.retriver.entity_linker import EntityLinkerBase, spo_entity_linker
+from kag.solver.logic.core_modules.retriver.entity_linker import (
+ EntityLinkerBase,
+ spo_entity_linker,
+)
from kag.solver.logic.core_modules.retriver.graph_retriver.dsl_executor import DslRunner
logger = logging.getLogger()
@@ -23,9 +30,20 @@ class GetSPOExecutor(OpExecutor):
This class is used to retrieve one-hop graphs based on the given parameters.
It extends the base `OpExecutor` class and initializes additional components specific to retrieving SPO triples.
"""
- def __init__(self, nl_query: str, kg_graph: KgGraph, schema: SchemaUtils, retrieval_spo: KGRetrieverABC,
- el: EntityLinkerBase,
- dsl_runner: DslRunner, query_one_graph_cache: dict, debug_info: dict, text_similarity: TextSimilarity=None,**kwargs):
+
+ def __init__(
+ self,
+ nl_query: str,
+ kg_graph: KgGraph,
+ schema: SchemaUtils,
+ retrieval_spo: KGRetrieverABC,
+ el: EntityLinkerBase,
+ dsl_runner: DslRunner,
+ query_one_graph_cache: dict,
+ debug_info: dict,
+ text_similarity: TextSimilarity = None,
+ **kwargs,
+ ):
"""
Initializes the GetSPOExecutor with necessary components.
@@ -47,46 +65,74 @@ def __init__(self, nl_query: str, kg_graph: KgGraph, schema: SchemaUtils, retrie
self.text_similarity = text_similarity or TextSimilarity()
- def _find_relation_result(self, n: GetSPONode, one_hop_graph_list: List[OneHopGraphData], req_id: str):
+ def _find_relation_result(
+ self, n: GetSPONode, one_hop_graph_list: List[OneHopGraphData], req_id: str
+ ):
one_kg_graph = KgGraph()
is_find_relation = False
- if n.p.get_entity_first_type_or_zh() is None and n.o.get_entity_first_type_or_zh() is None:
+ if (
+ n.p.get_entity_first_type_or_zh() is None
+ and n.o.get_entity_first_type_or_zh() is None
+ ):
is_find_relation = True
for one_hop_graph in one_hop_graph_list:
rel_set = one_hop_graph.get_all_relation_value()
one_kg_graph_ = KgGraph()
- recall_alias_name = n.s.alias_name if one_hop_graph.s_alias_name == "s" else n.o.alias_name
+ recall_alias_name = (
+ n.s.alias_name
+ if one_hop_graph.s_alias_name == "s"
+ else n.o.alias_name
+ )
one_kg_graph_.entity_map[recall_alias_name] = [one_hop_graph.s]
one_kg_graph_.edge_map[n.p.alias_name] = rel_set
one_kg_graph.merge_kg_graph(one_kg_graph_)
return one_kg_graph, is_find_relation
- def _get_spo_value_in_one_hop_graph_set(self, n: GetSPONode, one_hop_graph_list: List[OneHopGraphData], req_id: str):
+ def _get_spo_value_in_one_hop_graph_set(
+ self, n: GetSPONode, one_hop_graph_list: List[OneHopGraphData], req_id: str
+ ):
process_kg, is_rel = self._find_relation_result(n, one_hop_graph_list, req_id)
if is_rel:
return process_kg
- return self.retrieval_spo.retrieval_relation(n, one_hop_graph_list, req_id=req_id, debug_info = self.debug_info)
-
- def _run_query_vertex_one_graph(self, s_node_set: List[EntityData], o_node_set: List[EntityData], n: GetSPONode=None):
- return self.dsl_runner.query_vertex_one_graph_by_s_o_ids(s_node_set,
- o_node_set,
- self.query_one_graph_cache, n)
-
- def _execute_get_spo_by_set(self, n: GetSPONode, s_node_set: List[EntityData], o_node_set: List[EntityData], req_id):
+ return self.retrieval_spo.retrieval_relation(
+ n, one_hop_graph_list, req_id=req_id, debug_info=self.debug_info
+ )
+
+ def _run_query_vertex_one_graph(
+ self,
+ s_node_set: List[EntityData],
+ o_node_set: List[EntityData],
+ n: GetSPONode = None,
+ ):
+ return self.dsl_runner.query_vertex_one_graph_by_s_o_ids(
+ s_node_set, o_node_set, self.query_one_graph_cache, n
+ )
+
+ def _execute_get_spo_by_set(
+ self,
+ n: GetSPONode,
+ s_node_set: List[EntityData],
+ o_node_set: List[EntityData],
+ req_id,
+ ):
start_time = time.time()
kg_graph = KgGraph()
kg_graph.query_graph[n.p.alias_name] = {
"s": n.s.alias_name,
"p": n.p.alias_name,
- "o": n.o.alias_name
+ "o": n.o.alias_name,
}
- if (s_node_set is None or len(s_node_set) == 0) and (o_node_set is None or len(o_node_set) == 0):
+ if (s_node_set is None or len(s_node_set) == 0) and (
+ o_node_set is None or len(o_node_set) == 0
+ ):
logger.info(f"{req_id} not found id is spo " + str(n))
return kg_graph
one_hop_graph_map = self._run_query_vertex_one_graph(s_node_set, o_node_set, n)
end_time = time.time()
- logger.debug(f"{req_id} execute_get_spo_by_set {n} recall subgraph cost {end_time - start_time}")
+ logger.debug(
+ f"{req_id} execute_get_spo_by_set {n} recall subgraph cost {end_time - start_time}"
+ )
if len(one_hop_graph_map) == 0:
logger.debug(f"{req_id} execute_get_spo_by_set one_hop_graph_map is empty")
return kg_graph
@@ -104,7 +150,8 @@ def _execute_get_spo_by_set(self, n: GetSPONode, s_node_set: List[EntityData], o
kg_graph.merge_kg_graph(res)
logger.debug(
- f"{req_id} execute_get_spo_by_set merged kg graph ={kg_graph.to_edge_str()} cost = {time.time() - start_time}")
+ f"{req_id} execute_get_spo_by_set merged kg graph ={kg_graph.to_edge_str()} cost = {time.time() - start_time}"
+ )
return kg_graph
def executor(self, logic_node: LogicNode, req_id: str, param: dict) -> KgGraph:
@@ -118,21 +165,19 @@ def executor(self, logic_node: LogicNode, req_id: str, param: dict) -> KgGraph:
self.kg_graph.logic_form_base[n.o.alias_name] = n.o
# 实体标准化, 针对有实体名称的节点,需要做链指
- el_results, el_request, err_msg, call_result_data = spo_entity_linker(self.kg_graph,
- n,
- self.nl_query,
- self.el,
- self.schema,
- req_id,
- param)
- if el_request and el_request['entity_mentions']:
- self.debug_info['el'] = self.debug_info['el'] + el_results
- self.debug_info['el_detail'] = self.debug_info['el_detail'] + [{
- "el_request": el_request,
- 'el_results': el_results,
- 'el_debug_result': call_result_data,
- 'err_msg': err_msg
- }]
+ el_results, el_request, err_msg, call_result_data = spo_entity_linker(
+ self.kg_graph, n, self.nl_query, self.el, self.schema, req_id, param
+ )
+ if el_request and el_request["entity_mentions"]:
+ self.debug_info["el"] = self.debug_info["el"] + el_results
+ self.debug_info["el_detail"] = self.debug_info["el_detail"] + [
+ {
+ "el_request": el_request,
+ "el_results": el_results,
+ "el_debug_result": call_result_data,
+ "err_msg": err_msg,
+ }
+ ]
n.to_std(n.args)
s_data_set = []
@@ -150,7 +195,9 @@ def executor(self, logic_node: LogicNode, req_id: str, param: dict) -> KgGraph:
relation_data_set.append(s_data)
if len(relation_data_set) > 0:
- logger.info(f"{req_id} get_spo relation_data_set is not empty {str(relation_data_set)}, need get prop")
+ logger.info(
+ f"{req_id} get_spo relation_data_set is not empty {str(relation_data_set)}, need get prop"
+ )
return kg_graph
o_data_set = []
@@ -184,4 +231,4 @@ def _get_entity_node_from_lf(self, e: SPOEntity):
d.type = e.get_entity_first_type()
d.type_zh = e.get_entity_first_type_or_en()
ret.append(d)
- return ret
\ No newline at end of file
+ return ret
diff --git a/kag/solver/logic/core_modules/op_executor/op_retrieval/module/search_s.py b/kag/solver/logic/core_modules/op_executor/op_retrieval/module/search_s.py
index fa1ae134..b96d5c08 100644
--- a/kag/solver/logic/core_modules/op_executor/op_retrieval/module/search_s.py
+++ b/kag/solver/logic/core_modules/op_executor/op_retrieval/module/search_s.py
@@ -5,8 +5,15 @@
class SearchS(OpExecutor):
- def __init__(self, nl_query: str, kg_graph: KgGraph, schema: SchemaUtils, debug_info: dict, **kwargs):
+ def __init__(
+ self,
+ nl_query: str,
+ kg_graph: KgGraph,
+ schema: SchemaUtils,
+ debug_info: dict,
+ **kwargs
+ ):
super().__init__(nl_query, kg_graph, schema, debug_info, **kwargs)
def executor(self, logic_node: LogicNode, req_id: str, param: dict) -> KgGraph:
- raise NotImplementedError("search s not impl")
\ No newline at end of file
+ raise NotImplementedError("search s not impl")
diff --git a/kag/solver/logic/core_modules/op_executor/op_retrieval/retrieval_executor.py b/kag/solver/logic/core_modules/op_executor/op_retrieval/retrieval_executor.py
index aa80e989..e9e8729e 100644
--- a/kag/solver/logic/core_modules/op_executor/op_retrieval/retrieval_executor.py
+++ b/kag/solver/logic/core_modules/op_executor/op_retrieval/retrieval_executor.py
@@ -6,8 +6,12 @@
from kag.solver.logic.core_modules.common.schema_utils import SchemaUtils
from kag.solver.logic.core_modules.common.text_sim_by_vector import TextSimilarity
from kag.solver.logic.core_modules.op_executor.op_executor import OpExecutor
-from kag.solver.logic.core_modules.op_executor.op_retrieval.module.get_spo_executor import GetSPOExecutor
-from kag.solver.logic.core_modules.op_executor.op_retrieval.module.search_s import SearchS
+from kag.solver.logic.core_modules.op_executor.op_retrieval.module.get_spo_executor import (
+ GetSPOExecutor,
+)
+from kag.solver.logic.core_modules.op_executor.op_retrieval.module.search_s import (
+ SearchS,
+)
from kag.solver.logic.core_modules.parser.logic_node_parser import GetSPONode
from kag.solver.logic.core_modules.retriver.entity_linker import EntityLinkerBase
from kag.solver.logic.core_modules.retriver.graph_retriver.dsl_executor import DslRunner
@@ -16,13 +20,40 @@
class RetrievalExecutor(OpExecutor):
- def __init__(self, nl_query: str, kg_graph: KgGraph, schema: SchemaUtils, retrieval_spo: KGRetrieverABC, el: EntityLinkerBase,
- dsl_runner: DslRunner, debug_info: dict, text_similarity: TextSimilarity=None,**kwargs):
+ def __init__(
+ self,
+ nl_query: str,
+ kg_graph: KgGraph,
+ schema: SchemaUtils,
+ retrieval_spo: KGRetrieverABC,
+ el: EntityLinkerBase,
+ dsl_runner: DslRunner,
+ debug_info: dict,
+ text_similarity: TextSimilarity = None,
+ **kwargs,
+ ):
super().__init__(nl_query, kg_graph, schema, debug_info, **kwargs)
self.query_one_graph_cache = {}
self.op_register_map = {
- 'get_spo': GetSPOExecutor(nl_query, kg_graph, schema, retrieval_spo, el, dsl_runner, self.query_one_graph_cache, self.debug_info, text_similarity,KAG_PROJECT_ID = kwargs.get('KAG_PROJECT_ID')),
- 'search_s': SearchS(nl_query, kg_graph, schema, self.debug_info,KAG_PROJECT_ID = kwargs.get('KAG_PROJECT_ID'))
+ "get_spo": GetSPOExecutor(
+ nl_query,
+ kg_graph,
+ schema,
+ retrieval_spo,
+ el,
+ dsl_runner,
+ self.query_one_graph_cache,
+ self.debug_info,
+ text_similarity,
+ KAG_PROJECT_ID=kwargs.get("KAG_PROJECT_ID"),
+ ),
+ "search_s": SearchS(
+ nl_query,
+ kg_graph,
+ schema,
+ self.debug_info,
+ KAG_PROJECT_ID=kwargs.get("KAG_PROJECT_ID"),
+ ),
}
def is_this_op(self, logic_node: LogicNode) -> bool:
diff --git a/kag/solver/logic/core_modules/op_executor/op_sort/sort_executor.py b/kag/solver/logic/core_modules/op_executor/op_sort/sort_executor.py
index a3e47e49..bee87b10 100644
--- a/kag/solver/logic/core_modules/op_executor/op_sort/sort_executor.py
+++ b/kag/solver/logic/core_modules/op_executor/op_sort/sort_executor.py
@@ -8,10 +8,20 @@
class SortExecutor(OpExecutor):
- def __init__(self, nl_query: str, kg_graph: KgGraph, schema: SchemaUtils, debug_info: dict, **kwargs):
+ def __init__(
+ self,
+ nl_query: str,
+ kg_graph: KgGraph,
+ schema: SchemaUtils,
+ debug_info: dict,
+ **kwargs
+ ):
super().__init__(nl_query, kg_graph, schema, debug_info, **kwargs)
def is_this_op(self, logic_node: LogicNode) -> bool:
return isinstance(logic_node, SortNode)
- def executor(self, logic_node: LogicNode, req_id: str, param: dict) -> Union[KgGraph, list]:
+
+ def executor(
+ self, logic_node: LogicNode, req_id: str, param: dict
+ ) -> Union[KgGraph, list]:
pass
diff --git a/kag/solver/logic/core_modules/parser/logic_node_parser.py b/kag/solver/logic/core_modules/parser/logic_node_parser.py
index e0e95ee8..21d1a0b3 100644
--- a/kag/solver/logic/core_modules/parser/logic_node_parser.py
+++ b/kag/solver/logic/core_modules/parser/logic_node_parser.py
@@ -1,20 +1,25 @@
import logging
import re
-from kag.solver.logic.core_modules.common.base_model import SPOBase, SPOEntity, SPORelation, Identifer, \
- TypeInfo, LogicNode
+from kag.solver.logic.core_modules.common.base_model import (
+ SPOBase,
+ SPOEntity,
+ SPORelation,
+ Identifer,
+ TypeInfo,
+ LogicNode,
+)
from kag.solver.logic.core_modules.common.schema_utils import SchemaUtils
-logger = logging.getLogger(__name__)\
-
+logger = logging.getLogger(__name__)
# get_spg(s, p, o)
class GetSPONode(LogicNode):
def __init__(self, operator, args):
super().__init__(operator, args)
- self.s: SPOBase = args.get('s', None)
- self.p: SPOBase = args.get('p', None)
- self.o: SPOBase = args.get('o', None)
+ self.s: SPOBase = args.get("s", None)
+ self.p: SPOBase = args.get("p", None)
+ self.o: SPOBase = args.get("o", None)
self.sub_query = args.get("sub_query", None)
self.query = args.get("query", None)
@@ -24,14 +29,14 @@ def to_dsl(self):
def to_std(self, args):
for key, value in args.items():
self.args[key] = value
- self.s = args.get('s', self.s)
- self.p = args.get('p', self.p)
- self.o = args.get('o', self.o)
- self.sub_query = args.get('sub_query', self.sub_query)
+ self.s = args.get("s", self.s)
+ self.p = args.get("p", self.p)
+ self.o = args.get("o", self.o)
+ self.sub_query = args.get("sub_query", self.sub_query)
@staticmethod
def parse_node(input_str):
- equality_list = re.findall(r'([\w.]+=[^=]+)(,|,|$)', input_str)
+ equality_list = re.findall(r"([\w.]+=[^=]+)(,|,|$)", input_str)
if len(equality_list) < 3:
raise RuntimeError(f"parse {input_str} error not found s,p,o")
spo_params = [e[0] for e in equality_list[:3]]
@@ -47,7 +52,7 @@ def parse_node_spo(spo_params):
p = None
o = None
for spo_param in spo_params:
- key, param = spo_param.split('=')
+ key, param = spo_param.split("=")
if key == "s":
s = SPOEntity.parse_logic_form(param)
elif key == "o":
@@ -60,17 +65,13 @@ def parse_node_spo(spo_params):
raise RuntimeError(f"parse {str(spo_params)} error not found p")
if o is None:
raise RuntimeError(f"parse {str(spo_params)} error not found o")
- return GetSPONode("get_spo", {
- "s": s,
- "p": p,
- "o": o
- })
+ return GetSPONode("get_spo", {"s": s, "p": p, "o": o})
@staticmethod
def parse_node_value(get_spo_node_op, value_params):
for value_param in value_params:
# a.value=123,b.brand=345
- value_pair = re.findall(r'(?:[,\s]*(\w+)\.(\w+)=([^,,]+))', value_param)
+ value_pair = re.findall(r"(?:[,\s]*(\w+)\.(\w+)=([^,,]+))", value_param)
for key, property, value in value_pair:
node = None
if key == "s":
@@ -83,15 +84,15 @@ def parse_node_value(get_spo_node_op, value_params):
def binary_expr_parse(input_str):
- pattern = re.compile(r'(\w+)=((?:(?!\w+=).)*)')
+ pattern = re.compile(r"(\w+)=((?:(?!\w+=).)*)")
matches = pattern.finditer(input_str)
left_expr = None
right_expr = None
op = None
for match in matches:
key = match.group(1).strip()
- value = match.group(2).strip().rstrip(',')
- value = value.rstrip(',')
+ value = match.group(2).strip().rstrip(",")
+ value = value.rstrip(",")
if key == "left_expr":
if "," in value:
left_expr_list = list(set([Identifer(v) for v in value.split(",")]))
@@ -104,7 +105,7 @@ def binary_expr_parse(input_str):
else:
left_expr = left_expr_list
elif key == "right_expr":
- if value != '':
+ if value != "":
right_expr = value
elif key == "op":
op = value
@@ -113,21 +114,17 @@ def binary_expr_parse(input_str):
if op is None:
raise RuntimeError(f"parse {input_str} error not found op")
- return {
- "left_expr": left_expr,
- "right_expr": right_expr,
- "op": op
- }
+ return {"left_expr": left_expr, "right_expr": right_expr, "op": op}
# filter(left_expr=alias, right_expr=other_alias or const_data, op=equal|lt|gt|le|ge|in|contains|and|or|not)
class FilterNode(LogicNode):
def __init__(self, operator, args):
super().__init__(operator, args)
- self.left_expr = args.get('left_expr', None)
- self.right_expr = args.get('right_expr', None)
- self.op = args.get('op', None)
- self.OP = 'equal|lt|gt|le|ge|in|contains|and|or|not'.split('|')
+ self.left_expr = args.get("left_expr", None)
+ self.right_expr = args.get("right_expr", None)
+ self.op = args.get("op", None)
+ self.OP = "equal|lt|gt|le|ge|in|contains|and|or|not".split("|")
def to_dsl(self):
raise NotImplementedError("Subclasses should implement this method.")
@@ -135,9 +132,9 @@ def to_dsl(self):
def to_std(self, args):
for key, value in args.items():
self.args[key] = value
- self.left_expr = args.get('left_expr', self.left_expr)
- self.right_expr = args.get('right_expr', self.right_expr)
- self.op = args.get('op', self.op)
+ self.left_expr = args.get("left_expr", self.left_expr)
+ self.right_expr = args.get("right_expr", self.right_expr)
+ self.op = args.get("op", self.op)
@staticmethod
def parse_node(input_str):
@@ -157,7 +154,7 @@ def to_dsl(self):
@staticmethod
def parse_node(input_str, output_name):
- args = {'alias_name': output_name, 'set': input_str}
+ args = {"alias_name": output_name, "set": input_str}
return CountNode("count", args)
@@ -174,7 +171,7 @@ def to_dsl(self):
@staticmethod
def parse_node(input_str):
# count_alias=count(alias)
- match = re.match(r'(\w+)[\(\(](.*)[\)\)](->)?(.*)?', input_str)
+ match = re.match(r"(\w+)[\(\(](.*)[\)\)](->)?(.*)?", input_str)
if not match:
raise RuntimeError(f"parse logic form error {input_str}")
# print('match:',match.groups())
@@ -182,9 +179,9 @@ def parse_node(input_str):
operator, params, _, alias_name = match.groups()
else:
operator, params = match.groups()
- alias_name = 'sum1'
- params = params.replace(',', ',').split(',')
- args = {'alias_name': alias_name, 'set': params}
+ alias_name = "sum1"
+ params = params.replace(",", ",").split(",")
+ args = {"alias_name": alias_name, "set": params}
return SumNode("sum", args)
@@ -210,13 +207,15 @@ def get_set(self):
@staticmethod
def parse_node(input_str):
- equality_list = re.findall(r'([\w.]+=[^=]+)(,|,|$)', input_str)
+ equality_list = re.findall(r"([\w.]+=[^=]+)(,|,|$)", input_str)
if len(equality_list) < 4:
- raise RuntimeError(f"parse {input_str} error not found set,orderby,direction,limit")
+ raise RuntimeError(
+ f"parse {input_str} error not found set,orderby,direction,limit"
+ )
params = [e[0] for e in equality_list[:4]]
params_dict = {}
for param in params:
- key, value = param.split('=')
+ key, value = param.split("=")
params_dict[key] = value
return SortNode("sort", params_dict)
@@ -241,15 +240,24 @@ def get_set(self):
@staticmethod
def parse_node(input_str):
- equality_list = re.findall(r'([\w.]+=[^=]+)(,|,|$)', input_str)
+ equality_list = re.findall(r"([\w.]+=[^=]+)(,|,|$)", input_str)
if len(equality_list) < 2:
- raise RuntimeError(f"parse {input_str} error not found set,orderby,direction,limit")
+ raise RuntimeError(
+ f"parse {input_str} error not found set,orderby,direction,limit"
+ )
params = [e[0] for e in equality_list[:2]]
params_dict = {}
for param in params:
- key, value = param.split('=')
- if key == 'set':
- value = value.strip().replace(',', ',').replace(' ', '').strip('[').strip(']').split(',')
+ key, value = param.split("=")
+ if key == "set":
+ value = (
+ value.strip()
+ .replace(",", ",")
+ .replace(" ", "")
+ .strip("[")
+ .strip("]")
+ .split(",")
+ )
params_dict[key] = value
return CompareNode("compare", params_dict)
@@ -266,19 +274,25 @@ def __str__(self):
def parse_node(input_str):
ops = input_str.replace("op=", "")
input_ops = ops.split(",")
- return DeduceNode("deduce", {
- "deduce_ops": input_ops
- })
+ return DeduceNode("deduce", {"deduce_ops": input_ops})
# verity(left_expr=alias, right_expr=other_alias or const_data, op=equal|gt|lt|ge|le|in|contains)
class VerifyNode(LogicNode):
def __init__(self, operator, args):
super().__init__(operator, args)
- self.left_expr = args.get('left_expr', None)
- self.right_expr = args.get('right_expr', None)
- self.op = args.get('op', None)
- self.OP = {'等于': 'equal', '大于': 'gt', '小于': 'lt', '大于等于': 'ge', '小于等于': 'le', '属于': 'in', '包含': 'contains'}
+ self.left_expr = args.get("left_expr", None)
+ self.right_expr = args.get("right_expr", None)
+ self.op = args.get("op", None)
+ self.OP = {
+ "等于": "equal",
+ "大于": "gt",
+ "小于": "lt",
+ "大于等于": "ge",
+ "小于等于": "le",
+ "属于": "in",
+ "包含": "contains",
+ }
def to_dsl(self):
raise NotImplementedError("Subclasses should implement this method.")
@@ -299,16 +313,16 @@ def get_left_expr_name(self):
def to_std(self, args):
for key, value in args.items():
self.args[key] = value
- self.left_expr = args.get('left_expr', self.left_expr)
- self.right_expr = args.get('right_expr', self.right_expr)
- self.op = args.get('op', self.op)
+ self.left_expr = args.get("left_expr", self.left_expr)
+ self.right_expr = args.get("right_expr", self.right_expr)
+ self.op = args.get("op", self.op)
if self.op in self.OP.values():
self.op = self.OP[self.op]
@staticmethod
def parse_node(input_str):
if "verify" in input_str:
- match = re.match(r'(\w+)[\(\(](.*)[\)\)](->)?(.*)?', input_str)
+ match = re.match(r"(\w+)[\(\(](.*)[\)\)](->)?(.*)?", input_str)
if not match:
raise RuntimeError(f"parse logic form error {input_str}")
# print('match:',match.groups())
@@ -332,9 +346,7 @@ def to_dsl(self):
def parse_node(input_str):
params = set(input_str.split(","))
alias_set = [Identifer(p) for p in params]
- ex_node = ExtractorNode("extractor", {
- "alias_set": alias_set
- })
+ ex_node = ExtractorNode("extractor", {"alias_set": alias_set})
ex_node.alias_set = alias_set
return ex_node
@@ -354,22 +366,25 @@ def to_dsl(self):
@staticmethod
def parse_node(input_str):
input_args = input_str.split(",")
- return GetNode("get", {
- "alias_name": Identifer(input_args[0]),
- "alias_name_set": [Identifer(e) for e in input_args]
- })
+ return GetNode(
+ "get",
+ {
+ "alias_name": Identifer(input_args[0]),
+ "alias_name_set": [Identifer(e) for e in input_args],
+ },
+ )
# search_s()
class SearchNode(LogicNode):
def __init__(self, operator, args):
super().__init__(operator, args)
- self.s = SPOEntity(None, None, args['type'], None, args['alias'], False)
- self.s.value_list = args['conditions']
+ self.s = SPOEntity(None, None, args["type"], None, args["alias"], False)
+ self.s.value_list = args["conditions"]
@staticmethod
def parse_node(input_str):
- pattern = re.compile(r'[,\s]*s=(\w+):([^,\s]+),(.*)')
+ pattern = re.compile(r"[,\s]*s=(\w+):([^,\s]+),(.*)")
matches = pattern.match(input_str)
args = dict()
args["alias"] = matches.group(1)
@@ -378,21 +393,21 @@ def parse_node(input_str):
search_condition = dict()
s_condition = matches.group(3)
- condition_pattern = re.compile(r'(?:[,\s]*(\w+)\.(\w+)=([^,,]+))')
+ condition_pattern = re.compile(r"(?:[,\s]*(\w+)\.(\w+)=([^,,]+))")
condition_list = condition_pattern.findall(s_condition)
for condition in condition_list:
s_property = condition[1]
s_value = condition[2]
s_value = SearchNode.check_value_is_reference(s_value)
search_condition[s_property] = s_value
- args['conditions'] = search_condition
+ args["conditions"] = search_condition
- return SearchNode('search_s', args)
+ return SearchNode("search_s", args)
@staticmethod
def check_value_is_reference(value_str):
- if '.' in value_str:
- return value_str.split('.')
+ if "." in value_str:
+ return value_str.split(".")
return value_str
@@ -431,7 +446,9 @@ def std_parse_kg_node(self, entity: SPOBase, parsed_entity_set):
o_candis_set = self.schema.sp_o[sp_index]
for candis in o_candis_set:
spo_zh = f"{s_type_zh}_{entity_type}_{candis}"
- type_info.entity_type = self.schema.get_spo_with_p(self.schema.spo_zh_en[spo_zh])
+ type_info.entity_type = self.schema.get_spo_with_p(
+ self.schema.spo_zh_en[spo_zh]
+ )
break
if not type_info.entity_type and s_type_zh == "Entity":
@@ -440,10 +457,16 @@ def std_parse_kg_node(self, entity: SPOBase, parsed_entity_set):
s_candis_set = self.schema.op_s[op_index]
for candis in s_candis_set:
spo_zh = f"{candis}_{entity_type}_{o_type_zh}"
- type_info.entity_type = self.schema.get_spo_with_p(self.schema.spo_zh_en[spo_zh])
+ type_info.entity_type = self.schema.get_spo_with_p(
+ self.schema.spo_zh_en[spo_zh]
+ )
break
- if not type_info.entity_type and o_type_zh != "Entity" and s_type_zh != "Entity":
+ if (
+ not type_info.entity_type
+ and o_type_zh != "Entity"
+ and s_type_zh != "Entity"
+ ):
so_index = (s_type_zh, o_type_zh)
if so_index not in self.schema.so_p:
so_index = (o_type_zh, s_type_zh)
@@ -451,15 +474,21 @@ def std_parse_kg_node(self, entity: SPOBase, parsed_entity_set):
for p_candis in candis_set:
if p_candis == entity_type:
spo_zh = f"{s_type_zh}_{p_candis}_{o_type_zh}"
- type_info.entity_type = self.schema.get_spo_with_p(self.schema.spo_zh_en[spo_zh])
+ type_info.entity_type = self.schema.get_spo_with_p(
+ self.schema.spo_zh_en[spo_zh]
+ )
if not type_info.entity_type:
# maybe a property
- s_attr_zh_en = self.schema.attr_zh_en_by_label.get(s_type_en, [])
+ s_attr_zh_en = self.schema.attr_zh_en_by_label.get(
+ s_type_en, []
+ )
if s_attr_zh_en and entity_type in s_attr_zh_en:
type_info.entity_type = s_attr_zh_en[entity_type]
if not type_info.entity_type:
- o_attr_zh_en = self.schema.attr_zh_en_by_label.get(o_type_en, [])
+ o_attr_zh_en = self.schema.attr_zh_en_by_label.get(
+ o_type_en, []
+ )
if o_attr_zh_en and entity_type in o_attr_zh_en:
type_info.entity_type = o_attr_zh_en[entity_type]
std_entity_type_set.append(type_info)
@@ -503,8 +532,10 @@ def std_parse_edge(self, edge: SPORelation, parsed_entity_set):
parsed_entity_set[alias_name] = edge
return edge
- def parse_logic_form(self, input_str: str, parsed_entity_set={}, sub_query=None, query=None):
- match = re.match(r'(\w+)[\(\(](.*)[\)\)](->)?(.*)?', input_str.strip())
+ def parse_logic_form(
+ self, input_str: str, parsed_entity_set={}, sub_query=None, query=None
+ ):
+ match = re.match(r"(\w+)[\(\(](.*)[\)\)](->)?(.*)?", input_str.strip())
if not match:
raise RuntimeError(f"parse logic form error {input_str}")
if len(match.groups()) == 4:
@@ -525,12 +556,14 @@ def parse_logic_form(self, input_str: str, parsed_entity_set={}, sub_query=None,
node.p.s = s_node
node.p.o = o_node
p_node = self.std_parse_kg_node(node.p, parsed_entity_set)
- node.to_std({
- "s": s_node,
- "p": p_node,
- "o": o_node,
- "sub_query": sub_query,
- })
+ node.to_std(
+ {
+ "s": s_node,
+ "p": p_node,
+ "o": o_node,
+ "sub_query": sub_query,
+ }
+ )
elif low_operator in ["filter"]:
node: FilterNode = FilterNode.parse_node(args_str)
elif low_operator in ["deduce"]:
@@ -547,19 +580,19 @@ def parse_logic_form(self, input_str: str, parsed_entity_set={}, sub_query=None,
node: SortNode = CompareNode.parse_node(args_str)
elif low_operator in ["extractor"]:
node: ExtractorNode = ExtractorNode.parse_node(args_str)
- elif low_operator in ['search_s']:
+ elif low_operator in ["search_s"]:
node: SearchNode = SearchNode.parse_node(args_str)
self.std_parse_node(node.s, parsed_entity_set)
else:
raise NotImplementedError(f"not impl {input_str}")
- node.to_std({
- "sub_query": sub_query
- })
+ node.to_std({"sub_query": sub_query})
return node
- def parse_logic_form_set(self, input_str_set: list, sub_querys: list, question: str):
+ def parse_logic_form_set(
+ self, input_str_set: list, sub_querys: list, question: str
+ ):
parsed_cached_map = {}
parsed_node = []
for i, input_str in enumerate(input_str_set):
@@ -568,7 +601,9 @@ def parse_logic_form_set(self, input_str_set: list, sub_querys: list, question:
else:
sub_query = None
try:
- logic_node = self.parse_logic_form(input_str, parsed_cached_map, sub_query=sub_query, query=question)
+ logic_node = self.parse_logic_form(
+ input_str, parsed_cached_map, sub_query=sub_query, query=question
+ )
parsed_node.append(logic_node)
except Exception as e:
logger.warning(f"parse node {input_str} error", exc_info=True)
@@ -577,7 +612,9 @@ def parse_logic_form_set(self, input_str_set: list, sub_querys: list, question:
def std_node_type_name(self, type_name):
if self.schema_retrieval is None:
return type_name
- search_entity_labels = self.schema_retrieval.retrieval_entity(SPOEntity(entity_name=type_name))
+ search_entity_labels = self.schema_retrieval.retrieval_entity(
+ SPOEntity(entity_name=type_name)
+ )
if len(search_entity_labels) > 0:
return search_entity_labels[0].name
return type_name
diff --git a/kag/solver/logic/core_modules/retriver/entity_linker.py b/kag/solver/logic/core_modules/retriver/entity_linker.py
index e6ce92de..e5b4cfc8 100644
--- a/kag/solver/logic/core_modules/retriver/entity_linker.py
+++ b/kag/solver/logic/core_modules/retriver/entity_linker.py
@@ -18,15 +18,12 @@ class EntityLinkerBase:
def __init__(self, config):
self.config = config
- def entity_linking(self, content, entities: List[SPOEntity], req_id='', **kwargs):
+ def entity_linking(self, content, entities: List[SPOEntity], req_id="", **kwargs):
logger.info(f"EntityLinkerBase {req_id} return empty linker")
- return [
- ], []
+ return [], []
def get_service_name(self):
- return {
- 'scene_name': '空链指调用'
- }
+ return {"scene_name": "空链指调用"}
class DefaultEntityLinker(EntityLinkerBase):
@@ -36,109 +33,147 @@ def __init__(self, config, kg_retriever: KGRetrieverABC):
self.kg_retriever = kg_retriever
def get_service_name(self):
- return {
- 'scene_name': 'neo4j'
- }
+ return {"scene_name": "neo4j"}
def _call_feature(self, feature):
- mention_entity = feature.get('mention_entity', None)
+ mention_entity = feature.get("mention_entity", None)
return self.kg_retriever.retrieval_entity(mention_entity, params=feature)
- def compose_features(self, content, entities: List[SPOEntity], req_id='', params={}):
+ def compose_features(
+ self, content, entities: List[SPOEntity], req_id="", params={}
+ ):
features = []
for i, entity in enumerate(entities):
content = f"{content}[Entity]{entity.entity_name}"
feature = {
"mention_entity": entity,
"property_key": "name",
- 'content': content,
+ "content": content,
"query_text": entity.entity_name,
- 'recognition_threshold': self.recognition_threshold
+ "recognition_threshold": self.recognition_threshold,
}
feature.update(params)
features.append(feature)
return features
## ha3召回+精排链指
- def entity_linking(self, content, entities: List[SPOEntity], req_id='', **kwargs):
- '''
+ def entity_linking(self, content, entities: List[SPOEntity], req_id="", **kwargs):
+ """
input:
content: str, context
entities: [], entity spans to be linked
types: [], entity types to be linked
output:
[{'content': '吉林省抚松县被人们称为是哪种药材之乡?', 'entities': [{'word': '吉林省抚松县', 'start_idx': 0, 'recall': []}]}
- '''
+ """
features = self.compose_features(content, entities, req_id, kwargs)
entity_recalls = {}
logger.debug(f"{req_id} entity_linking {features}")
call_datas = []
if len(features) == 1:
res = self._call_feature(features[0])
- call_datas.append({'res': res, 'recalls': entity_recalls, 'content': content})
+ call_datas.append(
+ {"res": res, "recalls": entity_recalls, "content": content}
+ )
else:
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
- call_datas = [{'res': d, 'recalls': entity_recalls, 'content': content} for d in
- list(executor.map(self._call_feature, features))]
- logger.debug(f'{req_id} entity_linking result: {call_datas}')
+ call_datas = [
+ {"res": d, "recalls": entity_recalls, "content": content}
+ for d in list(executor.map(self._call_feature, features))
+ ]
+ logger.debug(f"{req_id} entity_linking result: {call_datas}")
results = []
for data in call_datas:
- recalled_entities = data['res']
+ recalled_entities = data["res"]
results.append(recalled_entities)
return results, call_datas
-def spo_entity_linker(kg_graph: KgGraph, n: Union[GetSPONode, GetNode], nl_query, el: EntityLinkerBase, schema: SchemaUtils, req_id='',
- params={}):
+def spo_entity_linker(
+ kg_graph: KgGraph,
+ n: Union[GetSPONode, GetNode],
+ nl_query,
+ el: EntityLinkerBase,
+ schema: SchemaUtils,
+ req_id="",
+ params={},
+):
el_results = []
call_result_data = []
entities_candis = []
args_entity_mentions = [[], [], []] # [keys, entities_name, entities_type]
s_data = kg_graph.get_entity_by_alias(n.s.alias_name)
- if s_data is None and isinstance(n.s, SPOEntity) and n.s.entity_name and len(n.s.id_set) == 0:
+ if (
+ s_data is None
+ and isinstance(n.s, SPOEntity)
+ and n.s.entity_name
+ and len(n.s.id_set) == 0
+ ):
entities_candis.append(n.s)
el_kg_graph = KgGraph()
if isinstance(n, GetSPONode):
o_data = kg_graph.get_entity_by_alias(n.o.alias_name)
- if o_data is None and isinstance(n.o, SPOEntity) and n.o.entity_name and len(n.o.id_set) == 0:
+ if (
+ o_data is None
+ and isinstance(n.o, SPOEntity)
+ and n.o.entity_name
+ and len(n.o.id_set) == 0
+ ):
entities_candis.append(n.o)
el_kg_graph.query_graph[n.p.alias_name] = {
"s": n.s.alias_name,
"p": n.p.alias_name,
- "o": n.o.alias_name
+ "o": n.o.alias_name,
}
- el_request = {
- "nl_query": nl_query,
- "entity_mentions": entities_candis
- }
+ el_request = {"nl_query": nl_query, "entity_mentions": entities_candis}
err_msg = ""
if entities_candis and el is not None:
try:
- el_results, call_result_data = el.entity_linking(nl_query, entities_candis, req_id, kwargs=params)
+ el_results, call_result_data = el.entity_linking(
+ nl_query, entities_candis, req_id, kwargs=params
+ )
except Exception as e:
- logger.error(f"{req_id} spo_entity_linker error, we need use name to id {str(e)}", exc_info=True)
+ logger.error(
+ f"{req_id} spo_entity_linker error, we need use name to id {str(e)}",
+ exc_info=True,
+ )
el_results = []
call_result_data = []
err_msg = str(e)
for i in range(len(entities_candis)):
candis_entitiy = entities_candis[i]
entity_data_set = []
- if el_results and i < len(el_results) and el_results[i] is not None and len(el_results[i]) > 0:
+ if (
+ el_results
+ and i < len(el_results)
+ and el_results[i] is not None
+ and len(el_results[i]) > 0
+ ):
el_recalls = el_results[i]
for entity_id_info in el_recalls:
- entity_type_zh = schema.node_en_zh[
- entity_id_info.type] if schema is not None and entity_id_info.type in schema.node_en_zh.keys() else None
+ entity_type_zh = (
+ schema.node_en_zh[entity_id_info.type]
+ if schema is not None
+ and entity_id_info.type in schema.node_en_zh.keys()
+ else None
+ )
entity_id_info.type_zh = entity_type_zh
entity_data_set.append(entity_id_info)
else:
entity_id_info = EntityData()
entity_id_info.name = candis_entitiy.entity_name
entity_id_info.biz_id = candis_entitiy.entity_name
- entity_id_info.type = schema.get_label_within_prefix(candis_entitiy.get_entity_first_type())
- entity_type_zh = schema.node_en_zh[
- entity_id_info.type] if schema is not None and entity_id_info.type in schema.node_en_zh.keys() else None
+ entity_id_info.type = schema.get_label_within_prefix(
+ candis_entitiy.get_entity_first_type()
+ )
+ entity_type_zh = (
+ schema.node_en_zh[entity_id_info.type]
+ if schema is not None
+ and entity_id_info.type in schema.node_en_zh.keys()
+ else None
+ )
entity_id_info.type_zh = entity_type_zh
entity_data_set.append(entity_id_info)
el_kg_graph.nodes_alias.append(candis_entitiy.alias_name)
diff --git a/kag/solver/logic/core_modules/retriver/graph_retriver/dsl_executor.py b/kag/solver/logic/core_modules/retriver/graph_retriver/dsl_executor.py
index e10e6a0e..84b867af 100644
--- a/kag/solver/logic/core_modules/retriver/graph_retriver/dsl_executor.py
+++ b/kag/solver/logic/core_modules/retriver/graph_retriver/dsl_executor.py
@@ -6,8 +6,13 @@
from knext.reasoner import TableResult, ReasonTask
from knext.reasoner.client import ReasonerClient
-from kag.solver.logic.core_modules.common.one_hop_graph import copy_one_hop_graph_data, EntityData, Prop, \
- OneHopGraphData, RelationData
+from kag.solver.logic.core_modules.common.one_hop_graph import (
+ copy_one_hop_graph_data,
+ EntityData,
+ Prop,
+ OneHopGraphData,
+ RelationData,
+)
from kag.solver.logic.core_modules.common.schema_utils import SchemaUtils
from kag.solver.logic.core_modules.common.utils import generate_biz_id_with_type
from kag.solver.logic.core_modules.config import LogicFormConfiguration
@@ -17,7 +22,9 @@
class DslRunner:
- def __init__(self, project_id: str, schema: SchemaUtils, config: LogicFormConfiguration):
+ def __init__(
+ self, project_id: str, schema: SchemaUtils, config: LogicFormConfiguration
+ ):
# Initialize the DslRunner with project ID, schema, and configuration.
"""
Initialize the DslRunner for graph database access using Cypher or other languages to retrieve results as OneHopGraph.
@@ -36,36 +43,53 @@ def get_cached_one_hop_data(self, query_one_graph_cache: dict, biz_id, spo_name)
return copy_one_hop_graph_data(query_one_graph_cache[biz_id], spo_name)
return None
- def call_sub_event(self, s_biz_id_set: list, s_node_type: str, o_node_type: str, p_name: str, out_direct: bool,
- filter_map: dict = None):
+ def call_sub_event(
+ self,
+ s_biz_id_set: list,
+ s_node_type: str,
+ o_node_type: str,
+ p_name: str,
+ out_direct: bool,
+ filter_map: dict = None,
+ ):
pass
- def run_dsl(self, query, dsl, start_id, params, schema: SchemaUtils, graph_output=False):
+ def run_dsl(
+ self, query, dsl, start_id, params, schema: SchemaUtils, graph_output=False
+ ):
pass
"""
batch query with s and o
"""
- def query_vertex_one_graph_by_s_o_ids(self, s_node_set: List[EntityData], o_node_set: List[EntityData],
- cached_map: dict,n: GetSPONode=None):
+ def query_vertex_one_graph_by_s_o_ids(
+ self,
+ s_node_set: List[EntityData],
+ o_node_set: List[EntityData],
+ cached_map: dict,
+ n: GetSPONode = None,
+ ):
pass
"""
batch query with s, only get property
"""
- def query_vertex_property_by_s_ids(self, s_biz_id_set: list, s_node_type: str, cached_map: dict):
+ def query_vertex_property_by_s_ids(
+ self, s_biz_id_set: list, s_node_type: str, cached_map: dict
+ ):
pass
class DslRunnerOnGraphStore(DslRunner):
-
def __init__(self, project_id, schema: SchemaUtils, config: LogicFormConfiguration):
super().__init__(project_id, schema, config)
self.schema = schema
- def run_dsl(self, query, dsl, start_id, params, schema: SchemaUtils, graph_output=False):
+ def run_dsl(
+ self, query, dsl, start_id, params, schema: SchemaUtils, graph_output=False
+ ):
pass
def _get_filter_gql(self, filter_map: dict, alias: str):
@@ -76,28 +100,27 @@ def _get_filter_gql(self, filter_map: dict, alias: str):
def _convert_node_to_json(self, node_str):
try:
import json
+
node = json.loads(node_str)
except:
return {}
return {
- 'id': node['id'],
- 'type': node['__label__'],
- 'propertyValues': dict(node)
+ "id": node["id"],
+ "type": node["__label__"],
+ "propertyValues": dict(node),
}
def _convert_edge_to_json(self, p_str):
try:
import json
+
p = json.loads(p_str)
except:
return {}
prop = dict(p)
- prop['original_src_id1__'] = p['__from_id__']
- prop['original_dst_id2__'] = p['__to_id__']
- return {
- 'type': p['__label__'],
- 'propertyValues': prop
- }
+ prop["original_src_id1__"] = p["__from_id__"]
+ prop["original_dst_id2__"] = p["__to_id__"]
+ return {"type": p["__label__"], "propertyValues": prop}
def replace_qota(self, s: str):
return s.replace("'", "\\'")
@@ -106,12 +129,20 @@ def _generate_gql_type(self, biz_set: list, node_type: str):
if biz_set is None or len(biz_set) == 0 or node_type is None:
return ":Entity"
return f":{node_type}"
+
"""
batch query with s and o
"""
- def _do_query_vertex_one_graph_by_s_o_ids(self, s_biz_id: list, s_node_type: str, o_biz_id: list, o_node_type: str,
- p_name: str = None, filter_map: dict = None):
+ def _do_query_vertex_one_graph_by_s_o_ids(
+ self,
+ s_biz_id: list,
+ s_node_type: str,
+ o_biz_id: list,
+ o_node_type: str,
+ p_name: str = None,
+ filter_map: dict = None,
+ ):
s_biz_id_set = [f'"{self.replace_qota(str(s_id))}"' for s_id in s_biz_id]
@@ -153,7 +184,7 @@ def _do_query_vertex_one_graph_by_s_o_ids(self, s_biz_id: list, s_node_type: str
gql_param = {
"start_alias": "s" if len(s_biz_id_set) > 0 else "o",
"s_type": s_node_type,
- "o_type": o_node_type
+ "o_type": o_node_type,
}
if len(s_biz_id_set) > 0:
where_cluase.append(f"s.id in $sid")
@@ -164,7 +195,15 @@ def _do_query_vertex_one_graph_by_s_o_ids(self, s_biz_id: list, s_node_type: str
gql_param["oid"] = f'[{",".join(o_biz_id_set)}]'
if p_name is None:
p_name = "rdf_expand()"
- gql_set = self._generate_gql_prio_set(s_node_type, s_biz_id_set, o_node_type, o_biz_id_set, p_name, where_cluase, return_cluase)
+ gql_set = self._generate_gql_prio_set(
+ s_node_type,
+ s_biz_id_set,
+ o_node_type,
+ o_biz_id_set,
+ p_name,
+ where_cluase,
+ return_cluase,
+ )
logger.debug("query_vertex_one_graph_by_s_o_ids query " + str(gql_set))
start_time = time.time()
@@ -176,12 +215,23 @@ def _do_query_vertex_one_graph_by_s_o_ids(self, s_biz_id: list, s_node_type: str
if len(o_biz_id_set) > 0:
add_alias.append("o")
out = self.parse_one_hot_graph_graph_detail_with_id_map(res.task, add_alias)
- logger.debug(f"query_vertex_one_graph_by_s_o_ids {s_biz_id_set} cost end time {time.time() - start_time}")
+ logger.debug(
+ f"query_vertex_one_graph_by_s_o_ids {s_biz_id_set} cost end time {time.time() - start_time}"
+ )
if out is not None and len(out) > 0:
return out
return {}
- def _generate_gql_prio_set(self, s_type, s_biz_id_set, o_type, o_biz_id_set, p_type, where_cluase, return_cluase):
+ def _generate_gql_prio_set(
+ self,
+ s_type,
+ s_biz_id_set,
+ o_type,
+ o_biz_id_set,
+ p_type,
+ where_cluase,
+ return_cluase,
+ ):
s_gql = f"(s{self._generate_gql_type(s_biz_id_set, s_type)})"
o_gql = f"(o{self._generate_gql_type(o_biz_id_set, o_type)})"
rdf_expand_gql = f"""match {s_gql}-[p:rdf_expand()]-{o_gql}
@@ -192,14 +242,28 @@ def _generate_gql_prio_set(self, s_type, s_biz_id_set, o_type, o_biz_id_set, p_t
s_without_prefix_type = self.schema.get_label_without_prefix(s_type)
o_without_prefix_type = self.schema.get_label_without_prefix(o_type)
ret_gql = []
- if (s_without_prefix_type, o_without_prefix_type) in self.schema.so_p_en and p_type in self.schema.so_p_en[(s_without_prefix_type, o_without_prefix_type)]:
- ret_gql.append(f"""match (s:{s_type})-[p:{p_type}]->(o:{o_type})
+ if (
+ s_without_prefix_type,
+ o_without_prefix_type,
+ ) in self.schema.so_p_en and p_type in self.schema.so_p_en[
+ (s_without_prefix_type, o_without_prefix_type)
+ ]:
+ ret_gql.append(
+ f"""match (s:{s_type})-[p:{p_type}]->(o:{o_type})
{'where ' + "and".join(where_cluase) if len(where_cluase) > 0 else ''}
- return {','.join(return_cluase)}""")
- if (o_without_prefix_type, s_without_prefix_type) in self.schema.op_s_en and p_type in self.schema.op_s_en[(o_without_prefix_type, s_without_prefix_type)]:
- ret_gql.append(f"""match (s:{s_type})<-[p:{p_type}]-(o:{o_type})
+ return {','.join(return_cluase)}"""
+ )
+ if (
+ o_without_prefix_type,
+ s_without_prefix_type,
+ ) in self.schema.op_s_en and p_type in self.schema.op_s_en[
+ (o_without_prefix_type, s_without_prefix_type)
+ ]:
+ ret_gql.append(
+ f"""match (s:{s_type})<-[p:{p_type}]-(o:{o_type})
{'where ' + "and".join(where_cluase) if len(where_cluase) > 0 else ''}
- return {','.join(return_cluase)}""")
+ return {','.join(return_cluase)}"""
+ )
ret_gql.append(rdf_expand_gql)
return ret_gql
@@ -214,17 +278,32 @@ def _get_node_type_zh(self, node_type):
if node_type == "attribute":
return "文本"
return node_type
+
def _get_p_type_name(self, n: GetSPONode):
if n is None:
return None
return n.p.get_entity_first_type()
- def _get_entity_type_name(self, d: EntityData, n: GetSPONode=None, alias=None):
+
+ def _get_entity_type_name(self, d: EntityData, n: GetSPONode = None, alias=None):
if d is None and n is None:
return None
- return self.schema.get_label_within_prefix(n.s.get_entity_first_type() if alias=="s" else n.o.get_entity_first_type()) if d is None else d.type
-
- def query_vertex_one_graph_by_s_o_ids(self, s_node_set: List[EntityData], o_node_set: List[EntityData],
- cached_map: dict, n: GetSPONode=None):
+ return (
+ self.schema.get_label_within_prefix(
+ n.s.get_entity_first_type()
+ if alias == "s"
+ else n.o.get_entity_first_type()
+ )
+ if d is None
+ else d.type
+ )
+
+ def query_vertex_one_graph_by_s_o_ids(
+ self,
+ s_node_set: List[EntityData],
+ o_node_set: List[EntityData],
+ cached_map: dict,
+ n: GetSPONode = None,
+ ):
one_hop_graph_map = {}
is_enable_cache = True
if (len(s_node_set) != 0 and len(o_node_set) != 0) or self._get_p_type_name(n):
@@ -234,7 +313,12 @@ def query_vertex_one_graph_by_s_o_ids(self, s_node_set: List[EntityData], o_node
o_uncached_biz_id = []
if is_enable_cache:
for s_node in s_node_set:
- cached_id = generate_biz_id_with_type(s_node.biz_id, self._get_node_type_zh(s_node.type_zh if s_node.type_zh else s_node.type))
+ cached_id = generate_biz_id_with_type(
+ s_node.biz_id,
+ self._get_node_type_zh(
+ s_node.type_zh if s_node.type_zh else s_node.type
+ ),
+ )
cached_graph = self.get_cached_one_hop_data(cached_map, cached_id, "s")
if cached_graph:
one_hop_graph_map[cached_id] = cached_graph
@@ -242,7 +326,12 @@ def query_vertex_one_graph_by_s_o_ids(self, s_node_set: List[EntityData], o_node
s_uncached_biz_id.append(s_node)
for o_node in o_node_set:
- cached_id = generate_biz_id_with_type(o_node.biz_id, self._get_node_type_zh(o_node.type_zh if o_node.type_zh else o_node.type))
+ cached_id = generate_biz_id_with_type(
+ o_node.biz_id,
+ self._get_node_type_zh(
+ o_node.type_zh if o_node.type_zh else o_node.type
+ ),
+ )
cached_graph = self.get_cached_one_hop_data(cached_map, cached_id, "o")
if cached_graph:
one_hop_graph_map[cached_id] = cached_graph
@@ -257,16 +346,26 @@ def query_vertex_one_graph_by_s_o_ids(self, s_node_set: List[EntityData], o_node
combined_list = self._shuffle_query_node(s_uncached_biz_id, o_uncached_biz_id)
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = [
- executor.submit(self._do_query_vertex_one_graph_by_s_o_ids, [] if s_node is None else [s_node.biz_id],
- self._get_entity_type_name(s_node, n, "s"), [] if o_node is None else [o_node.biz_id],
- self._get_entity_type_name(o_node, n, "o"), self._get_p_type_name(n)) for
- s_node, o_node in
- combined_list]
- results = [future.result() for future in concurrent.futures.as_completed(futures)]
+ executor.submit(
+ self._do_query_vertex_one_graph_by_s_o_ids,
+ [] if s_node is None else [s_node.biz_id],
+ self._get_entity_type_name(s_node, n, "s"),
+ [] if o_node is None else [o_node.biz_id],
+ self._get_entity_type_name(o_node, n, "o"),
+ self._get_p_type_name(n),
+ )
+ for s_node, o_node in combined_list
+ ]
+ results = [
+ future.result() for future in concurrent.futures.as_completed(futures)
+ ]
for r in results:
one_hop_graph_map.update(r)
for node in s_uncached_biz_id + o_uncached_biz_id:
- cached_id = generate_biz_id_with_type(node.biz_id, self._get_node_type_zh(node.type_zh if node.type_zh else node.type))
+ cached_id = generate_biz_id_with_type(
+ node.biz_id,
+ self._get_node_type_zh(node.type_zh if node.type_zh else node.type),
+ )
if cached_id not in one_hop_graph_map:
continue
one_hop_graph = one_hop_graph_map[cached_id]
@@ -277,7 +376,9 @@ def query_vertex_one_graph_by_s_o_ids(self, s_node_set: List[EntityData], o_node
def _shuffle_query_node(self, s_nodes: List[EntityData], o_nodes: List[EntityData]):
s_group_types = self._extra_node_id_group_by_type(s_nodes)
o_group_types = self._extra_node_id_group_by_type(o_nodes)
- combined_list = self._cartesian_product_with_default(s_group_types, o_group_types)
+ combined_list = self._cartesian_product_with_default(
+ s_group_types, o_group_types
+ )
node_ids = []
for s_group, o_group in combined_list:
node_ids = node_ids + self._cartesian_product_with_default(s_group, o_group)
@@ -291,13 +392,17 @@ def _extra_node_id_group_by_type(self, nodes: List[EntityData]):
type_map[node.type].append(node)
return list(type_map.values())
- def query_vertex_property_by_s_ids(self, s_biz_id_set: list, s_node_type: str, cached_map: dict):
+ def query_vertex_property_by_s_ids(
+ self, s_biz_id_set: list, s_node_type: str, cached_map: dict
+ ):
one_hop_graph_dict = {}
s_uncached_biz_id_set = []
s_node_type_zh = self._get_node_type_zh(s_node_type)
for s_id in s_biz_id_set:
cache_id_with_type = generate_biz_id_with_type(s_id, s_node_type_zh)
- cached_graph = self.get_cached_one_hop_data(cached_map, cache_id_with_type, "s")
+ cached_graph = self.get_cached_one_hop_data(
+ cached_map, cache_id_with_type, "s"
+ )
if cached_graph:
one_hop_graph_dict[cache_id_with_type] = cached_graph
else:
@@ -311,9 +416,7 @@ def query_vertex_property_by_s_ids(self, s_biz_id_set: list, s_node_type: str, c
return_cluase = []
id_rep = "id"
- gql_param = {
- "start_alias": "s"
- }
+ gql_param = {"start_alias": "s"}
if len(s_biz_id_set) > 0:
s_where_cluase = "s.id in $sid"
gql_param["sid"] = f"[{','.join(s_biz_id_set)}]"
@@ -336,29 +439,35 @@ def query_vertex_property_by_s_ids(self, s_biz_id_set: list, s_node_type: str, c
add_alias.append("s")
out = self.parse_one_hot_graph_graph_detail_with_id_map(res.task, add_alias)
one_hop_graph_dict.update(out)
- logger.debug(f"query_vertex_one_graph_by_s_o_ids {s_biz_id_set} cost end time {time.time() - start_time}")
+ logger.debug(
+ f"query_vertex_one_graph_by_s_o_ids {s_biz_id_set} cost end time {time.time() - start_time}"
+ )
return one_hop_graph_dict
def _trans_normal_p_json(self, p_json, s_json, o_json):
- s_type = s_json['type']
+ s_type = s_json["type"]
s_biz_id = s_json["propertyValues"]["id"]
- o_type = o_json['type']
+ o_type = o_json["type"]
o_biz_id = o_json["propertyValues"]["id"]
- p_total_type_name = p_json['type']
+ p_total_type_name = p_json["type"]
if len(s_type) > len(o_type):
- p_type = p_json['type'].replace(s_type, "").replace(o_type, "").replace("_", "")
+ p_type = (
+ p_json["type"].replace(s_type, "").replace(o_type, "").replace("_", "")
+ )
else:
- p_type = p_json['type'].replace(o_type, "").replace(s_type, "").replace("_", "")
+ p_type = (
+ p_json["type"].replace(o_type, "").replace(s_type, "").replace("_", "")
+ )
p_info = {}
from_id = None
to_id = None
- for property_key in p_json['propertyValues'].keys():
+ for property_key in p_json["propertyValues"].keys():
if property_key == "original_src_id1__":
- from_id = p_json['propertyValues'][property_key]
+ from_id = p_json["propertyValues"][property_key]
elif property_key == "original_dst_id2__":
- to_id = p_json['propertyValues'][property_key]
+ to_id = p_json["propertyValues"][property_key]
else:
- p_info[property_key] = p_json['propertyValues'][property_key]
+ p_info[property_key] = p_json["propertyValues"][property_key]
if from_id is None or to_id is None:
return None
"""
@@ -374,21 +483,25 @@ def _trans_normal_p_json(self, p_json, s_json, o_json):
is_out_edge = from_id == s_biz_id
if is_out_edge:
- p_info.update({
- "__label__": p_type,
- "__from_id__": s_biz_id,
- "__from_id_type__": s_type,
- "__to_id__": o_biz_id,
- "__to_id_type__": o_type
- })
+ p_info.update(
+ {
+ "__label__": p_type,
+ "__from_id__": s_biz_id,
+ "__from_id_type__": s_type,
+ "__to_id__": o_biz_id,
+ "__to_id_type__": o_type,
+ }
+ )
else:
- p_info.update({
- "__label__": p_type,
- "__from_id__": o_biz_id,
- "__from_id_type__": o_type,
- "__to_id__": s_biz_id,
- "__to_id_type__": s_type
- })
+ p_info.update(
+ {
+ "__label__": p_type,
+ "__from_id__": o_biz_id,
+ "__from_id_type__": o_type,
+ "__to_id__": s_biz_id,
+ "__to_id_type__": s_type,
+ }
+ )
return p_info
@@ -403,7 +516,9 @@ def _check_need_property(self, json_data):
return False
return True
- def parse_one_hot_graph_graph_detail_with_id_map(self, task_resp: ReasonTask, add_alias: list):
+ def parse_one_hot_graph_graph_detail_with_id_map(
+ self, task_resp: ReasonTask, add_alias: list
+ ):
one_hop_graph_map = {}
if task_resp is None or task_resp.status != "FINISH":
return one_hop_graph_map
@@ -434,10 +549,12 @@ def parse_one_hot_graph_graph_detail_with_id_map(self, task_resp: ReasonTask, ad
s_json = self._convert_node_to_json(data[s_index])
if self._check_need_property(s_json) is False:
continue
- prop_values = s_json['propertyValues']
+ prop_values = s_json["propertyValues"]
s_biz_id = prop_values["id"]
s_type_name = s_json["type"]
- s_biz_id_with_type_name = generate_biz_id_with_type(s_biz_id, self._get_node_type_zh(s_type_name))
+ s_biz_id_with_type_name = generate_biz_id_with_type(
+ s_biz_id, self._get_node_type_zh(s_type_name)
+ )
if s_biz_id_with_type_name not in tmp_graph_parse_result_map.keys():
s_entity = EntityData()
s_entity.type = s_type_name
@@ -453,8 +570,13 @@ def parse_one_hot_graph_graph_detail_with_id_map(self, task_resp: ReasonTask, ad
else:
s_entity = tmp_graph_parse_result_map[s_biz_id_with_type_name].s
- if "s" in add_alias and s_biz_id_with_type_name not in one_hop_graph_map.keys():
- one_hop_graph_map[s_biz_id_with_type_name] = tmp_graph_parse_result_map[s_biz_id_with_type_name]
+ if (
+ "s" in add_alias
+ and s_biz_id_with_type_name not in one_hop_graph_map.keys()
+ ):
+ one_hop_graph_map[
+ s_biz_id_with_type_name
+ ] = tmp_graph_parse_result_map[s_biz_id_with_type_name]
else:
s_biz_id_with_type_name = None
@@ -466,10 +588,12 @@ def parse_one_hot_graph_graph_detail_with_id_map(self, task_resp: ReasonTask, ad
o_json = self._convert_node_to_json(data[o_index])
if self._check_need_property(o_json) is False:
continue
- prop_values = o_json['propertyValues']
+ prop_values = o_json["propertyValues"]
o_biz_id = prop_values["id"]
o_type_name = o_json["type"]
- o_biz_id_with_type_name = generate_biz_id_with_type(o_biz_id, self._get_node_type_zh(o_type_name))
+ o_biz_id_with_type_name = generate_biz_id_with_type(
+ o_biz_id, self._get_node_type_zh(o_type_name)
+ )
if o_biz_id_with_type_name not in tmp_graph_parse_result_map.keys():
o_entity = EntityData()
o_entity.type = o_type_name
@@ -486,13 +610,20 @@ def parse_one_hot_graph_graph_detail_with_id_map(self, task_resp: ReasonTask, ad
else:
o_entity = tmp_graph_parse_result_map[o_biz_id_with_type_name].s
- if "o" in add_alias and o_biz_id_with_type_name not in one_hop_graph_map.keys():
- one_hop_graph_map[o_biz_id_with_type_name] = tmp_graph_parse_result_map[o_biz_id_with_type_name]
+ if (
+ "o" in add_alias
+ and o_biz_id_with_type_name not in one_hop_graph_map.keys()
+ ):
+ one_hop_graph_map[
+ o_biz_id_with_type_name
+ ] = tmp_graph_parse_result_map[o_biz_id_with_type_name]
else:
o_biz_id_with_type_name = None
if s_entity is None and o_entity is None:
- logger.info("parse_one_hot_graph_graph_detail_with_id_map entity is None")
+ logger.info(
+ "parse_one_hot_graph_graph_detail_with_id_map entity is None"
+ )
continue
if p_index == -1:
@@ -524,35 +655,84 @@ def parse_one_hot_graph_graph_detail_with_id_map(self, task_resp: ReasonTask, ad
o_entity.type = rel.from_type
o_entity.type_zh = self._get_node_type_zh(rel.from_type)
- if generate_biz_id_with_type(rel.from_id, rel.from_type) == generate_biz_id_with_type(s_entity.biz_id,
- s_type_name):
+ if generate_biz_id_with_type(
+ rel.from_id, rel.from_type
+ ) == generate_biz_id_with_type(s_entity.biz_id, s_type_name):
rel.from_entity = s_entity
rel.end_entity = o_entity
- if s_biz_id_with_type_name is not None and s_biz_id_with_type_name in tmp_graph_parse_result_map.keys():
- if rel.type in tmp_graph_parse_result_map[s_biz_id_with_type_name].out_relations.keys():
- tmp_graph_parse_result_map[s_biz_id_with_type_name].out_relations[rel.type].append(rel)
+ if (
+ s_biz_id_with_type_name is not None
+ and s_biz_id_with_type_name in tmp_graph_parse_result_map.keys()
+ ):
+ if (
+ rel.type
+ in tmp_graph_parse_result_map[
+ s_biz_id_with_type_name
+ ].out_relations.keys()
+ ):
+ tmp_graph_parse_result_map[
+ s_biz_id_with_type_name
+ ].out_relations[rel.type].append(rel)
else:
- tmp_graph_parse_result_map[s_biz_id_with_type_name].out_relations[rel.type] = [rel]
- if o_biz_id_with_type_name is not None and o_biz_id_with_type_name in tmp_graph_parse_result_map.keys():
- if rel.type in tmp_graph_parse_result_map[o_biz_id_with_type_name].in_relations.keys():
- tmp_graph_parse_result_map[o_biz_id_with_type_name].in_relations[rel.type].append(rel)
+ tmp_graph_parse_result_map[
+ s_biz_id_with_type_name
+ ].out_relations[rel.type] = [rel]
+ if (
+ o_biz_id_with_type_name is not None
+ and o_biz_id_with_type_name in tmp_graph_parse_result_map.keys()
+ ):
+ if (
+ rel.type
+ in tmp_graph_parse_result_map[
+ o_biz_id_with_type_name
+ ].in_relations.keys()
+ ):
+ tmp_graph_parse_result_map[
+ o_biz_id_with_type_name
+ ].in_relations[rel.type].append(rel)
else:
- tmp_graph_parse_result_map[o_biz_id_with_type_name].in_relations[rel.type] = [rel]
+ tmp_graph_parse_result_map[
+ o_biz_id_with_type_name
+ ].in_relations[rel.type] = [rel]
else:
rel.from_entity = o_entity
rel.from_alias = "o"
rel.end_entity = s_entity
rel.end_alias = "s"
- if s_biz_id_with_type_name is not None and s_biz_id_with_type_name in tmp_graph_parse_result_map.keys():
- if rel.type in tmp_graph_parse_result_map[s_biz_id_with_type_name].in_relations.keys():
- tmp_graph_parse_result_map[s_biz_id_with_type_name].in_relations[rel.type].append(rel)
+ if (
+ s_biz_id_with_type_name is not None
+ and s_biz_id_with_type_name in tmp_graph_parse_result_map.keys()
+ ):
+ if (
+ rel.type
+ in tmp_graph_parse_result_map[
+ s_biz_id_with_type_name
+ ].in_relations.keys()
+ ):
+ tmp_graph_parse_result_map[
+ s_biz_id_with_type_name
+ ].in_relations[rel.type].append(rel)
else:
- tmp_graph_parse_result_map[s_biz_id_with_type_name].in_relations[rel.type] = [rel]
-
- if o_biz_id_with_type_name is not None and o_biz_id_with_type_name in tmp_graph_parse_result_map.keys():
- if rel.type in tmp_graph_parse_result_map[o_biz_id_with_type_name].out_relations.keys():
- tmp_graph_parse_result_map[o_biz_id_with_type_name].out_relations[rel.type].append(rel)
+ tmp_graph_parse_result_map[
+ s_biz_id_with_type_name
+ ].in_relations[rel.type] = [rel]
+
+ if (
+ o_biz_id_with_type_name is not None
+ and o_biz_id_with_type_name in tmp_graph_parse_result_map.keys()
+ ):
+ if (
+ rel.type
+ in tmp_graph_parse_result_map[
+ o_biz_id_with_type_name
+ ].out_relations.keys()
+ ):
+ tmp_graph_parse_result_map[
+ o_biz_id_with_type_name
+ ].out_relations[rel.type].append(rel)
else:
- tmp_graph_parse_result_map[o_biz_id_with_type_name].out_relations[rel.type] = [rel]
+ tmp_graph_parse_result_map[
+ o_biz_id_with_type_name
+ ].out_relations[rel.type] = [rel]
return one_hop_graph_map
diff --git a/kag/solver/logic/core_modules/retriver/graph_retriver/dsl_model.py b/kag/solver/logic/core_modules/retriver/graph_retriver/dsl_model.py
index bd1250f0..37cd4980 100644
--- a/kag/solver/logic/core_modules/retriver/graph_retriver/dsl_model.py
+++ b/kag/solver/logic/core_modules/retriver/graph_retriver/dsl_model.py
@@ -45,6 +45,7 @@ def from_dict(json_dict):
entity.data = json_dict["data"]
return entity
+
class RelationDetail:
def __init__(self):
self.start_entity_type_name = None
@@ -90,18 +91,18 @@ def from_json(json_str):
@staticmethod
def from_dict(json_dict):
graph_detail = GraphDetail()
- nodes = json_dict['nodes']
+ nodes = json_dict["nodes"]
if len(nodes) != 0:
for node in nodes:
graph_detail.nodes.append(EntityDetail.from_dict(node))
- edges = json_dict['edges']
+ edges = json_dict["edges"]
if len(edges) != 0:
for edge in edges:
graph_detail.edges.append(RelationDetail.from_dict(edge))
- graph_detail.other = json_dict['other']
- graph_detail.next_query_id = json_dict['nextQueryId']
- graph_detail.view_level = ViewLevel[json_dict['viewLevel'].upper()]
+ graph_detail.other = json_dict["other"]
+ graph_detail.next_query_id = json_dict["nextQueryId"]
+ graph_detail.view_level = ViewLevel[json_dict["viewLevel"].upper()]
if "tableDetail" in json_dict.keys() and json_dict["tableDetail"] is not None:
graph_detail.tableData = TableData.from_dict(json_dict["tableDetail"])
diff --git a/kag/solver/logic/core_modules/retriver/schema_std.py b/kag/solver/logic/core_modules/retriver/schema_std.py
index 27448d42..1ac73d67 100644
--- a/kag/solver/logic/core_modules/retriver/schema_std.py
+++ b/kag/solver/logic/core_modules/retriver/schema_std.py
@@ -7,7 +7,7 @@
from kag.solver.logic.core_modules.common.base_model import SPOEntity
from kag.solver.logic.core_modules.common.one_hop_graph import EntityData
-sys.path.append('../logic_form_executor/')
+sys.path.append("../logic_form_executor/")
current_dir = os.path.dirname(os.path.abspath(__file__))
import logging
@@ -18,11 +18,16 @@ class SchemaRetrieval(KGRetrieverByLlm):
def __init__(self, **kwargs):
super().__init__(**kwargs)
- def retrieval_entity(self, mention_entity: SPOEntity, topk=1, **kwargs) -> List[EntityData]:
+ def retrieval_entity(
+ self, mention_entity: SPOEntity, topk=1, **kwargs
+ ) -> List[EntityData]:
# 根据mention召回
- label = self.schema.get_label_within_prefix('SemanticConcept')
+ label = self.schema.get_label_within_prefix("SemanticConcept")
typed_nodes = self.sc.search_vector(
- label=label, property_key="name", query_vector=self.vectorizer.vectorize(mention_entity.entity_name), topk=1
+ label=label,
+ property_key="name",
+ query_vector=self.vectorizer.vectorize(mention_entity.entity_name),
+ topk=1,
)
recalled_entity = EntityData()
recalled_entity.type = "SemanticConcept"
diff --git a/kag/solver/logic/core_modules/rule_runner/rule_runner.py b/kag/solver/logic/core_modules/rule_runner/rule_runner.py
index ff199f1a..f666f5cf 100644
--- a/kag/solver/logic/core_modules/rule_runner/rule_runner.py
+++ b/kag/solver/logic/core_modules/rule_runner/rule_runner.py
@@ -3,9 +3,16 @@
from enum import Enum
from kag.solver.logic.core_modules.common.base_model import Identifer
-from kag.solver.logic.core_modules.common.one_hop_graph import KgGraph, EntityData, RelationData
-from kag.solver.logic.core_modules.parser.logic_node_parser import FilterNode, ExtractorNode, \
- VerifyNode
+from kag.solver.logic.core_modules.common.one_hop_graph import (
+ KgGraph,
+ EntityData,
+ RelationData,
+)
+from kag.solver.logic.core_modules.parser.logic_node_parser import (
+ FilterNode,
+ ExtractorNode,
+ VerifyNode,
+)
class MatchRes(Enum):
@@ -16,7 +23,7 @@ class MatchRes(Enum):
class MatchInfo:
- def __init__(self, res: MatchRes, desc: str = ''):
+ def __init__(self, res: MatchRes, desc: str = ""):
self.res = res
self.desc = desc
@@ -30,9 +37,9 @@ def trans_match_res_to_str(self):
else:
return "不相关"
+
def trans_str_res_to_match(res: str):
- if res is None or res == '' or "无相关信息" in res \
- or "不相关" in res:
+ if res is None or res == "" or "无相关信息" in res or "不相关" in res:
return MatchRes.UN_RELATED
return MatchRes.RELATED
@@ -54,7 +61,7 @@ def __init__(self):
"exist": self.run_exists,
"necessary": self.run_necessary,
"collect_in": self.run_collect_in,
- "collect_contains": self.run_collect_contains
+ "collect_contains": self.run_collect_contains,
}
def run_rule(self, op_name: str, left_value, right_value):
@@ -179,12 +186,13 @@ def run_collect_contains(self, left_value, right_value):
class ModelRunner(StrRuleRunner):
- def __init__(self, llm, kg_graph: KgGraph, query:str, req_id: str):
+ def __init__(self, llm, kg_graph: KgGraph, query: str, req_id: str):
super().__init__()
self.llm = llm
self.kg_graph = kg_graph
self.query = query
self.req_id = req_id
+
def _get_kg_graph_data(self):
return self.kg_graph.to_spo()
@@ -198,7 +206,9 @@ def run_match(self, left_value, right_value):
right_value, right_value, str(self._get_kg_graph_data())
)
res = self.llm.generate(prompt, max_output_len=100)
- logging.info(f"ModelRunner {self.req_id} cost={time.time() - start_time} prompt={prompt} res={res}")
+ logging.info(
+ f"ModelRunner {self.req_id} cost={time.time() - start_time} prompt={prompt} res={res}"
+ )
return MatchInfo(trans_str_res_to_match(res), res)
def run_collect_in(self, left_value, right_value):
@@ -211,7 +221,9 @@ def run_collect_in(self, left_value, right_value):
right_value, right_value, str(self._get_kg_graph_data())
)
res = self.llm.generate(prompt, max_output_len=100)
- logging.info(f"ModelRunner {self.req_id} cost={time.time() - start_time} prompt={prompt} res={res}")
+ logging.info(
+ f"ModelRunner {self.req_id} cost={time.time() - start_time} prompt={prompt} res={res}"
+ )
return MatchInfo(trans_str_res_to_match(res), res)
def run_collect_contains(self, left_value, right_value):
@@ -224,7 +236,9 @@ def run_collect_contains(self, left_value, right_value):
right_value, right_value, str(self._get_kg_graph_data())
)
res = self.llm.generate(prompt, max_output_len=100)
- logging.info(f"ModelRunner {self.req_id} cost={time.time() - start_time} prompt={prompt} res={res}")
+ logging.info(
+ f"ModelRunner {self.req_id} cost={time.time() - start_time} prompt={prompt} res={res}"
+ )
return MatchInfo(trans_str_res_to_match(res), res)
@@ -238,7 +252,7 @@ def __init__(self, kg_graph: KgGraph, llm, query: str, req_id: str):
self.runner: ModelRunner = ModelRunner(llm, kg_graph, query, req_id)
self.llm = llm
- def _get_identifer_to_doc(self, alias:Identifer):
+ def _get_identifer_to_doc(self, alias: Identifer):
data = self.kg_graph.get_entity_by_alias(alias)
if data is None:
return []
@@ -321,8 +335,8 @@ def run_single_unary_exec_rule(self, op_name: str, left_value):
def single_rule_dispatch(self, op_name: str, left_value, right_value):
op_name = self._get_op_zh_2_en(op_name)
- binary_op = ['equal', 'lt', 'gt', 'le', 'ge', 'in', 'contains', 'and', 'or']
- unary_op = ['not']
+ binary_op = ["equal", "lt", "gt", "le", "ge", "in", "contains", "and", "or"]
+ unary_op = ["not"]
if op_name in binary_op:
return self.run_single_binary_exec_rule(op_name, left_value, right_value)
@@ -333,8 +347,8 @@ def single_rule_dispatch(self, op_name: str, left_value, right_value):
def collect_rule_dispatch(self, op_name: str, left_value, right_value):
op_name = self._get_op_zh_2_en(op_name)
- collect_binary_op = ['match', 'contains', 'in']
- collect_unary_op = ['exist', 'necessary']
+ collect_binary_op = ["match", "contains", "in"]
+ collect_unary_op = ["exist", "necessary"]
if op_name in collect_unary_op:
return self.run_collect_unary_exec_rule(op_name, left_value)
elif op_name in collect_binary_op:
@@ -343,22 +357,24 @@ def collect_rule_dispatch(self, op_name: str, left_value, right_value):
# agg by self
res = self.single_rule_dispatch(op_name, left_value, right_value)
if res is not None and True in res.values():
- return MatchInfo(MatchRes.MATCH, '')
- return MatchInfo(MatchRes.UN_MATCH, '')
+ return MatchInfo(MatchRes.MATCH, "")
+ return MatchInfo(MatchRes.UN_MATCH, "")
def run_collect_binary_exec_rule(self, op_name: str, left_value, right_value):
collect_op_name_map = {
"in": "collect_in",
"contains": "collect_contains",
"necessary": "necessary",
- "match": "match"
+ "match": "match",
}
left_value = self._get_value_ins(left_value)
right_value = self._get_value_ins(right_value)
"""
res = MatchRes
"""
- res: MatchRes = self.runner.op_map[collect_op_name_map[op_name]](left_value, right_value)
+ res: MatchRes = self.runner.op_map[collect_op_name_map[op_name]](
+ left_value, right_value
+ )
return res
def run_collect_unary_exec_rule(self, op_name: str, left_value):
@@ -377,7 +393,7 @@ def _get_op_zh_2_en(self, op_name):
"必要": "necessary",
"等于": "equal",
"大于": "gt",
- "小于": "lt"
+ "小于": "lt",
}
if op_name not in name_map.keys():
return op_name
@@ -385,7 +401,10 @@ def _get_op_zh_2_en(self, op_name):
def run_filter_op(self, f: FilterNode):
# 对边不执行过滤
- if isinstance(f.left_expr, Identifer) and f.left_expr in self.kg_graph.edge_alias:
+ if (
+ isinstance(f.left_expr, Identifer)
+ and f.left_expr in self.kg_graph.edge_alias
+ ):
return
res = self.single_rule_dispatch(f.op, f.left_expr, f.right_expr)
failed_list = []
@@ -395,11 +414,9 @@ def run_filter_op(self, f: FilterNode):
self.kg_graph.rmv_ins(f.left_expr, failed_list)
def run_extractor_op(self, f: ExtractorNode):
- update_verify = VerifyNode("verify", {
- "left_expr": f.alias_set,
- "right_expr": self.query,
- "op": "匹配"
- })
+ update_verify = VerifyNode(
+ "verify", {"left_expr": f.alias_set, "right_expr": self.query, "op": "匹配"}
+ )
return self.run_verify_op(update_verify)
def run_verify_op(self, f: VerifyNode):
@@ -411,7 +428,7 @@ def run_verify_op(self, f: VerifyNode):
verify_kg_graph.query_graph[p_alias_name] = {
"s": s_alias_name,
"p": p_alias_name,
- "o": o_alias_name
+ "o": o_alias_name,
}
left_value = self._get_alias_to_doc(f.left_expr)
if len(left_value) == 0:
@@ -434,7 +451,7 @@ def run_verify_op(self, f: VerifyNode):
if len(description) > 0:
s_entity_data.description = "\n\n".join(description)
right_value = f.right_expr
- if right_value is None or right_value == '':
+ if right_value is None or right_value == "":
right_value = self.query
right_value = self._get_alias_to_doc(right_value)
match_info = self.collect_rule_dispatch(f.op, f.left_expr, f.right_expr)
diff --git a/kag/solver/logic/solver_pipeline.py b/kag/solver/logic/solver_pipeline.py
index 907268f0..c34eb00d 100644
--- a/kag/solver/logic/solver_pipeline.py
+++ b/kag/solver/logic/solver_pipeline.py
@@ -12,8 +12,14 @@
class SolverPipeline:
- def __init__(self, max_run=3, reflector: KagReflectorABC = None, reasoner: KagReasonerABC = None,
- generator: KAGGeneratorABC = None, **kwargs):
+ def __init__(
+ self,
+ max_run=3,
+ reflector: KagReflectorABC = None,
+ reasoner: KagReasonerABC = None,
+ generator: KAGGeneratorABC = None,
+ **kwargs
+ ):
"""
Initializes the think-and-act loop class.
@@ -33,35 +39,39 @@ def __init__(self, max_run=3, reflector: KagReflectorABC = None, reasoner: KagRe
def run(self, question):
"""
- Executes the core logic of the problem-solving system.
+ Executes the core logic of the problem-solving system.
- Parameters:
- - question (str): The question to be answered.
+ Parameters:
+ - question (str): The question to be answered.
- Returns:
- - tuple: answer, trace log
- """
+ Returns:
+ - tuple: answer, trace log
+ """
instruction = question
if_finished = False
- logger.debug('input instruction:{}'.format(instruction))
+ logger.debug("input instruction:{}".format(instruction))
present_instruction = instruction
run_cnt = 0
while not if_finished and run_cnt < self.max_run:
run_cnt += 1
- logger.debug('present_instruction is:{}'.format(present_instruction))
+ logger.debug("present_instruction is:{}".format(present_instruction))
# Attempt to solve the current instruction and get the answer, supporting facts, and history log
- solved_answer, supporting_fact, history_log = self.reasoner.reason(present_instruction)
+ solved_answer, supporting_fact, history_log = self.reasoner.reason(
+ present_instruction
+ )
# Extract evidence from supporting facts
self.memory.save_memory(solved_answer, supporting_fact, instruction)
- history_log['present_instruction'] = present_instruction
- history_log['present_memory'] = self.memory.serialize_memory()
+ history_log["present_instruction"] = present_instruction
+ history_log["present_memory"] = self.memory.serialize_memory()
self.trace_log.append(history_log)
# Reflect the current instruction based on the current memory and instruction
- if_finished, present_instruction = self.reflector.reflect_query(self.memory, present_instruction)
+ if_finished, present_instruction = self.reflector.reflect_query(
+ self.memory, present_instruction
+ )
response = self.generator.generate(instruction, self.memory)
return response, self.trace_log
diff --git a/kag/solver/prompt/default/deduce_choice.py b/kag/solver/prompt/default/deduce_choice.py
index 6e488097..5f925f23 100644
--- a/kag/solver/prompt/default/deduce_choice.py
+++ b/kag/solver/prompt/default/deduce_choice.py
@@ -7,17 +7,20 @@
class DeduceEntail(PromptOp):
- template_zh = "根据提供的选项及相关答案,请选择其中一个选项回答问题“$instruction”。" \
- "无需解释;" \
- "如果没有可选择的选项,直接回复“无相关信息”无需解释" \
- "\n【信息】:“$memory”\n请确保所提供的信息直接准确地来自检索文档,不允许任何自身推测。"
- template_en = "Based on the provided options and related answers, choose one option to respond to the question '$instruction'." \
- "No explanation is needed;" \
- "If there are no available options, simply reply 'No relevant information' without explanation." \
- "\n[Information]: '$memory'" \
- "\nEnsure that the information provided comes directly and accurately from the retrieved document, " \
- "without any speculation."
-
+ template_zh = (
+ "根据提供的选项及相关答案,请选择其中一个选项回答问题“$instruction”。"
+ "无需解释;"
+ "如果没有可选择的选项,直接回复“无相关信息”无需解释"
+ "\n【信息】:“$memory”\n请确保所提供的信息直接准确地来自检索文档,不允许任何自身推测。"
+ )
+ template_en = (
+ "Based on the provided options and related answers, choose one option to respond to the question '$instruction'."
+ "No explanation is needed;"
+ "If there are no available options, simply reply 'No relevant information' without explanation."
+ "\n[Information]: '$memory'"
+ "\nEnsure that the information provided comes directly and accurately from the retrieved document, "
+ "without any speculation."
+ )
def __init__(self, language: str):
super().__init__(language)
@@ -27,7 +30,7 @@ def template_variables(self) -> List[str]:
return ["memory", "instruction"]
def parse_response_en(self, satisfied_info: str):
- if satisfied_info.startswith('No relevant information'):
+ if satisfied_info.startswith("No relevant information"):
if_answered = False
else:
if_answered = True
@@ -41,7 +44,7 @@ def parse_response_zh(self, satisfied_info: str):
return if_answered, satisfied_info
def parse_response(self, response: str, **kwargs):
- logger.debug('推理器判别:{}'.format(response))
+ logger.debug("推理器判别:{}".format(response))
if self.language == "en":
return self.parse_response_en(response)
return self.parse_response_zh(response)
diff --git a/kag/solver/prompt/default/deduce_entail.py b/kag/solver/prompt/default/deduce_entail.py
index dff35752..c3019683 100644
--- a/kag/solver/prompt/default/deduce_entail.py
+++ b/kag/solver/prompt/default/deduce_entail.py
@@ -7,21 +7,24 @@
class DeduceEntail(PromptOp):
- template_zh = "根据提供的信息,请首先判断是否能够直接回答指令“$instruction”。如果可以直接回答,请直接回复答案," \
- "无需解释;如果不能直接回答但存在关联信息,请总结其中与指令“$instruction”相关的关键信息,并明确解释为何与指令相关;" \
- "如果没有任何相关信息,直接回复“无相关信息”无需解释。" \
- "\n【信息】:“$memory”\n请确保所提供的信息直接准确地来自检索文档,不允许任何自身推测。"
- template_en = "Based on the provided information, first determine whether you can directly respond to the " \
- "instruction '$instruction'. If you can directly answer, " \
- "reply with the answer without any explanation;" \
- " if you cannot answer directly but there is related information, " \
- "summarize the key information related to the instruction '$instruction' " \
- "and clearly explain why it is related; " \
- "if there is no relevant information, simply reply 'No relevant information' without explanation." \
- "\n[Information]: '$memory'" \
- "\nEnsure that the information provided comes directly and accurately from the retrieved document, " \
- "without any speculation."
-
+ template_zh = (
+ "根据提供的信息,请首先判断是否能够直接回答指令“$instruction”。如果可以直接回答,请直接回复答案,"
+ "无需解释;如果不能直接回答但存在关联信息,请总结其中与指令“$instruction”相关的关键信息,并明确解释为何与指令相关;"
+ "如果没有任何相关信息,直接回复“无相关信息”无需解释。"
+ "\n【信息】:“$memory”\n请确保所提供的信息直接准确地来自检索文档,不允许任何自身推测。"
+ )
+ template_en = (
+ "Based on the provided information, first determine whether you can directly respond to the "
+ "instruction '$instruction'. If you can directly answer, "
+ "reply with the answer without any explanation;"
+ " if you cannot answer directly but there is related information, "
+ "summarize the key information related to the instruction '$instruction' "
+ "and clearly explain why it is related; "
+ "if there is no relevant information, simply reply 'No relevant information' without explanation."
+ "\n[Information]: '$memory'"
+ "\nEnsure that the information provided comes directly and accurately from the retrieved document, "
+ "without any speculation."
+ )
def __init__(self, language: str):
super().__init__(language)
@@ -31,7 +34,7 @@ def template_variables(self) -> List[str]:
return ["memory", "instruction"]
def parse_response_en(self, satisfied_info: str):
- if satisfied_info.startswith('No relevant information'):
+ if satisfied_info.startswith("No relevant information"):
if_answered = False
else:
if_answered = True
@@ -45,7 +48,7 @@ def parse_response_zh(self, satisfied_info: str):
return if_answered, satisfied_info
def parse_response(self, response: str, **kwargs):
- logger.debug('推理器判别:{}'.format(response))
+ logger.debug("推理器判别:{}".format(response))
if self.language == "en":
return self.parse_response_en(response)
return self.parse_response_zh(response)
diff --git a/kag/solver/prompt/default/deduce_judge.py b/kag/solver/prompt/default/deduce_judge.py
index 8bb8e3a6..6cb8a9e8 100644
--- a/kag/solver/prompt/default/deduce_judge.py
+++ b/kag/solver/prompt/default/deduce_judge.py
@@ -7,19 +7,22 @@
class DeduceJudge(PromptOp):
- template_zh = "根据提供的信息,请首先判断是否能够直接判断问题“$instruction”。如果可以直接回答,请直接根据提供信息对问题给出判断是或者否," \
- "无需解释;" \
- "如果没有任何相关信息,直接回复“无相关信息”无需解释。" \
- "\n【信息】:“$memory”\n请确保所提供的信息直接准确地来自检索文档,不允许任何自身推测。" \
- "\n【问题】:“$instruction”"
- template_en = "Based on the provided information, first determine if the question '$instruction' can be directly assessed. " \
- "If it can be directly answered, simply respond with Yes or No based on the provided information, no explanation needed;" \
- "If there is no relevant information, simply reply 'No relevant information' without explanation." \
- "\n[Information]: '$memory'" \
- "\nEnsure that the information provided comes directly and accurately from the retrieved document, " \
- "without any speculation."\
- "\n[Question]: '$instruction'"
-
+ template_zh = (
+ "根据提供的信息,请首先判断是否能够直接判断问题“$instruction”。如果可以直接回答,请直接根据提供信息对问题给出判断是或者否,"
+ "无需解释;"
+ "如果没有任何相关信息,直接回复“无相关信息”无需解释。"
+ "\n【信息】:“$memory”\n请确保所提供的信息直接准确地来自检索文档,不允许任何自身推测。"
+ "\n【问题】:“$instruction”"
+ )
+ template_en = (
+ "Based on the provided information, first determine if the question '$instruction' can be directly assessed. "
+ "If it can be directly answered, simply respond with Yes or No based on the provided information, no explanation needed;"
+ "If there is no relevant information, simply reply 'No relevant information' without explanation."
+ "\n[Information]: '$memory'"
+ "\nEnsure that the information provided comes directly and accurately from the retrieved document, "
+ "without any speculation."
+ "\n[Question]: '$instruction'"
+ )
def __init__(self, language: str):
super().__init__(language)
@@ -29,7 +32,7 @@ def template_variables(self) -> List[str]:
return ["memory", "instruction"]
def parse_response_en(self, satisfied_info: str):
- if satisfied_info.startswith('No relevant information'):
+ if satisfied_info.startswith("No relevant information"):
if_answered = False
else:
if_answered = True
@@ -43,7 +46,7 @@ def parse_response_zh(self, satisfied_info: str):
return if_answered, satisfied_info
def parse_response(self, response: str, **kwargs):
- logger.debug('推理器判别:{}'.format(response))
+ logger.debug("推理器判别:{}".format(response))
if self.language == "en":
return self.parse_response_en(response)
return self.parse_response_zh(response)
diff --git a/kag/solver/prompt/default/deduce_multi_choice.py b/kag/solver/prompt/default/deduce_multi_choice.py
index 158c21bc..f839a3c8 100644
--- a/kag/solver/prompt/default/deduce_multi_choice.py
+++ b/kag/solver/prompt/default/deduce_multi_choice.py
@@ -7,17 +7,20 @@
class DeduceEntail(PromptOp):
- template_zh = "根据提供的选项及相关答案,请选择其中至少一个选项回答问题“$instruction”。" \
- "无需解释;" \
- "如果没有可选择的选项,直接回复“无相关信息”无需解释" \
- "\n【信息】:“$memory”\n请确保所提供的信息直接准确地来自检索文档,不允许任何自身推测。"
- template_en = "Based on the provided options and related answers, choose at least one option to respond to the question '$instruction'." \
- "No explanation is needed;" \
- "If there are no available options, simply reply 'No relevant information' without explanation." \
- "\n[Information]: '$memory'" \
- "\nEnsure that the information provided comes directly and accurately from the retrieved document, " \
- "without any speculation."
-
+ template_zh = (
+ "根据提供的选项及相关答案,请选择其中至少一个选项回答问题“$instruction”。"
+ "无需解释;"
+ "如果没有可选择的选项,直接回复“无相关信息”无需解释"
+ "\n【信息】:“$memory”\n请确保所提供的信息直接准确地来自检索文档,不允许任何自身推测。"
+ )
+ template_en = (
+ "Based on the provided options and related answers, choose at least one option to respond to the question '$instruction'."
+ "No explanation is needed;"
+ "If there are no available options, simply reply 'No relevant information' without explanation."
+ "\n[Information]: '$memory'"
+ "\nEnsure that the information provided comes directly and accurately from the retrieved document, "
+ "without any speculation."
+ )
def __init__(self, language: str):
super().__init__(language)
@@ -27,7 +30,7 @@ def template_variables(self) -> List[str]:
return ["memory", "instruction"]
def parse_response_en(self, satisfied_info: str):
- if satisfied_info.startswith('No relevant information'):
+ if satisfied_info.startswith("No relevant information"):
if_answered = False
else:
if_answered = True
@@ -41,7 +44,7 @@ def parse_response_zh(self, satisfied_info: str):
return if_answered, satisfied_info
def parse_response(self, response: str, **kwargs):
- logger.debug('推理器判别:{}'.format(response))
+ logger.debug("推理器判别:{}".format(response))
if self.language == "en":
return self.parse_response_en(response)
return self.parse_response_zh(response)
diff --git a/kag/solver/prompt/default/logic_form_plan.py b/kag/solver/prompt/default/logic_form_plan.py
index 11f26605..990ed86f 100644
--- a/kag/solver/prompt/default/logic_form_plan.py
+++ b/kag/solver/prompt/default/logic_form_plan.py
@@ -127,7 +127,6 @@ def __init__(self, language: str):
def template_variables(self) -> List[str]:
return ["question"]
-
def parse_response(self, response: str, **kwargs):
try:
logger.debug(f"logic form:{response}")
@@ -135,17 +134,17 @@ def parse_response(self, response: str, **kwargs):
_output_string = response.strip()
sub_querys = []
logic_forms = []
- current_sub_query = ''
- for line in _output_string.split('\n'):
- if line.startswith('Step'):
- sub_querys_regex = re.search('Step\d+:(.*)', line)
+ current_sub_query = ""
+ for line in _output_string.split("\n"):
+ if line.startswith("Step"):
+ sub_querys_regex = re.search("Step\d+:(.*)", line)
if sub_querys_regex is not None:
sub_querys.append(sub_querys_regex.group(1))
current_sub_query = sub_querys_regex.group(1)
- elif line.startswith('Output'):
+ elif line.startswith("Output"):
sub_querys.append("output")
- elif line.startswith('Action'):
- logic_forms_regex = re.search('Action\d+:(.*)', line)
+ elif line.startswith("Action"):
+ logic_forms_regex = re.search("Action\d+:(.*)", line)
if logic_forms_regex:
logic_forms.append(logic_forms_regex.group(1))
if len(logic_forms) - len(sub_querys) == 1:
diff --git a/kag/solver/prompt/default/question_ner.py b/kag/solver/prompt/default/question_ner.py
index 8f6175e9..1004bc47 100644
--- a/kag/solver/prompt/default/question_ner.py
+++ b/kag/solver/prompt/default/question_ner.py
@@ -45,11 +45,11 @@ class QuestionNER(PromptOp):
template_zh = template_en
- def __init__(
- self, language: Optional[str] = "en", **kwargs
- ):
+ def __init__(self, language: Optional[str] = "en", **kwargs):
super().__init__(language, **kwargs)
- self.schema = ReasonerClient(project_id=self.project_id).get_reason_schema().keys()
+ self.schema = (
+ ReasonerClient(project_id=self.project_id).get_reason_schema().keys()
+ )
self.template = Template(self.template).safe_substitute(schema=self.schema)
@property
diff --git a/kag/solver/prompt/default/resp_extractor.py b/kag/solver/prompt/default/resp_extractor.py
index 724bd731..e05b9190 100644
--- a/kag/solver/prompt/default/resp_extractor.py
+++ b/kag/solver/prompt/default/resp_extractor.py
@@ -9,13 +9,17 @@
class RespExtractor(PromptOp):
- template_zh = "已知信息:\n$supporting_fact\n" \
- "你的任务是作为一名专业作家。你将仅根据提供的支持段落中的信息,撰写一段高质量的文章,以支持关于问题的给定预测。" \
- "现在,开始生成。在写完后,请输出[DONE]来表示已经完成任务。在生成段落时不要写前缀(例如:'Response:')。"\
- "\n问题:$instruction\n段落:"
- template_en = "Known information:\n $supporting_fact\nYour job is to act as a professional writer. " \
- "You will write a good-quality passage that can support the given prediction about the question only based on the information in the provided supporting passages. " \
- "Now, let's start. After you write, please write [DONE] to indicate you are done. Do not write a prefix (e.g., 'Response:'') while writing a passage.\nQuestion:$instruction\nPassage:"
+ template_zh = (
+ "已知信息:\n$supporting_fact\n"
+ "你的任务是作为一名专业作家。你将仅根据提供的支持段落中的信息,撰写一段高质量的文章,以支持关于问题的给定预测。"
+ "现在,开始生成。在写完后,请输出[DONE]来表示已经完成任务。在生成段落时不要写前缀(例如:'Response:')。"
+ "\n问题:$instruction\n段落:"
+ )
+ template_en = (
+ "Known information:\n $supporting_fact\nYour job is to act as a professional writer. "
+ "You will write a good-quality passage that can support the given prediction about the question only based on the information in the provided supporting passages. "
+ "Now, let's start. After you write, please write [DONE] to indicate you are done. Do not write a prefix (e.g., 'Response:'') while writing a passage.\nQuestion:$instruction\nPassage:"
+ )
def __init__(self, language: str):
super().__init__(language)
@@ -24,7 +28,6 @@ def __init__(self, language: str):
def template_variables(self) -> List[str]:
return ["supporting_fact", "instruction"]
-
def parse_response(self, response: str, **kwargs):
- logger.debug('推理器判别:{}'.format(response))
+ logger.debug("推理器判别:{}".format(response))
return response
diff --git a/kag/solver/prompt/default/resp_generator.py b/kag/solver/prompt/default/resp_generator.py
index 693e21d7..1dc08297 100644
--- a/kag/solver/prompt/default/resp_generator.py
+++ b/kag/solver/prompt/default/resp_generator.py
@@ -9,12 +9,14 @@
class RespGenerator(PromptOp):
- template_zh = "基于给定的引用信息回答问题。" \
- "\n输出答案,并且给出理由。" \
- "\n给定的引用信息:'$memory'\n问题:'$instruction'"
- template_en = "Answer the question based on the given reference." \
- "\nGive me the answer and why." \
- "\nThe following are given reference:'$memory'\nQuestion: '$instruction'"
+ template_zh = (
+ "基于给定的引用信息回答问题。" "\n输出答案,并且给出理由。" "\n给定的引用信息:'$memory'\n问题:'$instruction'"
+ )
+ template_en = (
+ "Answer the question based on the given reference."
+ "\nGive me the answer and why."
+ "\nThe following are given reference:'$memory'\nQuestion: '$instruction'"
+ )
def __init__(self, language: str):
super().__init__(language)
@@ -24,5 +26,5 @@ def template_variables(self) -> List[str]:
return ["memory", "instruction"]
def parse_response(self, response: str, **kwargs):
- logger.debug('推理器判别:{}'.format(response))
+ logger.debug("推理器判别:{}".format(response))
return response
diff --git a/kag/solver/prompt/default/resp_judge.py b/kag/solver/prompt/default/resp_judge.py
index e27c72f3..4b1a3905 100644
--- a/kag/solver/prompt/default/resp_judge.py
+++ b/kag/solver/prompt/default/resp_judge.py
@@ -9,14 +9,18 @@
class RespJudge(PromptOp):
- template_zh = "根据当前已知信息进行判断,不允许进行推理," \
- "你能否完全并准确地回答这个问题'$instruction'?\n已知信息:'$memory'。" \
- "\n如果你能,请直接回复‘是’\n如果不能且需要更多信息,请直接回复‘否’。"
- template_en = "Judging based solely on the current known information and without allowing for inference, " \
- "are you able to completely and accurately respond to the question '$instruction'? " \
- "\nKnown information: '$memory'. " \
- "\nIf you can, please reply with 'Yes' directly; " \
- "if you cannot and need more information, please reply with 'No' directly."
+ template_zh = (
+ "根据当前已知信息进行判断,不允许进行推理,"
+ "你能否完全并准确地回答这个问题'$instruction'?\n已知信息:'$memory'。"
+ "\n如果你能,请直接回复‘是’\n如果不能且需要更多信息,请直接回复‘否’。"
+ )
+ template_en = (
+ "Judging based solely on the current known information and without allowing for inference, "
+ "are you able to completely and accurately respond to the question '$instruction'? "
+ "\nKnown information: '$memory'. "
+ "\nIf you can, please reply with 'Yes' directly; "
+ "if you cannot and need more information, please reply with 'No' directly."
+ )
def __init__(self, language: str):
super().__init__(language)
@@ -26,7 +30,7 @@ def template_variables(self) -> List[str]:
return ["memory", "instruction"]
def parse_response_en(self, satisfied_info: str):
- if satisfied_info[:3] == 'Yes':
+ if satisfied_info[:3] == "Yes":
if_finished = True
else:
if_finished = False
@@ -40,7 +44,7 @@ def parse_response_zh(self, satisfied_info: str):
return if_finished
def parse_response(self, response: str, **kwargs):
- logger.debug('推理器判别:{}'.format(response))
+ logger.debug("推理器判别:{}".format(response))
if self.language == "en":
return self.parse_response_en(response)
return self.parse_response_zh(response)
diff --git a/kag/solver/prompt/default/resp_reflector.py b/kag/solver/prompt/default/resp_reflector.py
index 9c186122..2bf44c0a 100644
--- a/kag/solver/prompt/default/resp_reflector.py
+++ b/kag/solver/prompt/default/resp_reflector.py
@@ -7,13 +7,17 @@
class RespRewriter(PromptOp):
- template_zh = "你是一个智能助手,擅长通过复杂的、多跳的推理帮助用户在多文档中获取信息。请理解当前已知信息与目标问题之间的信息差。" \
- "你的任务是直接生成一个用于下一步检索的思考问题。" \
- "不要一次性生成所有思考过程!\n[已知信息]: $memory\n[目标问题]:$instruction\n[你的思考]:"
- template_en = "You serve as an intelligent assistant, adept at facilitating users through complex, " \
- "multi-hop reasoning across multiple documents. Please understand the information gap between the currently known information and the target problem." \
- "Your task is to generate one thought in the form of question for next retrieval step directly. " \
- "DON\'T generate the whole thoughts at once!\n[Known information]: $memory\n[Target question]: $instruction\n[You Thought]:"
+ template_zh = (
+ "你是一个智能助手,擅长通过复杂的、多跳的推理帮助用户在多文档中获取信息。请理解当前已知信息与目标问题之间的信息差。"
+ "你的任务是直接生成一个用于下一步检索的思考问题。"
+ "不要一次性生成所有思考过程!\n[已知信息]: $memory\n[目标问题]:$instruction\n[你的思考]:"
+ )
+ template_en = (
+ "You serve as an intelligent assistant, adept at facilitating users through complex, "
+ "multi-hop reasoning across multiple documents. Please understand the information gap between the currently known information and the target problem."
+ "Your task is to generate one thought in the form of question for next retrieval step directly. "
+ "DON'T generate the whole thoughts at once!\n[Known information]: $memory\n[Target question]: $instruction\n[You Thought]:"
+ )
def __init__(self, language: str):
super().__init__(language)
@@ -26,26 +30,26 @@ def parse_response_en(self, response: str):
update_reason_path = []
split_path = response.split("\n")
for p in split_path:
- if 'Here are the steps' in p or p == '\n' or p == '':
+ if "Here are the steps" in p or p == "\n" or p == "":
continue
else:
update_reason_path.append(p)
- logger.debug('cur path:{}'.format(str(update_reason_path)))
+ logger.debug("cur path:{}".format(str(update_reason_path)))
return update_reason_path
def parse_response_zh(self, response: str):
update_reason_path = []
split_path = response.split("\n")
for p in split_path:
- if '步骤为' in p or p == '\n' or p == '':
+ if "步骤为" in p or p == "\n" or p == "":
continue
else:
update_reason_path.append(p)
- logger.debug('cur path:{}'.format(str(update_reason_path)))
+ logger.debug("cur path:{}".format(str(update_reason_path)))
return update_reason_path
def parse_response(self, response: str, **kwargs):
- logger.debug('infer result:{}'.format(response))
+ logger.debug("infer result:{}".format(response))
if self.language == "en":
return self.parse_response_en(response)
return self.parse_response_zh(response)
diff --git a/kag/solver/prompt/default/resp_verifier.py b/kag/solver/prompt/default/resp_verifier.py
index 600dd111..dbd4cfc1 100644
--- a/kag/solver/prompt/default/resp_verifier.py
+++ b/kag/solver/prompt/default/resp_verifier.py
@@ -9,30 +9,35 @@
class RespVerifier(PromptOp):
- template_zh = "仅根据当前已知的信息,并且不允许进行推理," \
- "你能否完全并准确地回答这个问题'$sub_instruction'?\n已知信息:'$supporting_fact'。" \
- "\n如果你能,请直接回复‘是’,并给出问题'$sub_instruction'的答案,无需重复问题;如果不可以,请直接回答'否'。"
- template_en = "Judging based solely on the current known information and without allowing for inference, " \
- "are you able to respond completely and accurately to the question '$sub_instruction'? \n" \
- "Known information: '$supporting_fact'. If yes, please reply with 'Yes', followed by an accurate response to the question '$sub_instruction', " \
- "without restating the question; if no, please reply with 'No' directly."
+ template_zh = (
+ "仅根据当前已知的信息,并且不允许进行推理,"
+ "你能否完全并准确地回答这个问题'$sub_instruction'?\n已知信息:'$supporting_fact'。"
+ "\n如果你能,请直接回复‘是’,并给出问题'$sub_instruction'的答案,无需重复问题;如果不可以,请直接回答'否'。"
+ )
+ template_en = (
+ "Judging based solely on the current known information and without allowing for inference, "
+ "are you able to respond completely and accurately to the question '$sub_instruction'? \n"
+ "Known information: '$supporting_fact'. If yes, please reply with 'Yes', followed by an accurate response to the question '$sub_instruction', "
+ "without restating the question; if no, please reply with 'No' directly."
+ )
def __init__(self, language: str):
super().__init__(language)
-
@property
def template_variables(self) -> List[str]:
return ["sub_instruction", "supporting_fact"]
def parse_response_en(self, satisfied_info: str):
- if satisfied_info[:3] == 'Yes':
+ if satisfied_info[:3] == "Yes":
satisfied = True
else:
satisfied = False
if satisfied:
- satisfied_info = satisfied_info.replace('Yes', '').strip()
- res = "The answer to the Question'{}' is '{}'".format(self.template_variables_value["sub_instruction"], satisfied_info)
+ satisfied_info = satisfied_info.replace("Yes", "").strip()
+ res = "The answer to the Question'{}' is '{}'".format(
+ self.template_variables_value["sub_instruction"], satisfied_info
+ )
return res
return None
@@ -42,13 +47,15 @@ def parse_response_zh(self, satisfied_info: str):
else:
satisfied = False
if satisfied:
- satisfied_info = satisfied_info.replace('是', '').strip()
- res = "问题'{}' 的答案是 '{}'".format(self.template_variables_value["sub_instruction"], satisfied_info)
+ satisfied_info = satisfied_info.replace("是", "").strip()
+ res = "问题'{}' 的答案是 '{}'".format(
+ self.template_variables_value["sub_instruction"], satisfied_info
+ )
return res
return None
def parse_response(self, response: str, **kwargs):
- logger.debug('推理器判别:{}'.format(response))
+ logger.debug("推理器判别:{}".format(response))
if self.language == "en":
return self.parse_response_en(response)
return self.parse_response_zh(response)
diff --git a/kag/solver/prompt/default/solve_question_without_spo.py b/kag/solver/prompt/default/solve_question_without_spo.py
index 82ef9409..79974abb 100644
--- a/kag/solver/prompt/default/solve_question_without_spo.py
+++ b/kag/solver/prompt/default/solve_question_without_spo.py
@@ -41,7 +41,7 @@ def __init__(self, language: str):
@property
def template_variables(self) -> List[str]:
- return ["history", "question", "docs"]
+ return ["history", "question", "docs"]
def parse_response(self, response: str, **kwargs):
return response
diff --git a/kag/solver/prompt/default/spo_retrieval.py b/kag/solver/prompt/default/spo_retrieval.py
index 3f3a7384..fc3cd0ba 100644
--- a/kag/solver/prompt/default/spo_retrieval.py
+++ b/kag/solver/prompt/default/spo_retrieval.py
@@ -66,7 +66,7 @@ def template_variables(self) -> List[str]:
return ["question", "mention", "candis"]
def parse_response_en(self, satisfied_info: str):
- if satisfied_info[:3] == 'Yes':
+ if satisfied_info[:3] == "Yes":
if_finished = True
else:
if_finished = False
@@ -82,7 +82,8 @@ def parse_response_zh(self, satisfied_info: str):
def parse_response(self, response: str, **kwargs):
logger.debug(
f"SpoRetrieval {response} mention:{self.template_variables_value.get('mention', '')} "
- f"candis:{self.template_variables_value.get('candis', '')}")
- llm_output = response.replace('Expected Output:', '')
- llm_output = llm_output.replace('"', '')
+ f"candis:{self.template_variables_value.get('candis', '')}"
+ )
+ llm_output = response.replace("Expected Output:", "")
+ llm_output = llm_output.replace('"', "")
return llm_output.strip()
diff --git a/kag/solver/prompt/lawbench/logic_form_plan.py b/kag/solver/prompt/lawbench/logic_form_plan.py
index f067ea19..7fbde78c 100644
--- a/kag/solver/prompt/lawbench/logic_form_plan.py
+++ b/kag/solver/prompt/lawbench/logic_form_plan.py
@@ -5,7 +5,6 @@
logger = logging.getLogger(__name__)
-
class LawLogicFormPlanPrompt(LogicFormPlanPrompt):
default_case_zh = """"cases": [
{
diff --git a/kag/solver/prompt/medical/question_ner.py b/kag/solver/prompt/medical/question_ner.py
index 3eb8ea9d..9dd1c4e1 100644
--- a/kag/solver/prompt/medical/question_ner.py
+++ b/kag/solver/prompt/medical/question_ner.py
@@ -55,9 +55,7 @@ class QuestionNER(PromptOp):
template_en = template_zh
- def __init__(
- self, language: Optional[str] = "en", **kwargs
- ):
+ def __init__(self, language: Optional[str] = "en", **kwargs):
super().__init__(language, **kwargs)
self.schema = SchemaClient(project_id=self.project_id).extract_types()
self.template = Template(self.template).safe_substitute(schema=self.schema)
diff --git a/kag/solver/tools/info_processor.py b/kag/solver/tools/info_processor.py
index f9b1cc9b..7d7d2c6c 100644
--- a/kag/solver/tools/info_processor.py
+++ b/kag/solver/tools/info_processor.py
@@ -10,6 +10,8 @@
from knext.reasoner.rest.reasoner_api import ReasonerApi
logger = logging.getLogger(__name__)
+
+
class ReporterIntermediateProcessTool:
class STATE(str, Enum):
WAITING = "WAITING"
@@ -24,7 +26,8 @@ def __init__(self, report_log=False, task_id=None, project_id=None, host_addr=No
self.task_id = task_id
self.project_id = project_id
self.client: ReasonerApi = ReasonerApi(
- api_client=ApiClient(configuration=Configuration(host=host_addr)))
+ api_client=ApiClient(configuration=Configuration(host=host_addr))
+ )
def report_pipeline(self, question, rewrite_question_list=[]):
# print(question)
@@ -35,10 +38,26 @@ def report_pipeline(self, question, rewrite_question_list=[]):
pipeline = CaPipeline()
pipeline.nodes = []
pipeline.edges = []
- pipeline.nodes.append(Node(id=self.ROOT_ID, state=self.STATE.WAITING, question=question.question, answer=None, logs=None))
+ pipeline.nodes.append(
+ Node(
+ id=self.ROOT_ID,
+ state=self.STATE.WAITING,
+ question=question.question,
+ answer=None,
+ logs=None,
+ )
+ )
dep_question_list = []
for item in rewrite_question_list:
- pipeline.nodes.append(Node(id=item.id, state=self.STATE.WAITING, question=item.question, answer=None, logs=None))
+ pipeline.nodes.append(
+ Node(
+ id=item.id,
+ state=self.STATE.WAITING,
+ question=item.question,
+ answer=None,
+ logs=None,
+ )
+ )
if item.dependencies:
for dep_item in item.dependencies:
pipeline.edges.append(Edge(_from=dep_item.id, to=item.id))
@@ -54,12 +73,25 @@ def report_pipeline(self, question, rewrite_question_list=[]):
if node.id not in to_list:
first_nodes.append(node.id)
# str([n.question for n in pipeline.nodes if n.id != self.ROOT_ID])
- pipeline.nodes.insert(0, Node(id=1, state=self.STATE.FINISH, question=question.question, answer=str([n.question for n in pipeline.nodes if n.id != self.ROOT_ID]), logs=None))
+ pipeline.nodes.insert(
+ 0,
+ Node(
+ id=1,
+ state=self.STATE.FINISH,
+ question=question.question,
+ answer=str(
+ [n.question for n in pipeline.nodes if n.id != self.ROOT_ID]
+ ),
+ logs=None,
+ ),
+ )
for n in first_nodes:
pipeline.edges.insert(0, Edge(_from=1, to=n))
request = ReportPipelineRequest(task_id=self.task_id, pipeline=pipeline)
if self.report_log:
- self.client.reasoner_dialog_report_pipeline_post(report_pipeline_request=request)
+ self.client.reasoner_dialog_report_pipeline_post(
+ report_pipeline_request=request
+ )
else:
logger.info(request)
@@ -67,11 +99,18 @@ def report_node(self, question, answer, state):
logs = self.format_logs(question.context)
if not question.id:
question.id = self.ROOT_ID
- node = Node(id=(question.id+1 if question.id != 0 else 0), state=state, question=question.question, answer=answer,
- logs=logs)
+ node = Node(
+ id=(question.id + 1 if question.id != 0 else 0),
+ state=state,
+ question=question.question,
+ answer=answer,
+ logs=logs,
+ )
request = ReportPipelineRequest(task_id=self.task_id, node=node)
if self.report_log:
- self.client.reasoner_dialog_report_node_post(report_pipeline_request=request)
+ self.client.reasoner_dialog_report_node_post(
+ report_pipeline_request=request
+ )
else:
logger.info(request)
diff --git a/kag/templates/project/builder/__init__.py b/kag/templates/project/builder/__init__.py
index 94be39bc..7a018e7c 100644
--- a/kag/templates/project/builder/__init__.py
+++ b/kag/templates/project/builder/__init__.py
@@ -11,4 +11,4 @@
"""
Builder Dir.
-"""
\ No newline at end of file
+"""
diff --git a/kag/templates/project/builder/indexer.py b/kag/templates/project/builder/indexer.py
index f9e16285..6f6914a4 100644
--- a/kag/templates/project/builder/indexer.py
+++ b/kag/templates/project/builder/indexer.py
@@ -8,4 +8,3 @@
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
-
diff --git a/kag/templates/project/builder/prompt/__init__.py b/kag/templates/project/builder/prompt/__init__.py
index 247bb44c..ba7d5d56 100644
--- a/kag/templates/project/builder/prompt/__init__.py
+++ b/kag/templates/project/builder/prompt/__init__.py
@@ -11,4 +11,4 @@
"""
Place the prompts to be used for building the index in this directory.
-"""
\ No newline at end of file
+"""
diff --git a/kag/templates/project/reasoner/__init__.py b/kag/templates/project/reasoner/__init__.py
index a0c4032b..8b8a3c91 100644
--- a/kag/templates/project/reasoner/__init__.py
+++ b/kag/templates/project/reasoner/__init__.py
@@ -17,4 +17,4 @@
MATCH (s:DEFAULT.Company)
RETURN s.id, s.address
```
-"""
\ No newline at end of file
+"""
diff --git a/kag/templates/project/schema/__init__.py b/kag/templates/project/schema/__init__.py
index ef3dde6d..8ac86acc 100644
--- a/kag/templates/project/schema/__init__.py
+++ b/kag/templates/project/schema/__init__.py
@@ -15,4 +15,4 @@
You can execute `kag schema commit` to commit your schema to SPG server.
-"""
\ No newline at end of file
+"""
diff --git a/kag/templates/project/solver/prompt/__init__.py b/kag/templates/project/solver/prompt/__init__.py
index dadd42a3..dfa931cd 100644
--- a/kag/templates/project/solver/prompt/__init__.py
+++ b/kag/templates/project/solver/prompt/__init__.py
@@ -11,4 +11,4 @@
"""
Place the prompts to be used for solving problems in this directory.
-"""
\ No newline at end of file
+"""
diff --git a/setup.cfg b/setup.cfg
new file mode 100644
index 00000000..47773b7a
--- /dev/null
+++ b/setup.cfg
@@ -0,0 +1,17 @@
+[flake8]
+max-line-length = 120
+ignore = E203, E266, E501, W503, W291, C901, E722
+select = C,E,F,W,B,B950
+exclude =
+ .git,
+ __pycache__,
+ setup.py,
+ build,
+ dist,
+ kag/examples/*.py,
+ kag/common/arks_pb2.py,
+ kag/solver/*.py,
+per-file-ignores =
+ __init__.py: F401
+ tests/*: F811
+max-complexity = 10
\ No newline at end of file
diff --git a/tests/common/registry/test_registry.py b/tests/common/registry/test_registry.py
index 19a0555b..0ead581a 100644
--- a/tests/common/registry/test_registry.py
+++ b/tests/common/registry/test_registry.py
@@ -2,7 +2,7 @@
from typing import List, Dict, Union
from pyhocon import ConfigTree, ConfigFactory
-from kag_ant.common.registry import Registrable, Lazy, Functor
+from kag.common.registry import Registrable, Lazy, Functor
import numpy as np
@@ -13,7 +13,7 @@ def __init__(self, name: str = "mock_model"):
@MockModel.register("Simple")
class Simple(MockModel):
- def __init__(self, name, age=999):
+ def __init__(self, name, age=None):
pass