Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

更新部分Components组件的init方法,新增init方法与tool_eval的kwargs参数检查 #709

Merged
merged 3 commits into from
Jan 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/core/components/llms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@ def __init__(
secret_key: Optional[str] = None,
gateway: str = "",
lazy_certification: bool = False,
**kwargs
):
"""
Args:
Expand All @@ -280,6 +281,7 @@ def __init__(
secret_key (Optional[str], optional): 可选的密钥. Defaults to None.
gateway (str, optional): 网关地址. Defaults to "".
lazy_certification (bool, optional): 延迟认证,为True时在第一次运行时认证. Defaults to False.
**kwargs: 其他关键字参数

"""
super(CompletionBaseComponent, self).__init__(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def __init__(
secret_key: Optional[str] = None,
gateway: str = "",
lazy_certification: bool = True,
**kwargs,
):
"""初始化幻觉检测组件。

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def __init__(
secret_key: Optional[str] = None,
gateway: str = "",
lazy_certification: bool = True,
**kwargs
):
"""初始化StyleRewrite模型。

Expand Down
4 changes: 3 additions & 1 deletion python/core/components/v2/llms/similar_question/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def __init__(
secret_key: Optional[str] = None,
gateway: str = "",
lazy_certification: bool = True,
**kwargs
):
"""初始化StyleRewrite模型。

Expand All @@ -80,14 +81,15 @@ def __init__(
secret_key (str, 可选): 用户鉴权token, 默认从环境变量中获取: os.getenv("APPBUILDER_TOKEN", "").
gateway (str, 可选): 后端网关服务地址,默认从环境变量中获取: os.getenv("GATEWAY_URL", "")
lazy_certification (bool, 可选): 延迟认证,为True时在第一次运行时认证. Defaults to False.
**kwargs: 其他关键字参数.

Returns:
None

"""
super().__init__(
SimilarQuestionMeta, model=model, secret_key=secret_key, gateway=gateway,
lazy_certification=lazy_certification)
lazy_certification=lazy_certification, **kwargs)

@components_run_trace
def run(self, message, stream=False, temperature=1e-10, top_p=0.0, request_id=None):
Expand Down
6 changes: 4 additions & 2 deletions python/core/components/v2/llms/style_rewrite/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def __init__(
secret_key: Optional[str] = None,
gateway: str = "",
lazy_certification: bool = True,
**kwargs
):
"""初始化StyleRewrite模型。

Expand All @@ -77,14 +78,15 @@ def __init__(
secret_key (str, 可选): 用户鉴权token, 默认从环境变量中获取: os.getenv("APPBUILDER_TOKEN", "").
gateway (str, 可选): 后端网关服务地址,默认从环境变量中获取: os.getenv("GATEWAY_URL", "")
lazy_certification (bool, 可选): 延迟认证,为True时在第一次运行时认证. Defaults to False.

**kwargs: 其他关键字参数.

Returns:
None

"""
super().__init__(
StyleRewriteArgs, model=model, secret_key=secret_key, gateway=gateway,
lazy_certification=lazy_certification)
lazy_certification=lazy_certification, **kwargs)

@components_run_trace
def run(self, message, style="营销话术", stream=False, temperature=1e-10, top_p=0.0, request_id=None):
Expand Down
4 changes: 3 additions & 1 deletion python/core/components/v2/llms/style_writing/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def __init__(
secret_key: Optional[str] = None,
gateway: str = "",
lazy_certification: bool = True,
**kwargs
):
"""初始化StyleRewrite模型。

Expand All @@ -107,14 +108,15 @@ def __init__(
secret_key (str, 可选): 用户鉴权token, 默认从环境变量中获取: os.getenv("APPBUILDER_TOKEN", "").
gateway (str, 可选): 后端网关服务地址,默认从环境变量中获取: os.getenv("GATEWAY_URL", "")
lazy_certification (bool, 可选): 延迟认证,为True时在第一次运行时认证. Defaults to False.
**kwargs: 其他关键字参数.

Returns:
None

"""
super().__init__(
StyleWritingArgs, model=model, secret_key=secret_key, gateway=gateway,
lazy_certification=lazy_certification)
lazy_certification=lazy_certification, **kwargs)

@components_run_trace
def run(self, message, style_query="通用", length=100, stream=False, temperature=1e-10, top_p=0, request_id=None):
Expand Down
3 changes: 0 additions & 3 deletions python/core/components/v2/translate/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,6 @@ class Translation(Component):
}
]

def __init__(self, **kwargs):
super().__init__(**kwargs)

@HTTPClient.check_param
@components_run_trace
def run(self, message: Message, from_lang: str = "auto", to_lang: str = "en",
Expand Down
65 changes: 61 additions & 4 deletions python/tests/component_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,15 @@ def register_rule(self, rule_name: str, rule_obj: RuleBase):
def remove_rule(self, rule_name: str):
del self.rules[rule_name]

def notify(self, component_cls, component_case) -> tuple[bool, list]:
def notify(self, component_obj, component_case) -> tuple[bool, list]:
check_pass = True
check_details = {}
reasons = []
for rule_name, rule_obj in self.rules.items():
if rule_name == "ToolEvalOutputJsonRule":
res = rule_obj.check(component_cls, component_case)
res = rule_obj.check(component_obj, component_case)
else:
res = rule_obj.check(component_cls)
res = rule_obj.check(component_obj)
check_details[rule_name] = res
if res.check_result == False:
check_pass = False
Expand Down Expand Up @@ -119,6 +119,61 @@ def check(self, component_obj) -> CheckInfo:
'null': None,
}

class InitKwargsRule(RuleBase):
def __init__(self):
super().__init__()
self.rule_name = "InitKwargsRule"

def _accepts_kwargs(self, func):
"""
检查函数是否接受 **kwargs 参数。
"""
sig = inspect.signature(func)
params = sig.parameters
return any(p.kind == inspect.Parameter.VAR_KEYWORD for p in params.values())

def check(self, component_obj) -> CheckInfo:
if not self._accepts_kwargs(component_obj.__init__):
return CheckInfo(
check_rule_name=self.rule_name,
check_result=False,
check_detail="组件的__init__初始化方法需要添加**kwargs参数"
)
else:
return CheckInfo(
check_rule_name=self.rule_name,
check_result=True,
check_detail=""
)

class ToolEvalKwargsRule(RuleBase):
def __init__(self):
super().__init__()
self.rule_name = "ToolEvalKwargsRule"

def _accepts_kwargs(self, func):
"""
检查函数是否接受 **kwargs 参数。
"""
sig = inspect.signature(func)
params = sig.parameters
return any(p.kind == inspect.Parameter.VAR_KEYWORD for p in params.values())

def check(self, component_obj) -> CheckInfo:
if not self._accepts_kwargs(component_obj.tool_eval):
return CheckInfo(
check_rule_name=self.rule_name,
check_result=False,
check_detail="组件的__init__初始化方法需要添加**kwargs参数"
)
else:
return CheckInfo(
check_rule_name=self.rule_name,
check_result=True,
check_detail=""
)


class MainfestMatchToolEvalRule(RuleBase):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -459,4 +514,6 @@ def write_error_data(txt_file_path, error_df, error_stats):
register_component_check_rule("ManifestValidRule", ManifestValidRule)
register_component_check_rule("MainfestMatchToolEvalRule", MainfestMatchToolEvalRule)
register_component_check_rule("ToolEvalInputNameRule", ToolEvalInputNameRule)
register_component_check_rule("ToolEvalOutputJsonRule", ToolEvalOutputJsonRule)
register_component_check_rule("ToolEvalOutputJsonRule", ToolEvalOutputJsonRule)
register_component_check_rule("InitKwargsRule", InitKwargsRule)
register_component_check_rule("ToolEvalKwargsRule", ToolEvalKwargsRule)
Loading