From 6a09a94717a13286228287b5d552d10656ce00ef Mon Sep 17 00:00:00 2001 From: zhuzhongshu123 Date: Wed, 13 Nov 2024 17:32:17 +0800 Subject: [PATCH] add component of load external graph and postprocess subgraphs --- kag/builder/component/__init__.py | 15 +- .../{kag_post_processor.py => kag_aligner.py} | 2 +- .../{spg_post_processor.py => spg_aligner.py} | 2 +- .../component/external_graph/__init__.py | 0 .../external_graph/external_graph.py | 185 +++++++++++++++++ .../component/extractor/kag_extractor.py | 85 ++++---- .../component/postprocessor/__init__.py | 0 .../postprocessor/kag_postprocessor.py | 106 ++++++++++ .../component/vectorizer/batch_vectorizer.py | 46 ++--- kag/builder/component/writer/kg_writer.py | 1 + kag/builder/model/sub_graph.py | 4 +- kag/common/conf.py | 57 ++++-- kag/common/env.py | 143 +++++++++++++ kag/common/llm/mock_llm.py | 69 +++++++ kag/common/utils.py | 6 + kag/examples/musique/builder/indexer.py | 4 +- kag/interface/__init__.py | 9 +- kag/interface/builder/external_graph_abc.py | 58 ++++++ kag/interface/builder/postprocessor_abc.py | 35 ++++ .../builder/component/test_external_graph.py | 82 ++++++++ .../builder/component/test_post_processor.py | 88 ++++++++ tests/builder/data/edges.json | 154 ++++++++++++++ tests/builder/data/nodes.json | 192 ++++++++++++++++++ tests/builder/kag_config.yaml | 23 +++ 24 files changed, 1258 insertions(+), 108 deletions(-) rename kag/builder/component/aligner/{kag_post_processor.py => kag_aligner.py} (97%) rename kag/builder/component/aligner/{spg_post_processor.py => spg_aligner.py} (99%) create mode 100644 kag/builder/component/external_graph/__init__.py create mode 100644 kag/builder/component/external_graph/external_graph.py create mode 100644 kag/builder/component/postprocessor/__init__.py create mode 100644 kag/builder/component/postprocessor/kag_postprocessor.py create mode 100644 kag/common/env.py create mode 100644 kag/common/llm/mock_llm.py create mode 100644 kag/interface/builder/external_graph_abc.py create mode 100644 kag/interface/builder/postprocessor_abc.py create mode 100644 tests/builder/component/test_external_graph.py create mode 100644 tests/builder/data/edges.json create mode 100644 tests/builder/data/nodes.json create mode 100644 tests/builder/kag_config.yaml diff --git a/kag/builder/component/__init__.py b/kag/builder/component/__init__.py index 18df97af..80d211ce 100644 --- a/kag/builder/component/__init__.py +++ b/kag/builder/component/__init__.py @@ -10,10 +10,15 @@ # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. +from kag.builder.component.external_graph.external_graph import ( + DefaultExternalGraphLoader, +) from kag.builder.component.extractor.kag_extractor import KAGExtractor from kag.builder.component.extractor.spg_extractor import SPGExtractor -from kag.builder.component.aligner.kag_post_processor import KAGPostProcessorAligner -from kag.builder.component.aligner.spg_post_processor import SPGPostProcessorAligner +from kag.builder.component.aligner.kag_aligner import KAGAligner +from kag.builder.component.aligner.spg_aligner import SPGAligner +from kag.builder.component.postprocessor.kag_postprocessor import KAGPostProcessor + from kag.builder.component.mapping.spg_type_mapping import SPGTypeMapping from kag.builder.component.mapping.relation_mapping import RelationMapping from kag.builder.component.mapping.spo_mapping import SPOMapping @@ -38,10 +43,12 @@ __all__ = [ + "DefaultExternalGraphLoader", "KAGExtractor", "SPGExtractor", - "KAGPostProcessorAligner", - "SPGPostProcessorAligner", + "KAGAligner", + "SPGAligner", + "KAGPostProcessor", "KGWriter", "SPGTypeMapping", "RelationMapping", diff --git a/kag/builder/component/aligner/kag_post_processor.py b/kag/builder/component/aligner/kag_aligner.py similarity index 97% rename from kag/builder/component/aligner/kag_post_processor.py rename to kag/builder/component/aligner/kag_aligner.py index eb43269b..328d1aa7 100644 --- a/kag/builder/component/aligner/kag_post_processor.py +++ b/kag/builder/component/aligner/kag_aligner.py @@ -18,7 +18,7 @@ @AlignerABC.register("kag") -class KAGPostProcessorAligner(AlignerABC): +class KAGAligner(AlignerABC): def __init__(self, **kwargs): super().__init__(**kwargs) diff --git a/kag/builder/component/aligner/spg_post_processor.py b/kag/builder/component/aligner/spg_aligner.py similarity index 99% rename from kag/builder/component/aligner/spg_post_processor.py rename to kag/builder/component/aligner/spg_aligner.py index 659759b3..7196ea63 100644 --- a/kag/builder/component/aligner/spg_post_processor.py +++ b/kag/builder/component/aligner/spg_aligner.py @@ -23,7 +23,7 @@ @AlignerABC.register("spg") -class SPGPostProcessorAligner(AlignerABC): +class SPGAligner(AlignerABC): def __init__(self, **kwargs): super().__init__(**kwargs) self.spg_types = SchemaClient(project_id=KAG_PROJECT_CONF.project_id).load() diff --git a/kag/builder/component/external_graph/__init__.py b/kag/builder/component/external_graph/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/kag/builder/component/external_graph/external_graph.py b/kag/builder/component/external_graph/external_graph.py new file mode 100644 index 00000000..df0849a3 --- /dev/null +++ b/kag/builder/component/external_graph/external_graph.py @@ -0,0 +1,185 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 OpenSPG Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License +# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. +import jieba +import json +import numpy as np +import logging +from typing import List, Union, Dict +from kag.interface import ExternalGraphLoaderABC, MatchConfig +from kag.common.conf import KAG_PROJECT_CONF +from kag.builder.model.sub_graph import Node, Edge, SubGraph +from knext.schema.client import SchemaClient + +from knext.search.client import SearchClient + + +logger = logging.getLogger() + + +@ExternalGraphLoaderABC.register("base", constructor="from_json_file", as_default=True) +class DefaultExternalGraphLoader(ExternalGraphLoaderABC): + def __init__( + self, + nodes: List[Node], + edges: List[Edge], + match_config: MatchConfig, + ): + + self.schema = SchemaClient(project_id=KAG_PROJECT_CONF.project_id).load() + for node in nodes: + if node.label not in self.schema: + raise ValueError( + f"Type of node {node.to_dict()} is beyond the schema definition." + ) + for k in node.properties.keys(): + if k not in self.schema[node.label]: + raise ValueError( + f"Property of node {node.to_dict()} is beyond the schema definition." + ) + self.nodes = nodes + self.edges = edges + + self.vocabulary = {} + self.node_labels = set() + for node in self.nodes: + self.vocabulary[node.name] = node + self.node_labels.add(node.label) + + for word in self.vocabulary.keys(): + jieba.add_word(word) + + self.match_config = match_config + self._init_search() + + def _init_search(self): + self._search_client = SearchClient( + KAG_PROJECT_CONF.host_addr, KAG_PROJECT_CONF.project_id + ) + + def _group_by_label(self, data: Union[List[Node], List[Edge]]): + groups = {} + + for item in data: + label = item.label + if label not in groups: + groups[label] = [item] + else: + groups[label].append(item) + return list(groups.values()) + + def _group_by_cnt(self, data, n): + return [data[i : i + n] for i in range(0, len(data), n)] + + def dump(self, max_num_nodes: int = 4096, max_num_edges: int = 4096): + graphs = [] + # process nodes + for item in self._group_by_label(self.nodes): + for grouped_nodes in self._group_by_cnt(item, max_num_nodes): + graphs.append(SubGraph(nodes=grouped_nodes, edges=[])) + + # process edges + for item in self._group_by_label(self.edges): + for grouped_edges in self._group_by_cnt(item, max_num_edges): + graphs.append(SubGraph(nodes=[], edges=grouped_edges)) + + return graphs + + def ner(self, content: str): + output = [] + for word in jieba.cut(content): + if word in self.vocabulary: + output.append(self.vocabulary[word]) + return output + + def get_allowed_labels(self, labels: List[str] = None): + allowed_labels = [] + + namespace = KAG_PROJECT_CONF.namespace + if labels is None: + allowed_labels = [f"{namespace}.{x}" for x in self.node_labels] + else: + for label in labels: + # remove namespace + if label.startswith(KAG_PROJECT_CONF.namespace): + label = label.split(".")[1] + if label in self.node_labels: + allowed_labels.append(f"{namespace}.{label}") + return allowed_labels + + def search_result_to_node(self, search_result: Dict): + output = [] + for label in search_result["__labels__"]: + node = { + "id": search_result["id"], + "name": search_result["name"], + "label": label, + } + output.append(Node.from_dict(node)) + return output + + def text_match(self, query: str, k: int = 1, labels: List[str] = None): + allowed_labels = self.get_allowed_labels(labels) + text_matched = self._search_client.search_text(query, allowed_labels, topk=k) + return text_matched + + def vector_match( + self, + query: Union[List[float], np.ndarray], + k: int = 1, + threshold: float = 0.9, + labels: List[str] = None, + ): + allowed_labels = self.get_allowed_labels(labels) + if isinstance(query, np.ndarray): + query = query.tolist() + matched_results = [] + for label in allowed_labels: + vector_matched = self._search_client.search_vector( + label=label, property_key="name", query_vector=query, topk=k + ) + matched_results.extend(vector_matched) + + filtered_results = [] + for item in matched_results: + score = item["score"] + if score >= threshold: + filtered_results.append(item) + return filtered_results + + def match_entity(self, query: Union[str, List[float], np.ndarray]): + if isinstance(query, str): + return self.text_match( + query, k=self.match_config.k, labels=self.match_config.labels + ) + else: + return self.vector_match( + query, + k=self.match_config.k, + labels=self.match_config.labels, + threshold=self.match_config.threshold, + ) + + @classmethod + def from_json_file( + cls, + node_file_path: str, + edge_file_path, + match_config: MatchConfig, + ): + + nodes = [] + for item in json.load(open(node_file_path, "r")): + nodes.append(Node.from_dict(item)) + edges = [] + for item in json.load(open(edge_file_path, "r")): + edges.append(Edge.from_dict(item)) + return cls(nodes=nodes, edges=edges, match_config=match_config) diff --git a/kag/builder/component/extractor/kag_extractor.py b/kag/builder/component/extractor/kag_extractor.py index e959f1d1..6c8055a1 100644 --- a/kag/builder/component/extractor/kag_extractor.py +++ b/kag/builder/component/extractor/kag_extractor.py @@ -16,16 +16,15 @@ from kag.common.llm.llm_client import LLMClient from tenacity import stop_after_attempt, retry -from kag.builder.prompt.spg_prompt import SPG_KGPrompt -from kag.interface import ExtractorABC, PromptABC -from knext.schema.client import OTHER_TYPE, CHUNK_TYPE, BASIC_TYPES +from kag.interface import ExtractorABC, PromptABC, ExternalGraphLoaderABC + from kag.common.conf import KAG_PROJECT_CONF from kag.common.utils import processing_phrases, to_camel_case from kag.builder.model.chunk import Chunk from kag.builder.model.sub_graph import SubGraph +from knext.schema.client import OTHER_TYPE, CHUNK_TYPE, BASIC_TYPES from knext.common.base.runnable import Input, Output from knext.schema.client import SchemaClient -from knext.schema.model.base import SpgTypeEnum logger = logging.getLogger(__name__) @@ -43,6 +42,7 @@ def __init__( ner_prompt: PromptABC = None, std_prompt: PromptABC = None, triple_prompt: PromptABC = None, + external_graph: ExternalGraphLoaderABC = None, ): self.llm = llm print(f"self.llm: {self.llm}") @@ -64,43 +64,7 @@ def __init__( self.triple_prompt = PromptABC.from_config( {"type": f"{biz_scene}_triple", "language": KAG_PROJECT_CONF.language} ) - self.create_extra_prompts() - - def create_extra_prompts(self): - self.kg_types = [] - for type_name, spg_type in self.schema.items(): - if type_name in SPG_KGPrompt.ignored_types: - continue - if spg_type.spg_type_enum == SpgTypeEnum.Concept: - continue - properties = list(spg_type.properties.keys()) - for p in properties: - if p not in SPG_KGPrompt.ignored_properties: - self.kg_types.append(type_name) - break - if self.kg_types: - self.kg_prompt = SPG_KGPrompt( - self.kg_types, - language=KAG_PROJECT_CONF.language, - project_id=KAG_PROJECT_CONF.project_id, - ) - else: - self.kg_prompt = None - - # @classmethod - # def initialize( - # llm: LLMClient, - # ner_prompt: PromptABC = None, - # std_prompt: PromptABC = None, - # triple_prompt: PromptABC = None, - # ): - # print(f"llm = {llm}") - # print(ner_prompt) - # print(std_prompt) - # print(triple_prompt) - # extractor = KAGExtractor(llm, ner_prompt, std_prompt, triple_prompt) - # extractor.create_extra_prompts() - # return extractor + self.external_graph = external_graph @property def input_types(self) -> Type[Input]: @@ -119,12 +83,34 @@ def named_entity_recognition(self, passage: str): Returns: The result of the named entity recognition operation. """ - if self.kg_types: - kg_result = self.llm.invoke({"input": passage}, self.kg_prompt) - else: - kg_result = [] ner_result = self.llm.invoke({"input": passage}, self.ner_prompt) - return kg_result + ner_result + if self.external_graph: + extra_ner_result = self.external_graph.ner(passage) + else: + extra_ner_result = [] + output = [] + dedup = set() + for item in extra_ner_result: + name = item.name + label = item.label + description = item.properties.get("desc", "") + semantic_type = item.properties.get("semanticType", label) + if name not in dedup: + dedup.add(name) + output.append( + { + "entity": name, + "type": semantic_type, + "category": label, + "description": description, + } + ) + for item in ner_result: + name = item.get("entity", None) + if name and name not in dedup: + dedup.add(name) + output.append(item) + return output @retry(stop=stop_after_attempt(3)) def named_entity_standardization(self, passage: str, entities: List[Dict]): @@ -357,9 +343,10 @@ def invoke(self, input: Input, **kwargs) -> List[Output]: Returns: List[Output]: A list of processed results, containing subgraph information. """ + title = input.name passage = title + "\n" + input.content - + out = [] try: entities = self.named_entity_recognition(passage) sub_graph, entities = self.assemble_sub_graph_with_spg_records(entities) @@ -371,10 +358,10 @@ def invoke(self, input: Input, **kwargs) -> List[Output]: std_entities = self.named_entity_standardization(passage, filtered_entities) self.append_official_name(entities, std_entities) self.assemble_sub_graph(sub_graph, input, entities, triples) - return [sub_graph] + out.append(sub_graph) except Exception as e: import traceback traceback.print_exc() logger.info(e) - return [] + return out diff --git a/kag/builder/component/postprocessor/__init__.py b/kag/builder/component/postprocessor/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/kag/builder/component/postprocessor/kag_postprocessor.py b/kag/builder/component/postprocessor/kag_postprocessor.py new file mode 100644 index 00000000..51c2c33b --- /dev/null +++ b/kag/builder/component/postprocessor/kag_postprocessor.py @@ -0,0 +1,106 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 OpenSPG Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License +# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. +from typing import List +from kag.interface import PostProcessorABC +from kag.interface import ExternalGraphLoaderABC +from kag.builder.model.sub_graph import SubGraph +from kag.common.conf import KAGConstants, KAG_PROJECT_CONF +from kag.common.utils import get_vector_field_name +from knext.search.client import SearchClient +from knext.schema.client import SchemaClient + + +@PostProcessorABC.register("base", as_default=True) +class KAGPostProcessor(PostProcessorABC): + def __init__( + self, + similarity_threshold: float = 0.9, + external_graph: ExternalGraphLoaderABC = None, + ): + self.schema = SchemaClient(project_id=KAG_PROJECT_CONF.project_id).load() + self.similarity_threshold = similarity_threshold + self.external_graph = external_graph + self._init_search() + + def format_label(self, label: str): + namespace = KAG_PROJECT_CONF.namespace + if label.startswith(namespace): + return label + return f"{namespace}.{label}" + + def _init_search(self): + self._search_client = SearchClient( + KAG_PROJECT_CONF.host_addr, KAG_PROJECT_CONF.project_id + ) + + def filter_invalid_data(self, graph: SubGraph): + valid_nodes = [] + valid_edges = [] + for node in graph.nodes: + if not node.id or not node.label: + continue + if node.label not in self.schema: + continue + # for k in node.properties.keys(): + # if k not in self.schema[node.label]: + # continue + valid_nodes.append(node) + for edge in graph.edges: + if edge.label: + valid_edges.append(edge) + return SubGraph(nodes=valid_nodes, edges=valid_edges) + + def _entity_link( + self, graph: SubGraph, property_key: str = "name", labels: List[str] = None + ): + vector_field_name = get_vector_field_name(property_key) + for node in graph.nodes: + if labels is None: + link_labels = [self.format_label(node.label)] + else: + link_labels = [self.format_label(x) for x in labels] + vector = node.properties.get(vector_field_name) + if vector: + all_similar_nodes = [] + for label in link_labels: + similar_nodes = self._search_client.search_vector( + label=label, + property_key=property_key, + query_vector=vector, + topk=1, + ) + all_similar_nodes.extend(similar_nodes) + for item in all_similar_nodes: + score = item["score"] + if score >= self.similarity_threshold: + graph.add_edge( + node.id, + node.label, + KAGConstants.KAG_SIMILAR_EDGE_NAME, + item["node"]["id"], + item["node"]["__labels__"][0], + ) + + def similarity_based_link(self, graph: SubGraph, property_key: str = "name"): + self._entity_link(graph, property_key, None) + + def external_graph_based_link(self, graph: SubGraph, property_key: str = "name"): + if not self.external_graph: + return + labels = self.external_graph.get_allowed_labels() + self._entity_link(graph, property_key, labels) + + def invoke(self, input): + new_graph = self.filter_invalid_data(input) + self.similarity_based_link(new_graph) + self.external_graph_based_link(new_graph) + return [new_graph] diff --git a/kag/builder/component/vectorizer/batch_vectorizer.py b/kag/builder/component/vectorizer/batch_vectorizer.py index f1ca0c49..0cb2c5b6 100644 --- a/kag/builder/component/vectorizer/batch_vectorizer.py +++ b/kag/builder/component/vectorizer/batch_vectorizer.py @@ -15,6 +15,7 @@ from kag.builder.model.sub_graph import SubGraph from kag.common.conf import KAG_PROJECT_CONF from kag.common.vectorizer import Vectorizer +from kag.common.utils import get_vector_field_name from kag.interface import VectorizerABC from knext.schema.client import SchemaClient from knext.schema.model.base import IndexTypeEnum @@ -42,16 +43,9 @@ class EmbeddingVectorManager(object): def __init__(self): self._placeholders = [] - def _create_vector_field_name(self, property_key): - from kag.common.utils import to_snake_case - - name = f"{property_key}_vector" - name = to_snake_case(name) - return "_" + name - def get_placeholder(self, properties, vector_field): for property_key, property_value in properties.items(): - field_name = self._create_vector_field_name(property_key) + field_name = get_vector_field_name(property_key) if field_name != vector_field: continue if not property_value: @@ -76,21 +70,27 @@ def _get_text_batch(self): text_batch[property_value].append(placeholder) return text_batch - def _generate_vectors(self, vectorizer, text_batch): + def _generate_vectors(self, vectorizer, text_batch, batch_size=1024): texts = list(text_batch) if not texts: return [] - vectors = vectorizer.vectorize(texts) - return vectors + + n_batchs = len(texts) // batch_size + 1 + embeddings = [] + for idx in range(n_batchs): + start = idx * batch_size + end = min(start + batch_size, len(texts)) + embeddings.extend(vectorizer.vectorize(texts[start:end])) + return embeddings def _fill_vectors(self, vectors, text_batch): for vector, (_text, placeholders) in zip(vectors, text_batch.items()): for placeholder in placeholders: placeholder._embedding_vector = vector - def batch_generate(self, vectorizer): + def batch_generate(self, vectorizer, batch_size=1024): text_batch = self._get_text_batch() - vectors = self._generate_vectors(vectorizer, text_batch) + vectors = self._generate_vectors(vectorizer, text_batch, batch_size) self._fill_vectors(vectors, text_batch) def patch(self): @@ -104,7 +104,7 @@ def __init__(self, vectorizer, vector_index_meta=None, extra_labels=("Entity",)) self._extra_labels = extra_labels self._vector_index_meta = vector_index_meta or {} - def batch_generate(self, node_batch): + def batch_generate(self, node_batch, batch_size=1024): manager = EmbeddingVectorManager() vector_index_meta = self._vector_index_meta for node_item in node_batch: @@ -121,18 +121,19 @@ def batch_generate(self, node_batch): placeholder = manager.get_placeholder(properties, vector_field) if placeholder is not None: properties[vector_field] = placeholder - manager.batch_generate(self._vectorizer) + manager.batch_generate(self._vectorizer, batch_size) manager.patch() @VectorizerABC.register("batch") class BatchVectorizer(VectorizerABC): - def __init__(self, vectorizer_model: Vectorizer): + def __init__(self, vectorizer_model: Vectorizer, batch_size: int = 1024): super().__init__() self.project_id = KAG_PROJECT_CONF.project_id # self._init_graph_store() self.vec_meta = self._init_vec_meta() self.vectorizer_model = vectorizer_model + self.batch_size = batch_size def _init_vec_meta(self): vec_meta = defaultdict(list) @@ -144,18 +145,9 @@ def _init_vec_meta(self): IndexTypeEnum.Vector, IndexTypeEnum.TextAndVector, ]: - vec_meta[type_name].append( - self._create_vector_field_name(prop_name) - ) + vec_meta[type_name].append(get_vector_field_name(prop_name)) return vec_meta - def _create_vector_field_name(self, property_key): - from kag.common.utils import to_snake_case - - name = f"{property_key}_vector" - name = to_snake_case(name) - return "_" + name - def _generate_embedding_vectors(self, input_subgraph: SubGraph) -> SubGraph: node_list = [] node_batch = [] @@ -167,7 +159,7 @@ def _generate_embedding_vectors(self, input_subgraph: SubGraph) -> SubGraph: node_list.append((node, properties)) node_batch.append((node.label, properties.copy())) generator = EmbeddingVectorGenerator(self.vectorizer_model, self.vec_meta) - generator.batch_generate(node_batch) + generator.batch_generate(node_batch, self.batch_size) for (node, properties), (_node_label, new_properties) in zip( node_list, node_batch ): diff --git a/kag/builder/component/writer/kg_writer.py b/kag/builder/component/writer/kg_writer.py index 081902af..7f270afc 100644 --- a/kag/builder/component/writer/kg_writer.py +++ b/kag/builder/component/writer/kg_writer.py @@ -81,4 +81,5 @@ def _handle(self, input: Dict, alter_operation: str, **kwargs): """The calling interface provided for SPGServer.""" _input = self.input_types.from_dict(input) _output = self.invoke(_input, alter_operation) # noqa + return None diff --git a/kag/builder/model/sub_graph.py b/kag/builder/model/sub_graph.py index b359ca2b..d44a2bdb 100644 --- a/kag/builder/model/sub_graph.py +++ b/kag/builder/model/sub_graph.py @@ -57,7 +57,7 @@ def from_dict(cls, input: Dict): _id=input["id"], name=input["name"], label=input["label"], - properties=input["properties"], + properties=input.get("properties", {}), ) def __eq__(self, other): @@ -136,7 +136,7 @@ def from_dict(cls, input: Dict): _id=input["to"], name=input["to"], label=input["toType"], properties={} ), label=input["label"], - properties=input["properties"], + properties=input.get("properties", {}), ) def __eq__(self, other): diff --git a/kag/common/conf.py b/kag/common/conf.py index 7d27cc5c..45210ca9 100644 --- a/kag/common/conf.py +++ b/kag/common/conf.py @@ -28,25 +28,42 @@ class KAGConstants(object): KAG_CFG_PREFIX = "KAG" GLOBAL_CONFIG_KEY = "global" PROJECT_CONFIG_KEY = "project" - KAG_PROJECT_ID_KEY = "KAG_PROJECT_ID" - KAG_HOST_ADDR_KEY = "KAG_PROJECT_HOST_ADDR" - KAG_LANGUAGE_KEY = "KAG_LANGUAGE" - KAG_BIZ_SCENE_KEY = "KAG_BIZ_SCENE" + KAG_NAMESPACE_KEY = "namespace" + KAG_PROJECT_ID_KEY = "id" + KAG_PROJECT_HOST_ADDR_KEY = "host_addr" + KAG_LANGUAGE_KEY = "language" + KAG_BIZ_SCENE_KEY = "biz_scene" + ENV_KAG_PROJECT_ID = "KAG_PROJECT_ID" + ENV_KAG_PROJECT_HOST_ADDR = "KAG_PROJECT_HOST_ADDR" + KAG_SIMILAR_EDGE_NAME = "similar" + + KS8_ENV_TF_CONFIG = "TF_CONFIG" + K8S_ENV_MASTER_ADDR = "MASTER_ADDR" + K8S_ENV_MASTER_PORT = "MASTER_PORT" + K8S_ENV_WORLD_SIZE = "WORLD_SIZE" + K8S_ENV_RANK = "RANK" + K8S_ENV_POD_NAME = "POD_NAME" class KAGGlobalConf: def __init__(self): - pass - - def setup(self, **kwargs): - self.project_id = kwargs.pop(KAGConstants.KAG_PROJECT_ID_KEY, "1") - self.host_addr = kwargs.pop( - KAGConstants.KAG_HOST_ADDR_KEY, "http://127.0.0.1:8887" - ) - self.biz_scene = kwargs.pop(KAGConstants.KAG_BIZ_SCENE_KEY, "default") - self.language = kwargs.pop(KAGConstants.KAG_LANGUAGE_KEY, "en") - for k, v in kwargs.items(): - setattr(self, k, v) + self._initialized = False + + def initialize(self, **kwargs): + if not self._initialized: + print(f"kwargs = {kwargs}") + self.project_id = kwargs.pop(KAGConstants.KAG_PROJECT_ID_KEY, "1") + self.host_addr = kwargs.pop( + KAGConstants.KAG_PROJECT_HOST_ADDR_KEY, "http://127.0.0.1:8887" + ) + self.biz_scene = kwargs.pop(KAGConstants.KAG_BIZ_SCENE_KEY, "default") + self.language = kwargs.pop(KAGConstants.KAG_LANGUAGE_KEY, "en") + self.namespace = kwargs.pop(KAGConstants.KAG_NAMESPACE_KEY, None) + for k, v in kwargs.items(): + setattr(self, k, v) + self._initialized = True + else: + print("KAGGlobalConf has been initialized and cannot be initialized again!") def _closest_cfg( @@ -71,8 +88,8 @@ def load_config(prod: bool = False): Get kag config file as a ConfigParser. """ if prod: - project_id = os.getenv(KAGConstants.KAG_PROJECT_ID_KEY) - host_addr = os.getenv(KAGConstants.KAG_HOST_ADDR_KEY) + project_id = os.getenv(KAGConstants.ENV_KAG_PROJECT_ID) + host_addr = os.getenv(KAGConstants.ENV_KAG_PROJECT_HOST_ADDR) config = ProjectClient(host_addr=host_addr).get_config(project_id) return config else: @@ -87,8 +104,6 @@ def load_config(prod: bool = False): def init_kag_config(config): - global_config = config.get(KAGConstants.GLOBAL_CONFIG_KEY, {}) - KAG_PROJECT_CONF.setup(**global_config) log_conf = config.get("log", {}) if log_conf: log_level = log_conf.get("level", "INFO") @@ -110,7 +125,7 @@ def initialize(self, prod: bool = True): self.prod = prod self.config = load_config(prod) global_config = self.config.get(KAGConstants.PROJECT_CONFIG_KEY, {}) - self.global_config.setup(**global_config) + self.global_config.initialize(**global_config) init_kag_config(self.config) self._is_initialize = True @@ -126,7 +141,7 @@ def all_config(self): def init_env(): project_id = os.getenv(KAGConstants.KAG_PROJECT_ID_KEY) - host_addr = os.getenv(KAGConstants.KAG_HOST_ADDR_KEY) + host_addr = os.getenv(KAGConstants.KAG_PROJECT_HOST_ADDR_KEY) if project_id and host_addr: prod = True else: diff --git a/kag/common/env.py b/kag/common/env.py new file mode 100644 index 00000000..c587a2ba --- /dev/null +++ b/kag/common/env.py @@ -0,0 +1,143 @@ +# -*- coding: utf-8 -*- +import os +import json +import time +import datetime +import socket +import traceback +import torch +import torch.distributed as dist +from kag.common.conf import KAGConstants + + +def parse_tf_config(): + tf_config_str = os.environ.get(KAGConstants.KS8_ENV_TF_CONFIG, None) + if tf_config_str is None: + return None + else: + return json.loads(tf_config_str) + + +def get_role_number(config, role_name): + role_info = config["cluster"].get(role_name, None) + if role_info is None: + return 0 + else: + return len(role_info) + + +def get_rank(default=None): + if KAGConstants.K8S_ENV_RANK in os.environ: + return int(os.environ[KAGConstants.K8S_ENV_RANK]) + + tf_config = parse_tf_config() + if tf_config is None: + print(f"no RANK info in env/tf_config, use default value:{default}") + return default + + num_master = get_role_number(tf_config, "master") + task_type = tf_config["task"]["type"] + task_index = tf_config["task"]["index"] + if task_type == "master": + rank = task_index + elif task_type == "worker": + rank = num_master + task_index + else: + rank = default + + return rank + + +def get_world_size(default=None): + if KAGConstants.K8S_ENV_WORLD_SIZE in os.environ: + return os.environ[KAGConstants.K8S_ENV_WORLD_SIZE] + + tf_config = parse_tf_config() + if tf_config is None: + return default + + num_master = get_role_number(tf_config, "master") + num_worker = get_role_number(tf_config, "worker") + + return num_master + num_worker + + +def get_master_port(default=None): + return os.environ.get(KAGConstants.K8S_ENV_MASTER_PORT, default) + + +def get_master_addr(default=None): + if KAGConstants.K8S_ENV_MASTER_ADDR in os.environ: + return os.environ[KAGConstants.K8S_ENV_MASTER_ADDR] + + tf_config = parse_tf_config() + if tf_config is None: + return default + + return tf_config["cluster"]["worker"][0] + + +def host2tensor(master_port): + host_str = socket.gethostbyname(socket.gethostname()) + host = [int(x) for x in host_str.split(".")] + host.append(int(master_port)) + host_tensor = torch.tensor(host) + return host_tensor + + +def tensor2host(host_tensor): + host_tensor = host_tensor.tolist() + host = ".".join([str(x) for x in host_tensor[0:4]]) + port = host_tensor[4] + return f"{host}:{port}" + + +def sync_hosts(): + rank = get_rank() + if rank is None: + raise ValueError("can't get rank of container") + rank = int(rank) + + world_size = get_world_size() + if world_size is None: + raise ValueError("can't get world_size of container") + world_size = int(world_size) + + master_port = get_master_port() + if master_port is None: + raise ValueError("can't get master_port of container") + master_port = int(master_port) + + while True: + try: + dist.init_process_group( + backend="gloo", + rank=rank, + world_size=world_size, + timeout=datetime.timedelta(days=1), + ) + break + except Exception as e: + error_traceback = traceback.format_exc() + print(f"failed to init process group, info: {e}\n\n\n{error_traceback}") + time.sleep(60) + print("Done init process group, get all hosts...") + host_tensors = [torch.tensor([0, 0, 0, 0, 0]) for x in range(world_size)] + dist.all_gather(host_tensors, host2tensor(master_port)) + # we need to destory torch process group to release MASTER_PORT, otherwise the server + # can't serving on it . + print("Done get all hosts, destory process group...") + dist.destroy_process_group() + time.sleep(10) + return [tensor2host(x) for x in host_tensors] + + +def extract_job_name_from_pod_name(pod_name): + if "-ptjob" in pod_name: + return pod_name.rsplit("-ptjob", maxsplit=1)[0] + elif "-tfjob" in pod_name: + return pod_name.rsplit("-tfjob", maxsplit=1)[0] + elif "-mpijob" in pod_name: + return pod_name.rsplit("-mpijob", maxsplit=1)[0] + else: + return None diff --git a/kag/common/llm/mock_llm.py b/kag/common/llm/mock_llm.py new file mode 100644 index 00000000..6193f0f4 --- /dev/null +++ b/kag/common/llm/mock_llm.py @@ -0,0 +1,69 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 OpenSPG Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License +# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. + + +import json +from kag.common.llm.llm_client import LLMClient + + +@LLMClient.register("mock") +class MockLLMClient(LLMClient): + def __init__(self): + pass + + def match_input(self, prompt): + if "You're a very effective entity extraction system" in prompt: + return [ + { + "entity": "The Rezort", + "type": "Movie", + "category": "Works", + "description": "A 2015 British zombie horror film directed by Steve Barker and written by Paul Gerstenberger.", + }, + { + "entity": "2015", + "type": "Year", + "category": "Date", + "description": "The year the movie 'The Rezort' was released.", + }, + ] + if "please attempt to provide the official names of" in prompt: + return [ + { + "entity": "The Rezort", + "type": "Movie", + "category": "Works", + "description": "A 2015 British zombie horror film directed by Steve Barker and written by Paul Gerstenberger.", + }, + { + "entity": "2015", + "type": "Year", + "category": "Date", + "description": "The year the movie 'The Rezort' was released.", + }, + ] + if ( + "You are an expert specializing in carrying out open information extraction" + in prompt + ): + return [ + ["The Rezort", "is", "zombie horror film"], + ["The Rezort", "publish at", "2015"], + ] + return "I am an intelligent assistant" + + def __call__(self, prompt): + + return json.dumps(self.match_input(prompt)) + + def call_with_json_parse(self, prompt): + return self.match_input(prompt) diff --git a/kag/common/utils.py b/kag/common/utils.py index b6891952..dd68ecaa 100644 --- a/kag/common/utils.py +++ b/kag/common/utils.py @@ -203,3 +203,9 @@ def to_snake_case(name): words = re.findall("[A-Za-z][a-z0-9]*", name) result = "_".join(words).lower() return result + + +def get_vector_field_name(property_key: str): + name = f"{property_key}_vector" + name = to_snake_case(name) + return "_" + name diff --git a/kag/examples/musique/builder/indexer.py b/kag/examples/musique/builder/indexer.py index eff3afee..ec869f10 100644 --- a/kag/examples/musique/builder/indexer.py +++ b/kag/examples/musique/builder/indexer.py @@ -49,9 +49,9 @@ def buildKB(file_path): chain_config = KAG_CONFIG.all_config["chain"] chain = MusiqueBuilderChain.from_config(chain_config) - chain.invoke(file_path=file_path, max_workers=20) + chain.invoke(file_path=file_path, max_workers=1) - logger.info(f"\n\nbuildKB successfully for {corpusFilePath}\n\n") + logger.info(f"\n\nbuildKB successfully for {file_path}\n\n") if __name__ == "__main__": diff --git a/kag/interface/__init__.py b/kag/interface/__init__.py index d61bc5f2..862e2ea7 100644 --- a/kag/interface/__init__.py +++ b/kag/interface/__init__.py @@ -17,7 +17,11 @@ from kag.interface.builder.aligner_abc import AlignerABC from kag.interface.builder.writer_abc import SinkWriterABC from kag.interface.builder.vectorizer_abc import VectorizerABC - +from kag.interface.builder.external_graph_abc import ( + ExternalGraphLoaderABC, + MatchConfig, +) +from kag.interface.builder.postprocessor_abc import PostProcessorABC from kag.interface.solver.base import KagBaseModule, Question from kag.interface.solver.kag_generator_abc import KAGGeneratorABC from kag.interface.solver.kag_memory_abc import KagMemoryABC @@ -36,6 +40,9 @@ "AlignerABC", "SinkWriterABC", "VectorizerABC", + "ExternalGraphLoaderABC", + "MatchConfig", + "PostProcessorABC", "KagBaseModule", "Question", "KAGGeneratorABC", diff --git a/kag/interface/builder/external_graph_abc.py b/kag/interface/builder/external_graph_abc.py new file mode 100644 index 00000000..37aa0ef4 --- /dev/null +++ b/kag/interface/builder/external_graph_abc.py @@ -0,0 +1,58 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 OpenSPG Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License +# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. +import numpy as np +from typing import List, Union, Any +from kag.builder.model.sub_graph import Node, SubGraph +from kag.common.registry import Registrable +from kag.interface.builder.base import BuilderComponent +from knext.common.base.runnable import Input, Output + + +class MatchConfig(Registrable): + def __init__(self, k: int = 1, labels: List[str] = None, threshold: float = 0.9): + self.k = k + self.labels = labels + self.threshold = threshold + + +MatchConfig.register(MatchConfig, "base", as_default=True) + + +class ExternalGraphLoaderABC(BuilderComponent): + def __init__(self, match_config: MatchConfig): + self.match_config = match_config + + def dump(self) -> List[SubGraph]: + raise NotImplementedError("dump not implemented yet.") + + def ner(self, content: str) -> List[Node]: + raise NotImplementedError("ner not implemented yet.") + + def get_allowed_labels(self, labels: List[str] = None) -> List[str]: + raise NotImplementedError("get_allowed_labels not implemented yet.") + + def match_entity( + self, + query: Union[str, List[float], np.ndarray], + ): + pass + + @property + def input_types(self): + return Any + + @property + def output_types(self): + return SubGraph + + def invoke(self, input: Input, **kwargs) -> List[Output]: + return self.dump() diff --git a/kag/interface/builder/postprocessor_abc.py b/kag/interface/builder/postprocessor_abc.py new file mode 100644 index 00000000..d3906171 --- /dev/null +++ b/kag/interface/builder/postprocessor_abc.py @@ -0,0 +1,35 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 OpenSPG Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License +# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. +from typing import List + +from kag.interface.builder.base import BuilderComponent +from kag.builder.model.sub_graph import SubGraph +from knext.common.base.runnable import Input, Output + + +class PostProcessorABC(BuilderComponent): + """ + Interface for vectorizer. + """ + + @property + def input_types(self): + return SubGraph + + @property + def output_types(self): + return SubGraph + + def invoke(self, input: Input, **kwargs) -> List[Output]: + raise NotImplementedError( + f"`invoke` is not currently supported for {self.__class__.__name__}." + ) diff --git a/tests/builder/component/test_external_graph.py b/tests/builder/component/test_external_graph.py new file mode 100644 index 00000000..bb1dc6ab --- /dev/null +++ b/tests/builder/component/test_external_graph.py @@ -0,0 +1,82 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 OpenSPG Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License +# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. + +from kag.interface import ExternalGraphLoaderABC, SinkWriterABC, VectorizerABC +from kag.builder.model.sub_graph import Node +from kag.common.conf import KAG_CONFIG +from kag.common.utils import get_vector_field_name + + +def get_config(): + config = { + "type": "base", + "node_file_path": "../data/nodes.json", + "edge_file_path": "../data/edges.json", + "match_config": { + "k": 1, + "threshold": 0.9, + }, + } + return config + + +def _test_eg_base(): + config = get_config() + eg = ExternalGraphLoaderABC.from_config(config) + assert len(eg.nodes) == 38 + assert len(eg.edges) == 19 + + +def _test_eg_dump(): + config = get_config() + eg = ExternalGraphLoaderABC.from_config(config) + graphs = eg.invoke(None) + for graph in graphs: + if len(graph.nodes) > 0: + assert len(graph.edges) == 0 + labels = set() + for node in graph.nodes: + labels.add(node.label) + assert len(labels) == 1 + elif len(graph.edges) > 0: + assert len(graph.nodes) == 0 + labels = set() + for edge in graph.edges: + labels.add(edge.label) + assert len(labels) == 1 + + vectorizer = VectorizerABC.from_config(KAG_CONFIG.all_config["vectorizer"]) + writer = SinkWriterABC.from_config(KAG_CONFIG.all_config["writer"]) + for graph in graphs: + new_graph = vectorizer.invoke(graph)[0] + writer.invoke(new_graph) + + +def _test_eg_query(): + config = get_config() + eg = ExternalGraphLoaderABC.from_config(config) + entities = eg.ner("促生长素抑制素和蛋白酶有什么关系") + assert len(entities) > 0 + for entity in entities: + assert isinstance(entity, Node) + + text_matched = eg.match_entity("蛋白水解酶") + assert len(text_matched) > 0 and text_matched[0]["node"]["name"] == "蛋白水解酶" + vector = text_matched[0]["node"][get_vector_field_name("name")] + vector_matched = eg.match_entity(vector) + assert len(vector_matched) > 0 and vector_matched[0]["node"]["name"] == "蛋白水解酶" + + +def test_eg(): + _test_eg_base() + _test_eg_dump() + _test_eg_query() diff --git a/tests/builder/component/test_post_processor.py b/tests/builder/component/test_post_processor.py index e69de29b..6c9b30c4 100644 --- a/tests/builder/component/test_post_processor.py +++ b/tests/builder/component/test_post_processor.py @@ -0,0 +1,88 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 OpenSPG Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License +# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. + +from kag.interface import ExternalGraphLoaderABC, PostProcessorABC, VectorizerABC +from kag.builder.model.sub_graph import Node, Edge, SubGraph +from kag.common.conf import KAG_CONFIG +from kag.common.utils import get_vector_field_name + + +def get_config(): + config = { + "type": "base", + "similarity_threshold": 0.1, + } + return config + + +def create_mock_graph(): + nodes = [ + Node("Apple", name="Apple", label="Concept", properties={}), + Node("Banana", name="Banana", label="Concept", properties={}), + Node("Peach", name="Peach", label="Concept", properties={}), + Node("000", name="000", label="Number", properties={}), + Node("Error", name="Peach", label="", properties={}), + ] + + edges = [ + Edge("1", from_node=nodes[0], to_node=nodes[1], label="sim", properties={}), + Edge("2", from_node=nodes[0], to_node=nodes[2], label="sim", properties={}), + Edge("3", from_node=nodes[1], to_node=nodes[2], label="sim", properties={}), + Edge("4", from_node=nodes[0], to_node=nodes[3], label="", properties={}), + Edge("5", from_node=nodes[0], to_node=nodes[4], label="", properties={}), + ] + graph = SubGraph(nodes=nodes, edges=edges) + return graph + + +def test_postprocessor_filter(): + config = get_config() + postprocessor = PostProcessorABC.from_config(config) + graph = create_mock_graph() + new_graph = postprocessor.filter_invalid_data(graph) + assert len(new_graph.nodes) == 3 + assert len(new_graph.edges) == 3 + for node in new_graph.nodes: + assert node.label == "Concept" + for edge in new_graph.edges: + assert edge.label == "sim" + + +def test_postprocessor_add_sim_edges(): + config = get_config() + postprocessor = PostProcessorABC.from_config(config) + graph = create_mock_graph() + vectorizer = VectorizerABC.from_config(KAG_CONFIG.all_config["vectorizer"]) + graph = vectorizer.invoke(graph)[0] + origin_num_edges = len(graph.edges) + postprocessor.similarity_based_link(graph) + assert len(graph.edges) > origin_num_edges + + +def test_postprocessor_add_eg_edges(): + config = get_config() + config["external_graph"] = { + "type": "base", + "node_file_path": "../data/nodes.json", + "edge_file_path": "../data/edges.json", + "match_config": { + "k": 1, + "threshold": 0.9, + }, + } + postprocessor = PostProcessorABC.from_config(config) + graph = create_mock_graph() + vectorizer = VectorizerABC.from_config(KAG_CONFIG.all_config["vectorizer"]) + graph = vectorizer.invoke(graph)[0] + origin_num_edges = len(graph.edges) + postprocessor.external_graph_based_link(graph) + assert len(graph.edges) > origin_num_edges diff --git a/tests/builder/data/edges.json b/tests/builder/data/edges.json new file mode 100644 index 00000000..bbeaaa18 --- /dev/null +++ b/tests/builder/data/edges.json @@ -0,0 +1,154 @@ +[ + { + "id": "(缩)肾上腺皮质激素-促肾上腺皮质激素", + "from": "(缩)肾上腺皮质激素", + "fromType": "Concept", + "to": "促肾上腺皮质激素", + "toType": "Concept", + "label": "isA" + }, + { + "id": "促肾皮素-促肾上腺皮质激素", + "from": "促肾皮素", + "fromType": "Concept", + "to": "促肾上腺皮质激素", + "toType": "Concept", + "label": "isA" + }, + { + "id": "促皮质素-促肾上腺皮质激素", + "from": "促皮质素", + "fromType": "Concept", + "to": "促肾上腺皮质激素", + "toType": "Concept", + "label": "isA" + }, + { + "id": "促肾上腺皮质[激]素-促肾上腺皮质激素", + "from": "促肾上腺皮质[激]素", + "fromType": "Concept", + "to": "促肾上腺皮质激素", + "toType": "Concept", + "label": "isA" + }, + { + "id": "ACTH-促肾上腺皮质激素", + "from": "ACTH", + "fromType": "Concept", + "to": "促肾上腺皮质激素", + "toType": "Concept", + "label": "isA" + }, + { + "id": "促皮质激素-促肾上腺皮质激素", + "from": "促皮质激素", + "fromType": "Concept", + "to": "促肾上腺皮质激素", + "toType": "Concept", + "label": "isA" + }, + { + "id": "促肾上腺皮质素-促肾上腺皮质激素", + "from": "促肾上腺皮质素", + "fromType": "Concept", + "to": "促肾上腺皮质激素", + "toType": "Concept", + "label": "isA" + }, + { + "id": "人生长激素-促生长素", + "from": "人生长激素", + "fromType": "Concept", + "to": "促生长素", + "toType": "Concept", + "label": "isA" + }, + { + "id": "生长激素-促生长素", + "from": "生长激素", + "fromType": "Concept", + "to": "促生长素", + "toType": "Concept", + "label": "isA" + }, + { + "id": "生长激素释放抑制激素-生长抑素", + "from": "生长激素释放抑制激素", + "fromType": "Concept", + "to": "生长抑素", + "toType": "Concept", + "label": "isA" + }, + { + "id": "促生长素抑制素-生长抑素", + "from": "促生长素抑制素", + "fromType": "Concept", + "to": "生长抑素", + "toType": "Concept", + "label": "isA" + }, + { + "id": "生长抑素醋酸盐-生长抑素", + "from": "生长抑素醋酸盐", + "fromType": "Concept", + "to": "生长抑素", + "toType": "Concept", + "label": "isA" + }, + { + "id": "胃泌激素-促胃液素", + "from": "胃泌激素", + "fromType": "Concept", + "to": "促胃液素", + "toType": "Concept", + "label": "isA" + }, + { + "id": "胃泌素-促胃液素", + "from": "胃泌素", + "fromType": "Concept", + "to": "促胃液素", + "toType": "Concept", + "label": "isA" + }, + { + "id": "促乳素-催乳素", + "from": "促乳素", + "fromType": "Concept", + "to": "催乳素", + "toType": "Concept", + "label": "isA" + }, + { + "id": "泌乳素-催乳素", + "from": "泌乳素", + "fromType": "Concept", + "to": "催乳素", + "toType": "Concept", + "label": "isA" + }, + { + "id": "催乳激素-催乳素", + "from": "催乳激素", + "fromType": "Concept", + "to": "催乳素", + "toType": "Concept", + "label": "isA" + }, + { + "id": "蛋白水解酶-内肽酶", + "from": "蛋白水解酶", + "fromType": "Concept", + "to": "内肽酶", + "toType": "Concept", + "label": "isA" + }, + { + "id": "蛋白酶-内肽酶", + "from": "蛋白酶", + "fromType": "Concept", + "to": "内肽酶", + "toType": "Concept", + "label": "isA" + } +] \ No newline at end of file diff --git a/tests/builder/data/nodes.json b/tests/builder/data/nodes.json new file mode 100644 index 00000000..2b0c8af8 --- /dev/null +++ b/tests/builder/data/nodes.json @@ -0,0 +1,192 @@ +[ + { + "id": "(缩)肾上腺皮质激素", + "name": "(缩)肾上腺皮质激素", + "label": "Concept" + }, + { + "id": "促肾上腺皮质激素", + "name": "促肾上腺皮质激素", + "label": "Concept" + }, + { + "id": "促肾皮素", + "name": "促肾皮素", + "label": "Concept" + }, + { + "id": "促肾上腺皮质激素", + "name": "促肾上腺皮质激素", + "label": "Concept" + }, + { + "id": "促皮质素", + "name": "促皮质素", + "label": "Concept" + }, + { + "id": "促肾上腺皮质激素", + "name": "促肾上腺皮质激素", + "label": "Concept" + }, + { + "id": "促肾上腺皮质[激]素", + "name": "促肾上腺皮质[激]素", + "label": "Concept" + }, + { + "id": "促肾上腺皮质激素", + "name": "促肾上腺皮质激素", + "label": "Concept" + }, + { + "id": "ACTH", + "name": "ACTH", + "label": "Concept" + }, + { + "id": "促肾上腺皮质激素", + "name": "促肾上腺皮质激素", + "label": "Concept" + }, + { + "id": "促皮质激素", + "name": "促皮质激素", + "label": "Concept" + }, + { + "id": "促肾上腺皮质激素", + "name": "促肾上腺皮质激素", + "label": "Concept" + }, + { + "id": "促肾上腺皮质素", + "name": "促肾上腺皮质素", + "label": "Concept" + }, + { + "id": "促肾上腺皮质激素", + "name": "促肾上腺皮质激素", + "label": "Concept" + }, + { + "id": "人生长激素", + "name": "人生长激素", + "label": "Concept" + }, + { + "id": "促生长素", + "name": "促生长素", + "label": "Concept" + }, + { + "id": "生长激素", + "name": "生长激素", + "label": "Concept" + }, + { + "id": "促生长素", + "name": "促生长素", + "label": "Concept" + }, + { + "id": "生长激素释放抑制激素", + "name": "生长激素释放抑制激素", + "label": "Concept" + }, + { + "id": "生长抑素", + "name": "生长抑素", + "label": "Concept" + }, + { + "id": "促生长素抑制素", + "name": "促生长素抑制素", + "label": "Concept" + }, + { + "id": "生长抑素", + "name": "生长抑素", + "label": "Concept" + }, + { + "id": "生长抑素醋酸盐", + "name": "生长抑素醋酸盐", + "label": "Concept" + }, + { + "id": "生长抑素", + "name": "生长抑素", + "label": "Concept" + }, + { + "id": "胃泌激素", + "name": "胃泌激素", + "label": "Concept" + }, + { + "id": "促胃液素", + "name": "促胃液素", + "label": "Concept" + }, + { + "id": "胃泌素", + "name": "胃泌素", + "label": "Concept" + }, + { + "id": "促胃液素", + "name": "促胃液素", + "label": "Concept" + }, + { + "id": "促乳素", + "name": "促乳素", + "label": "Concept" + }, + { + "id": "催乳素", + "name": "催乳素", + "label": "Concept" + }, + { + "id": "泌乳素", + "name": "泌乳素", + "label": "Concept" + }, + { + "id": "催乳素", + "name": "催乳素", + "label": "Concept" + }, + { + "id": "催乳激素", + "name": "催乳激素", + "label": "Concept" + }, + { + "id": "催乳素", + "name": "催乳素", + "label": "Concept" + }, + { + "id": "蛋白水解酶", + "name": "蛋白水解酶", + "label": "Concept" + }, + { + "id": "内肽酶", + "name": "内肽酶", + "label": "Concept" + }, + { + "id": "蛋白酶", + "name": "蛋白酶", + "label": "Concept" + }, + { + "id": "内肽酶", + "name": "内肽酶", + "label": "Concept" + } +] \ No newline at end of file diff --git a/tests/builder/kag_config.yaml b/tests/builder/kag_config.yaml new file mode 100644 index 00000000..2624c5d0 --- /dev/null +++ b/tests/builder/kag_config.yaml @@ -0,0 +1,23 @@ +project: + biz_scene: default + host_addr: http://127.0.0.1:8887 + id: '1' + language: en + namespace: MuSiQue + project_id: 666 + +llm: &llm_conf + type: mock + +vectorizer: &vec + type: batch + vectorizer_model: + path: ~/.cache/vectorizer/BAAI/bge-base-zh-v1.5 + type: bge + vector_dimensions: 768 + +writer: + type: kg + +log: + level: INFO