Skip to content

Commit

Permalink
more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
zhuzhongshu123 committed Nov 7, 2024
1 parent 41e95a3 commit c18df1f
Show file tree
Hide file tree
Showing 11 changed files with 102 additions and 25 deletions.
39 changes: 23 additions & 16 deletions kag/builder/component/extractor/kag_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
logger = logging.getLogger(__name__)


@ExtractorABC.register("kag", constructor="initialize", as_default=True)
@ExtractorABC.register("kag")
class KAGExtractor(ExtractorABC):
"""
A class for extracting knowledge graph subgraphs from text using a large language model (LLM).
Expand All @@ -45,23 +45,26 @@ def __init__(
triple_prompt: PromptABC = None,
):
self.llm = llm
self.schema = SchemaClient(project_id=KAG_PROJECT_CONF.project_id()).load()
print(f"self.llm: {self.llm}")
self.schema = SchemaClient(project_id=KAG_PROJECT_CONF.project_id).load()
self.ner_prompt = ner_prompt
self.std_prompt = std_prompt
self.triple_prompt = triple_prompt

biz_scene = KAG_PROJECT_CONF.biz_scene
if self.ner_prompt is None:
self.ner_prompt = PromptABC.from_config(
{"type": "default_ner", "language": KAG_PROJECT_CONF.language}
{"type": f"{biz_scene}_ner", "language": KAG_PROJECT_CONF.language}
)
if self.std_prompt is None:
self.std_prompt = PromptABC.from_config(
{"type": "default_std", "language": KAG_PROJECT_CONF.language}
{"type": f"{biz_scene}_std", "language": KAG_PROJECT_CONF.language}
)
if self.triple_prompt is None:
self.std_prompt = PromptABC.from_config(
{"type": "default_triple", "language": KAG_PROJECT_CONF.language}
self.triple_prompt = PromptABC.from_config(
{"type": f"{biz_scene}_triple", "language": KAG_PROJECT_CONF.language}
)
self.create_extra_prompts()

def create_extra_prompts(self):
self.kg_types = []
Expand All @@ -84,16 +87,20 @@ def create_extra_prompts(self):
else:
self.kg_prompt = None

@classmethod
def initialize(
llm: LLMClient,
ner_prompt: PromptABC = None,
std_prompt: PromptABC = None,
triple_prompt: PromptABC = None,
):
extractor = KAGExtractor(llm, ner_prompt, std_prompt, triple_prompt)
extractor.create_extra_prompts()
return extractor
# @classmethod
# def initialize(
# llm: LLMClient,
# ner_prompt: PromptABC = None,
# std_prompt: PromptABC = None,
# triple_prompt: PromptABC = None,
# ):
# print(f"llm = {llm}")
# print(ner_prompt)
# print(std_prompt)
# print(triple_prompt)
# extractor = KAGExtractor(llm, ner_prompt, std_prompt, triple_prompt)
# extractor.create_extra_prompts()
# return extractor

@property
def input_types(self) -> Type[Input]:
Expand Down
2 changes: 1 addition & 1 deletion kag/builder/component/extractor/spg_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
logger = logging.getLogger(__name__)


@ExtractorABC.register("spg", constructor="initialize", as_default=True)
@ExtractorABC.register("spg")
class SPGExtractor(KAGExtractor):
"""
A Builder Component that extracting structured data from long texts by invoking large language model.
Expand Down
2 changes: 1 addition & 1 deletion kag/builder/prompt/default/ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from knext.schema.client import SchemaClient


@PromptABC.register("ner_default")
@PromptABC.register("default_ner")
class OpenIENERPrompt(PromptABC):

template_en = """
Expand Down
2 changes: 1 addition & 1 deletion kag/builder/prompt/default/std.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from kag.interface import PromptABC


@PromptABC.register("std_default")
@PromptABC.register("default_std")
class OpenIEEntitystandardizationdPrompt(PromptABC):
template_en = """
{
Expand Down
2 changes: 1 addition & 1 deletion kag/builder/prompt/default/triple.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from kag.interface import PromptABC


@PromptABC.register("triple_default")
@PromptABC.register("default_triple")
class OpenIETriplePrompt(PromptABC):
template_en = """
{
Expand Down
2 changes: 1 addition & 1 deletion kag/builder/prompt/medical/ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from knext.schema.client import SchemaClient


@PromptABC.register("ner_medical")
@PromptABC.register("medical_ner")
class OpenIENERPrompt(PromptABC):

template_zh = """
Expand Down
2 changes: 1 addition & 1 deletion kag/builder/prompt/medical/std.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from kag.interface import PromptABC


@PromptABC.register("std_medical")
@PromptABC.register("medical_std")
class OpenIEEntitystandardizationdPrompt(PromptABC):

template_zh = """
Expand Down
2 changes: 1 addition & 1 deletion kag/builder/prompt/medical/triple.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from kag.interface import PromptABC


@PromptABC.register("triple_medical")
@PromptABC.register("medical_triple")
class OpenIETriplePrompt(PromptABC):

template_zh = """
Expand Down
37 changes: 37 additions & 0 deletions tests/builder/component/test_batch_vectorizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# -*- coding: utf-8 -*-
from kag.interface import VectorizerABC
from kag.builder.model.sub_graph import SubGraph


def test_batch_vectorizer():
batch_vectorizer = VectorizerABC.from_config(
{
"type": "batch",
"vectorizer_model": {
"type": "bge",
"path": "~/.cache/vectorizer/BAAI/bge-base-zh-v1.5",
"url": "",
"vector_dimensions": 768,
},
}
)
names = [
"精卫填海",
"海阔天空",
"空前绝后",
"后来居上",
"上下一心",
"心旷神怡",
"怡然自得",
"得心应手",
]
subgraph = SubGraph([], [])
for name in names:
subgraph.add_node(id=name, name=name, label="Keyword")

new_graph = batch_vectorizer.invoke(subgraph)[0]
assert len(subgraph.nodes) == len(new_graph.nodes)
for node in new_graph.nodes:
assert node.name in names
assert "_name_vector" in node.properties
assert len(node.properties["_name_vector"]) == 768
33 changes: 33 additions & 0 deletions tests/builder/component/test_extractor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# -*- coding: utf-8 -*-
from kag.common.conf import KAG_CONFIG
from kag.builder.model.chunk import Chunk
from kag.interface import ExtractorABC
from kag.builder.model.sub_graph import SubGraph

llm_config = KAG_CONFIG.all_config["llm"]


def test_kag_extractor():
conf = {"type": "kag", "llm": llm_config, "ner_prompt": {"type": "default_ner"}}

extractor = ExtractorABC.from_config(conf)
with open("../data/test_txt.txt", "r") as reader:
content = reader.read()
chunk = Chunk(id="111", name="test", content=content)
subgraph = extractor.invoke(chunk)[0]
print(subgraph)
print(type(subgraph))
assert isinstance(subgraph, SubGraph)


def test_spg_extractor():
conf = {"type": "spg", "llm": llm_config, "ner_prompt": {"type": "default_ner"}}

extractor = ExtractorABC.from_config(conf)
with open("../data/test_txt.txt", "r") as reader:
content = reader.read()
chunk = Chunk(id="111", name="test", content=content)
subgraph = extractor.invoke(chunk)[0]
print(subgraph)
print(type(subgraph))
assert isinstance(subgraph, SubGraph)
4 changes: 2 additions & 2 deletions tests/common/kag_config.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
global: &global_config
host_addr: http://127.0.0.1:8887
project_id: 666
biz_scene: news
language: zh
biz_scene: default
language: en
project: *global_config
vectorizer:
type: bge
Expand Down

0 comments on commit c18df1f

Please sign in to comment.