Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(kag): refactor builder components #36

Closed
Closed
Prev Previous commit
Next Next commit
solver refactor
zhuzhongshu123 committed Nov 14, 2024
commit d3ff9d3d1679b9ae4662140281710e4480c1090d
4 changes: 3 additions & 1 deletion kag/__init__.py
Original file line number Diff line number Diff line change
@@ -210,7 +210,9 @@
import kag.builder.component
import kag.builder.prompt
import kag.solver.prompt

import kag.common.vectorize_model
import kag.common.llm
import kag.solver
from kag.common.conf import init_env

init_env()
2 changes: 1 addition & 1 deletion kag/builder/component/extractor/kag_extractor.py
Original file line number Diff line number Diff line change
@@ -13,7 +13,7 @@
import logging
from typing import Dict, Type, List

from kag.common.llm.llm_client import LLMClient
from kag.interface import LLMClient
from tenacity import stop_after_attempt, retry

from kag.interface import ExtractorABC, PromptABC, ExternalGraphLoaderABC
12 changes: 12 additions & 0 deletions kag/builder/component/postprocessor/kag_postprocessor.py
Original file line number Diff line number Diff line change
@@ -9,6 +9,7 @@
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import logging
from typing import List
from kag.interface import PostProcessorABC
from kag.interface import ExternalGraphLoaderABC
@@ -19,6 +20,9 @@
from knext.schema.client import SchemaClient


logger = logging.getLogger()


@PostProcessorABC.register("base", as_default=True)
class KAGPostProcessor(PostProcessorABC):
def __init__(
@@ -100,7 +104,15 @@ def external_graph_based_link(self, graph: SubGraph, property_key: str = "name")
self._entity_link(graph, property_key, labels)

def invoke(self, input):
origin_num_nodes = len(input.nodes)
origin_num_edges = len(input.edges)
new_graph = self.filter_invalid_data(input)
self.similarity_based_link(new_graph)
self.external_graph_based_link(new_graph)

new_num_nodes = len(new_graph.nodes)
new_num_edges = len(new_graph.edges)
logger.debug(
f"origin: {origin_num_nodes}/{origin_num_edges}, processed: {new_num_nodes}/{new_num_edges}"
)
return [new_graph]
4 changes: 2 additions & 2 deletions kag/builder/component/reader/dataset_reader.py
Original file line number Diff line number Diff line change
@@ -77,9 +77,9 @@ def invoke(self, input: str, **kwargs) -> List[Output]:
corpusList = input
chunks = []

for item in corpusList:
for idx, item in enumerate(corpusList):
chunk = Chunk(
id=item[id_column],
id=f"{item[id_column]}#{idx}",
name=item[name_column],
content=item[content_column],
)
2 changes: 1 addition & 1 deletion kag/builder/component/reader/docx_reader.py
Original file line number Diff line number Diff line change
@@ -14,7 +14,7 @@
from typing import List, Type, Union

from docx import Document
from kag.common.llm import LLMClient
from kag.interface import LLMClient
from kag.builder.model.chunk import Chunk
from kag.interface import SourceReaderABC
from kag.builder.prompt.outline_prompt import OutlinePrompt
2 changes: 1 addition & 1 deletion kag/builder/component/reader/markdown_reader.py
Original file line number Diff line number Diff line change
@@ -25,7 +25,7 @@

from kag.interface import SourceReaderABC
from kag.builder.model.chunk import Chunk, ChunkTypeEnum
from kag.common.llm import LLMClient
from kag.interface import LLMClient
from kag.common.conf import KAG_PROJECT_CONF
from kag.builder.prompt.analyze_table_prompt import AnalyzeTablePrompt
from knext.common.base.runnable import Output, Input
2 changes: 1 addition & 1 deletion kag/builder/component/reader/pdf_reader.py
Original file line number Diff line number Diff line change
@@ -21,7 +21,7 @@
from kag.interface import SourceReaderABC

from kag.builder.prompt.outline_prompt import OutlinePrompt
from kag.common.llm import LLMClient
from kag.interface import LLMClient
from kag.common.conf import KAG_PROJECT_CONF
from knext.common.base.runnable import Input, Output
from pdfminer.high_level import extract_text
2 changes: 1 addition & 1 deletion kag/builder/component/reader/yuque_reader.py
Original file line number Diff line number Diff line change
@@ -16,7 +16,7 @@
from kag.builder.component.reader.markdown_reader import MarkDownReader
from kag.builder.model.chunk import Chunk
from kag.interface import SourceReaderABC
from kag.common.llm import LLMClient
from kag.interface import LLMClient
from knext.common.base.runnable import Input, Output


2 changes: 1 addition & 1 deletion kag/builder/component/splitter/outline_splitter.py
Original file line number Diff line number Diff line change
@@ -16,7 +16,7 @@
from kag.builder.prompt.outline_prompt import OutlinePrompt
from kag.builder.model.chunk import Chunk
from kag.common.conf import KAG_PROJECT_CONF
from kag.common.llm import LLMClient
from kag.interface import LLMClient
from knext.common.base.runnable import Input, Output


2 changes: 1 addition & 1 deletion kag/builder/component/splitter/semantic_splitter.py
Original file line number Diff line number Diff line change
@@ -16,7 +16,7 @@
from kag.interface import SplitterABC
from kag.builder.prompt.semantic_seg_prompt import SemanticSegPrompt
from kag.builder.model.chunk import Chunk
from kag.common.llm import LLMClient
from kag.interface import LLMClient
from kag.common.conf import KAG_PROJECT_CONF
from knext.common.base.runnable import Input, Output

10 changes: 5 additions & 5 deletions kag/builder/component/vectorizer/batch_vectorizer.py
Original file line number Diff line number Diff line change
@@ -14,9 +14,9 @@

from kag.builder.model.sub_graph import SubGraph
from kag.common.conf import KAG_PROJECT_CONF
from kag.common.vectorizer import Vectorizer

from kag.common.utils import get_vector_field_name
from kag.interface import VectorizerABC
from kag.interface import VectorizerABC, VectorizeModelABC
from knext.schema.client import SchemaClient
from knext.schema.model.base import IndexTypeEnum
from knext.common.base.runnable import Input, Output
@@ -127,12 +127,12 @@ def batch_generate(self, node_batch, batch_size=1024):

@VectorizerABC.register("batch")
class BatchVectorizer(VectorizerABC):
def __init__(self, vectorizer_model: Vectorizer, batch_size: int = 1024):
def __init__(self, vectorize_model: VectorizeModelABC, batch_size: int = 1024):
super().__init__()
self.project_id = KAG_PROJECT_CONF.project_id
# self._init_graph_store()
self.vec_meta = self._init_vec_meta()
self.vectorizer_model = vectorizer_model
self.vectorize_model = vectorize_model
self.batch_size = batch_size

def _init_vec_meta(self):
@@ -158,7 +158,7 @@ def _generate_embedding_vectors(self, input_subgraph: SubGraph) -> SubGraph:
properties.update(node.properties)
node_list.append((node, properties))
node_batch.append((node.label, properties.copy()))
generator = EmbeddingVectorGenerator(self.vectorizer_model, self.vec_meta)
generator = EmbeddingVectorGenerator(self.vectorize_model, self.vec_meta)
generator.batch_generate(node_batch, self.batch_size)
for (node, properties), (_node_label, new_properties) in zip(
node_list, node_batch
1 change: 1 addition & 0 deletions kag/builder/component/writer/kg_writer.py
Original file line number Diff line number Diff line change
@@ -70,6 +70,7 @@ def invoke(
Returns:
List[Output]: A list of output objects (currently always [None]).
"""

self.client.write_graph(
sub_graph=input.to_dict(),
operation=alter_operation,
1 change: 0 additions & 1 deletion kag/common/llm/__init__.py
Original file line number Diff line number Diff line change
@@ -11,7 +11,6 @@
# or implied.


from kag.common.llm.llm_client import LLMClient
from kag.common.llm.openai_client import OpenAIClient
from kag.common.llm.vllm_client import VLLMClient
from kag.common.llm.ollama_client import OllamaClient
2 changes: 1 addition & 1 deletion kag/common/llm/llm_config_checker.py
Original file line number Diff line number Diff line change
@@ -31,7 +31,7 @@ def check(self, config: str) -> str:
:rtype: str
:raises RuntimeError: if the config is invalid
"""
from kag.common.llm import LLMClient
from kag.interface import LLMClient

config = json.loads(config)
llm_client = LLMClient.from_config(config)
2 changes: 1 addition & 1 deletion kag/common/llm/mock_llm.py
Original file line number Diff line number Diff line change
@@ -12,7 +12,7 @@


import json
from kag.common.llm.llm_client import LLMClient
from kag.interface import LLMClient


@LLMClient.register("mock")
2 changes: 1 addition & 1 deletion kag/common/llm/ollama_client.py
Original file line number Diff line number Diff line change
@@ -15,7 +15,7 @@
import logging
from ollama import Client

from kag.common.llm.llm_client import LLMClient
from kag.interface import LLMClient


# logging.basicConfig(level=logging.DEBUG)
2 changes: 1 addition & 1 deletion kag/common/llm/openai_client.py
Original file line number Diff line number Diff line change
@@ -15,7 +15,7 @@
from openai import OpenAI
import logging

from kag.common.llm.llm_client import LLMClient
from kag.interface import LLMClient

# logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
2 changes: 1 addition & 1 deletion kag/common/llm/vllm_client.py
Original file line number Diff line number Diff line change
@@ -14,7 +14,7 @@
import json
import logging
import requests
from kag.common.llm.llm_client import LLMClient
from kag.interface import LLMClient


# logging.basicConfig(level=logging.DEBUG)
159 changes: 84 additions & 75 deletions kag/common/registry/registrable.py
Original file line number Diff line number Diff line change
@@ -660,89 +660,98 @@ def from_config(
)

registered_subclasses = Registrable._registry.get(cls)
try:
# instantiate object from base class
if registered_subclasses and not constructor_to_call:

as_registrable = cast(Type[Registrable], cls)
default_choice = as_registrable.default_implementation
# call with BaseClass.from_prams, should use `type` to point out which subclasss to use
choice = params.pop("type", default_choice)
choices = as_registrable.list_available()
# if cls has subclass and choice not found in params, we'll instantiate cls itself
if choice is None:
subclass, constructor_name = cls, None
# invalid choice encountered, raise
elif choice not in choices:
message = (
f"{choice} not in acceptable choices for type: {choices}. "
"You should make sure the class is correctly registerd. "
)
raise ConfigurationError(message)

# instantiate object from base class
if registered_subclasses and not constructor_to_call:

as_registrable = cast(Type[Registrable], cls)
default_choice = as_registrable.default_implementation
# call with BaseClass.from_prams, should use `type` to point out which subclasss to use
choice = params.pop("type", default_choice)
choices = as_registrable.list_available()
# if cls has subclass and choice not found in params, we'll instantiate cls itself
if choice is None:
subclass, constructor_name = cls, None
# invalid choice encountered, raise
elif choice not in choices:
message = (
f"{choice} not in acceptable choices for type: {choices}. "
"You should make sure the class is correctly registerd. "
)
raise ConfigurationError(message)

else:
subclass, constructor_name = as_registrable.resolve_class_name(choice)
else:
subclass, constructor_name = as_registrable.resolve_class_name(
choice
)

# See the docstring for an explanation of what's going on here.
if not constructor_name:
constructor_to_inspect = subclass.__init__
constructor_to_call = subclass # type: ignore
else:
constructor_to_inspect = cast(
Callable[..., RegistrableType], getattr(subclass, constructor_name)
)
constructor_to_call = constructor_to_inspect
# See the docstring for an explanation of what's going on here.
if not constructor_name:
constructor_to_inspect = subclass.__init__
constructor_to_call = subclass # type: ignore
else:
constructor_to_inspect = cast(
Callable[..., RegistrableType],
getattr(subclass, constructor_name),
)
constructor_to_call = constructor_to_inspect

retyped_subclass = cast(Type[RegistrableType], subclass)
retyped_subclass = cast(Type[RegistrableType], subclass)

instant = retyped_subclass.from_config(
params=params,
constructor_to_call=constructor_to_call,
constructor_to_inspect=constructor_to_inspect,
)
instant = retyped_subclass.from_config(
params=params,
constructor_to_call=constructor_to_call,
constructor_to_inspect=constructor_to_inspect,
)

setattr(instant, "__register_type__", choice)
setattr(instant, "__original_parameters__", original_params)
# return ins
else:
# pop unused type declaration
register_type = params.pop("type", None)

if not constructor_to_inspect:
constructor_to_inspect = cls.__init__
if not constructor_to_call:
constructor_to_call = cls

if constructor_to_inspect == object.__init__:
# This class does not have an explicit constructor, so don't give it any kwargs.
# Without this logic, create_kwargs will look at object.__init__ and see that
# it takes *args and **kwargs and look for those.
accepts_kwargs, kwargs = False, {}
setattr(instant, "__register_type__", choice)
setattr(instant, "__original_parameters__", original_params)
# return ins
else:
# This class has a constructor, so create kwargs for it.
constructor_to_inspect = cast(
Callable[..., RegistrableType], constructor_to_inspect
)
accepts_kwargs, kwargs = create_kwargs(
constructor_to_inspect,
cls,
params,
# pop unused type declaration
register_type = params.pop("type", None)

if not constructor_to_inspect:
constructor_to_inspect = cls.__init__
if not constructor_to_call:
constructor_to_call = cls

if constructor_to_inspect == object.__init__:
# This class does not have an explicit constructor, so don't give it any kwargs.
# Without this logic, create_kwargs will look at object.__init__ and see that
# it takes *args and **kwargs and look for those.
accepts_kwargs, kwargs = False, {}
else:
# This class has a constructor, so create kwargs for it.
constructor_to_inspect = cast(
Callable[..., RegistrableType], constructor_to_inspect
)
accepts_kwargs, kwargs = create_kwargs(
constructor_to_inspect,
cls,
params,
)

instant = constructor_to_call(**kwargs) # type: ignore
setattr(instant, "__register_type__", register_type)
setattr(
instant,
"__constructor_called__",
functools.partial(constructor_to_call, **kwargs),
)
setattr(instant, "__original_parameters__", original_params)
# if constructor takes kwargs, they can't be infered from constructor. Therefore we should record
# which attrs are created by kwargs to correctly restore the configs by `to_config`.
if accepts_kwargs:
remaining_kwargs = set(params)
params.clear()
setattr(instant, "__from_config_kwargs__", remaining_kwargs)
except Exception as e:
import traceback

instant = constructor_to_call(**kwargs) # type: ignore
setattr(instant, "__register_type__", register_type)
setattr(
instant,
"__constructor_called__",
functools.partial(constructor_to_call, **kwargs),
)
setattr(instant, "__original_parameters__", original_params)
# if constructor takes kwargs, they can't be infered from constructor. Therefore we should record
# which attrs are created by kwargs to correctly restore the configs by `to_config`.
if accepts_kwargs:
remaining_kwargs = set(params)
params.clear()
setattr(instant, "__from_config_kwargs__", remaining_kwargs)
logger.error(f"failed to initialize class {cls}, info: ")
traceback.print_exc()
raise e
if len(params) > 0:
raise ConfigurationError(
f"These params are not used for constructing {cls}:\n{params}"
Loading