Skip to content

Commit

Permalink
add list registered command (#67)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhuzhongshu123 authored Nov 25, 2024
1 parent ee731a5 commit 3460392
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 24 deletions.
1 change: 1 addition & 0 deletions kag/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,3 +220,4 @@
import kag.common.vectorize_model
import kag.common.llm
import kag.solver
import kag.bin.commands
1 change: 0 additions & 1 deletion kag/common/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ def __init__(self):

def initialize(self, **kwargs):
if not self._initialized:
print(f"kwargs = {kwargs}")
self.project_id = kwargs.pop(KAGConstants.KAG_PROJECT_ID_KEY, "1")
self.host_addr = kwargs.pop(
KAGConstants.KAG_PROJECT_HOST_ADDR_KEY, "http://127.0.0.1:8887"
Expand Down
51 changes: 47 additions & 4 deletions kag/common/registry/registrable.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,6 @@ def pop_and_construct_arg(
default: Any,
actual_params: ConfigTree,
) -> Any:

annotation = remove_optional(annotation)
popped_params = (
actual_params.pop(argument_name, default)
Expand All @@ -299,7 +298,6 @@ def construct_arg(
annotation: Type,
default: Any,
) -> Any:

origin = get_origin(annotation)
args = get_args(annotation)

Expand Down Expand Up @@ -615,6 +613,53 @@ def list_available(cls) -> List[str]:
else:
return [default] + [k for k in keys if k != default]

@classmethod
def list_available_with_detail(cls) -> Dict:
"""List default first if it exists"""
register_dict = Registrable._registry[cls]
availables = {}
for k, v in register_dict.items():
params = extract_parameters(v[0], v[1])
required_params = []
optional_params = []
sample_config = {"type": k}
for arg_name, arg_def in params.items():
if arg_name.strip() == "self":
continue
annotation = arg_def.annotation
if annotation == inspect.Parameter.empty:
annotation = None
default = arg_def.default
required = default == inspect.Parameter.empty
# if default == inspect.Parameter.empty:
# default = None
if required:
arg_info = (
f"{arg_name}: {annotation.__name__ if annotation else 'Any'}"
)
required_params.append(arg_info)
else:
arg_info = f"{arg_name}: {annotation.__name__ if annotation else 'Any'} = {default}"
optional_params.append(arg_info)
if required:
sample_config[arg_name] = f"Your {arg_name} config"
else:
sample_config[arg_name] = default

# if default != None:
# sample_config[arg_name] = default
availables[k] = {
"class": f"{v[0].__module__}.{v[0].__name__}",
"doc": inspect.getdoc(v[0]),
"params": {
"required_params": required_params,
"optional_params": optional_params,
},
# "default_config": default_conf,
"sample_useage": f"{cls.__name__}.from_config({sample_config})",
}
return availables

@classmethod
def from_config(
cls: Type[RegistrableType],
Expand Down Expand Up @@ -663,7 +708,6 @@ def from_config(
try:
# instantiate object from base class
if registered_subclasses and not constructor_to_call:

as_registrable = cast(Type[Registrable], cls)
default_choice = as_registrable.default_implementation
# call with BaseClass.from_prams, should use `type` to point out which subclasss to use
Expand Down Expand Up @@ -747,7 +791,6 @@ def from_config(
params.clear()
setattr(instant, "__from_config_kwargs__", remaining_kwargs)
except Exception as e:

logger.warn(f"Failed to initialize class {cls}, info: {e}")
raise e
if len(params) > 0:
Expand Down
7 changes: 3 additions & 4 deletions kag/solver/logic/core_modules/parser/logic_node_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@
SPOBase,
SPOEntity,
SPORelation,
Identifer,
Identifier,
TypeInfo,
LogicNode,
)
from kag.solver.logic.core_modules.common.schema_utils import SchemaUtils

logger = logging.getLogger(__name__)


# get_spg(s, p, o)
class GetSPONode(LogicNode):
def __init__(self, operator, args):
Expand Down Expand Up @@ -346,9 +347,7 @@ def to_dsl(self):
def parse_node(input_str):
params = set(input_str.split(","))
alias_set = [Identifier(p) for p in params]
ex_node = ExtractorNode("extractor", {
"alias_set": alias_set
})
ex_node = ExtractorNode("extractor", {"alias_set": alias_set})
ex_node.alias_set = alias_set
return ex_node

Expand Down
18 changes: 5 additions & 13 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,19 +33,6 @@
license += "#\n"
line = rf.readline()

# Generate kag.__init__.py
with open(os.path.join(cwd, "kag/__init__.py"), "w") as wf:
content = f"""{license}
__package_name__ = "{package_name}"
__version__ = "{version}"
from kag.common.env import init_env
init_env()
"""
wf.write(content)

setup(
name=package_name,
version=version,
Expand Down Expand Up @@ -79,4 +66,9 @@
package_data={
"bin": ["*"],
},
entry_points={
"console_scripts": [
"kag = kag.bin.kag_cmds:main",
]
},
)
10 changes: 8 additions & 2 deletions tests/unit/common/registry/test_registry.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
# -*- coding: utf-8 -*-

import json
from typing import List, Dict, Union
from pyhocon import ConfigTree, ConfigFactory
from kag.common.registry import Registrable, Lazy, Functor
import numpy as np


def test_list_available():
from kag.interface import LLMClient

ava = LLMClient.list_available_with_detail()
print(json.dumps(ava, indent=4))


class MockModel(Registrable):
def __init__(self, name: str = "mock_model"):
self.name = name
Expand Down Expand Up @@ -248,7 +255,6 @@ def test_to_config():


def test_multi_constructor():

# without type key, will use default_implementation
params = ConfigFactory.from_dict({"count": 32})
ins = BaseCount.from_config(params)
Expand Down

0 comments on commit 3460392

Please sign in to comment.