Skip to content

Commit

Permalink
add component of load external graph and postprocess subgraphs
Browse files Browse the repository at this point in the history
  • Loading branch information
zhuzhongshu123 committed Nov 13, 2024
1 parent 40ae2bb commit 6a09a94
Show file tree
Hide file tree
Showing 24 changed files with 1,258 additions and 108 deletions.
15 changes: 11 additions & 4 deletions kag/builder/component/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,15 @@
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.

from kag.builder.component.external_graph.external_graph import (
DefaultExternalGraphLoader,
)
from kag.builder.component.extractor.kag_extractor import KAGExtractor
from kag.builder.component.extractor.spg_extractor import SPGExtractor
from kag.builder.component.aligner.kag_post_processor import KAGPostProcessorAligner
from kag.builder.component.aligner.spg_post_processor import SPGPostProcessorAligner
from kag.builder.component.aligner.kag_aligner import KAGAligner
from kag.builder.component.aligner.spg_aligner import SPGAligner
from kag.builder.component.postprocessor.kag_postprocessor import KAGPostProcessor

from kag.builder.component.mapping.spg_type_mapping import SPGTypeMapping
from kag.builder.component.mapping.relation_mapping import RelationMapping
from kag.builder.component.mapping.spo_mapping import SPOMapping
Expand All @@ -38,10 +43,12 @@


__all__ = [
"DefaultExternalGraphLoader",
"KAGExtractor",
"SPGExtractor",
"KAGPostProcessorAligner",
"SPGPostProcessorAligner",
"KAGAligner",
"SPGAligner",
"KAGPostProcessor",
"KGWriter",
"SPGTypeMapping",
"RelationMapping",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@


@AlignerABC.register("kag")
class KAGPostProcessorAligner(AlignerABC):
class KAGAligner(AlignerABC):
def __init__(self, **kwargs):
super().__init__(**kwargs)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@


@AlignerABC.register("spg")
class SPGPostProcessorAligner(AlignerABC):
class SPGAligner(AlignerABC):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.spg_types = SchemaClient(project_id=KAG_PROJECT_CONF.project_id).load()
Expand Down
Empty file.
185 changes: 185 additions & 0 deletions kag/builder/component/external_graph/external_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import jieba
import json
import numpy as np
import logging
from typing import List, Union, Dict
from kag.interface import ExternalGraphLoaderABC, MatchConfig
from kag.common.conf import KAG_PROJECT_CONF
from kag.builder.model.sub_graph import Node, Edge, SubGraph
from knext.schema.client import SchemaClient

from knext.search.client import SearchClient


logger = logging.getLogger()


@ExternalGraphLoaderABC.register("base", constructor="from_json_file", as_default=True)
class DefaultExternalGraphLoader(ExternalGraphLoaderABC):
def __init__(
self,
nodes: List[Node],
edges: List[Edge],
match_config: MatchConfig,
):

self.schema = SchemaClient(project_id=KAG_PROJECT_CONF.project_id).load()
for node in nodes:
if node.label not in self.schema:
raise ValueError(
f"Type of node {node.to_dict()} is beyond the schema definition."
)
for k in node.properties.keys():
if k not in self.schema[node.label]:
raise ValueError(
f"Property of node {node.to_dict()} is beyond the schema definition."
)
self.nodes = nodes
self.edges = edges

self.vocabulary = {}
self.node_labels = set()
for node in self.nodes:
self.vocabulary[node.name] = node
self.node_labels.add(node.label)

for word in self.vocabulary.keys():
jieba.add_word(word)

self.match_config = match_config
self._init_search()

def _init_search(self):
self._search_client = SearchClient(
KAG_PROJECT_CONF.host_addr, KAG_PROJECT_CONF.project_id
)

def _group_by_label(self, data: Union[List[Node], List[Edge]]):
groups = {}

for item in data:
label = item.label
if label not in groups:
groups[label] = [item]
else:
groups[label].append(item)
return list(groups.values())

def _group_by_cnt(self, data, n):
return [data[i : i + n] for i in range(0, len(data), n)]

def dump(self, max_num_nodes: int = 4096, max_num_edges: int = 4096):
graphs = []
# process nodes
for item in self._group_by_label(self.nodes):
for grouped_nodes in self._group_by_cnt(item, max_num_nodes):
graphs.append(SubGraph(nodes=grouped_nodes, edges=[]))

# process edges
for item in self._group_by_label(self.edges):
for grouped_edges in self._group_by_cnt(item, max_num_edges):
graphs.append(SubGraph(nodes=[], edges=grouped_edges))

return graphs

def ner(self, content: str):
output = []
for word in jieba.cut(content):
if word in self.vocabulary:
output.append(self.vocabulary[word])
return output

def get_allowed_labels(self, labels: List[str] = None):
allowed_labels = []

namespace = KAG_PROJECT_CONF.namespace
if labels is None:
allowed_labels = [f"{namespace}.{x}" for x in self.node_labels]
else:
for label in labels:
# remove namespace
if label.startswith(KAG_PROJECT_CONF.namespace):
label = label.split(".")[1]
if label in self.node_labels:
allowed_labels.append(f"{namespace}.{label}")
return allowed_labels

def search_result_to_node(self, search_result: Dict):
output = []
for label in search_result["__labels__"]:
node = {
"id": search_result["id"],
"name": search_result["name"],
"label": label,
}
output.append(Node.from_dict(node))
return output

def text_match(self, query: str, k: int = 1, labels: List[str] = None):
allowed_labels = self.get_allowed_labels(labels)
text_matched = self._search_client.search_text(query, allowed_labels, topk=k)
return text_matched

def vector_match(
self,
query: Union[List[float], np.ndarray],
k: int = 1,
threshold: float = 0.9,
labels: List[str] = None,
):
allowed_labels = self.get_allowed_labels(labels)
if isinstance(query, np.ndarray):
query = query.tolist()
matched_results = []
for label in allowed_labels:
vector_matched = self._search_client.search_vector(
label=label, property_key="name", query_vector=query, topk=k
)
matched_results.extend(vector_matched)

filtered_results = []
for item in matched_results:
score = item["score"]
if score >= threshold:
filtered_results.append(item)
return filtered_results

def match_entity(self, query: Union[str, List[float], np.ndarray]):
if isinstance(query, str):
return self.text_match(
query, k=self.match_config.k, labels=self.match_config.labels
)
else:
return self.vector_match(
query,
k=self.match_config.k,
labels=self.match_config.labels,
threshold=self.match_config.threshold,
)

@classmethod
def from_json_file(
cls,
node_file_path: str,
edge_file_path,
match_config: MatchConfig,
):

nodes = []
for item in json.load(open(node_file_path, "r")):
nodes.append(Node.from_dict(item))
edges = []
for item in json.load(open(edge_file_path, "r")):
edges.append(Edge.from_dict(item))
return cls(nodes=nodes, edges=edges, match_config=match_config)
85 changes: 36 additions & 49 deletions kag/builder/component/extractor/kag_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,15 @@
from kag.common.llm.llm_client import LLMClient
from tenacity import stop_after_attempt, retry

from kag.builder.prompt.spg_prompt import SPG_KGPrompt
from kag.interface import ExtractorABC, PromptABC
from knext.schema.client import OTHER_TYPE, CHUNK_TYPE, BASIC_TYPES
from kag.interface import ExtractorABC, PromptABC, ExternalGraphLoaderABC

from kag.common.conf import KAG_PROJECT_CONF
from kag.common.utils import processing_phrases, to_camel_case
from kag.builder.model.chunk import Chunk
from kag.builder.model.sub_graph import SubGraph
from knext.schema.client import OTHER_TYPE, CHUNK_TYPE, BASIC_TYPES
from knext.common.base.runnable import Input, Output
from knext.schema.client import SchemaClient
from knext.schema.model.base import SpgTypeEnum

logger = logging.getLogger(__name__)

Expand All @@ -43,6 +42,7 @@ def __init__(
ner_prompt: PromptABC = None,
std_prompt: PromptABC = None,
triple_prompt: PromptABC = None,
external_graph: ExternalGraphLoaderABC = None,
):
self.llm = llm
print(f"self.llm: {self.llm}")
Expand All @@ -64,43 +64,7 @@ def __init__(
self.triple_prompt = PromptABC.from_config(
{"type": f"{biz_scene}_triple", "language": KAG_PROJECT_CONF.language}
)
self.create_extra_prompts()

def create_extra_prompts(self):
self.kg_types = []
for type_name, spg_type in self.schema.items():
if type_name in SPG_KGPrompt.ignored_types:
continue
if spg_type.spg_type_enum == SpgTypeEnum.Concept:
continue
properties = list(spg_type.properties.keys())
for p in properties:
if p not in SPG_KGPrompt.ignored_properties:
self.kg_types.append(type_name)
break
if self.kg_types:
self.kg_prompt = SPG_KGPrompt(
self.kg_types,
language=KAG_PROJECT_CONF.language,
project_id=KAG_PROJECT_CONF.project_id,
)
else:
self.kg_prompt = None

# @classmethod
# def initialize(
# llm: LLMClient,
# ner_prompt: PromptABC = None,
# std_prompt: PromptABC = None,
# triple_prompt: PromptABC = None,
# ):
# print(f"llm = {llm}")
# print(ner_prompt)
# print(std_prompt)
# print(triple_prompt)
# extractor = KAGExtractor(llm, ner_prompt, std_prompt, triple_prompt)
# extractor.create_extra_prompts()
# return extractor
self.external_graph = external_graph

@property
def input_types(self) -> Type[Input]:
Expand All @@ -119,12 +83,34 @@ def named_entity_recognition(self, passage: str):
Returns:
The result of the named entity recognition operation.
"""
if self.kg_types:
kg_result = self.llm.invoke({"input": passage}, self.kg_prompt)
else:
kg_result = []
ner_result = self.llm.invoke({"input": passage}, self.ner_prompt)
return kg_result + ner_result
if self.external_graph:
extra_ner_result = self.external_graph.ner(passage)
else:
extra_ner_result = []
output = []
dedup = set()
for item in extra_ner_result:
name = item.name
label = item.label
description = item.properties.get("desc", "")
semantic_type = item.properties.get("semanticType", label)
if name not in dedup:
dedup.add(name)
output.append(
{
"entity": name,
"type": semantic_type,
"category": label,
"description": description,
}
)
for item in ner_result:
name = item.get("entity", None)
if name and name not in dedup:
dedup.add(name)
output.append(item)
return output

@retry(stop=stop_after_attempt(3))
def named_entity_standardization(self, passage: str, entities: List[Dict]):
Expand Down Expand Up @@ -357,9 +343,10 @@ def invoke(self, input: Input, **kwargs) -> List[Output]:
Returns:
List[Output]: A list of processed results, containing subgraph information.
"""

title = input.name
passage = title + "\n" + input.content

out = []
try:
entities = self.named_entity_recognition(passage)
sub_graph, entities = self.assemble_sub_graph_with_spg_records(entities)
Expand All @@ -371,10 +358,10 @@ def invoke(self, input: Input, **kwargs) -> List[Output]:
std_entities = self.named_entity_standardization(passage, filtered_entities)
self.append_official_name(entities, std_entities)
self.assemble_sub_graph(sub_graph, input, entities, triples)
return [sub_graph]
out.append(sub_graph)
except Exception as e:
import traceback

traceback.print_exc()
logger.info(e)
return []
return out
Empty file.
Loading

0 comments on commit 6a09a94

Please sign in to comment.