Skip to content

Commit

Permalink
更新部分Components组件的init方法,新增init方法与tool_eval的kwargs参数检查
Browse files Browse the repository at this point in the history
  • Loading branch information
yinjiaqi authored and yinjiaqi committed Jan 7, 2025
1 parent 9a830a7 commit 61c83c1
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 11 deletions.
2 changes: 2 additions & 0 deletions python/core/components/llms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@ def __init__(
secret_key: Optional[str] = None,
gateway: str = "",
lazy_certification: bool = False,
**kwargs
):
"""
Args:
Expand All @@ -280,6 +281,7 @@ def __init__(
secret_key (Optional[str], optional): 可选的密钥. Defaults to None.
gateway (str, optional): 网关地址. Defaults to "".
lazy_certification (bool, optional): 延迟认证,为True时在第一次运行时认证. Defaults to False.
**kwargs: 其他关键字参数
"""
super(CompletionBaseComponent, self).__init__(
Expand Down
4 changes: 3 additions & 1 deletion python/core/components/v2/llms/similar_question/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def __init__(
secret_key: Optional[str] = None,
gateway: str = "",
lazy_certification: bool = True,
**kwargs
):
"""初始化StyleRewrite模型。
Expand All @@ -80,14 +81,15 @@ def __init__(
secret_key (str, 可选): 用户鉴权token, 默认从环境变量中获取: os.getenv("APPBUILDER_TOKEN", "").
gateway (str, 可选): 后端网关服务地址,默认从环境变量中获取: os.getenv("GATEWAY_URL", "")
lazy_certification (bool, 可选): 延迟认证,为True时在第一次运行时认证. Defaults to False.
**kwargs: 其他关键字参数.
Returns:
None
"""
super().__init__(
SimilarQuestionMeta, model=model, secret_key=secret_key, gateway=gateway,
lazy_certification=lazy_certification)
lazy_certification=lazy_certification, **kwargs)

@components_run_trace
def run(self, message, stream=False, temperature=1e-10, top_p=0.0, request_id=None):
Expand Down
6 changes: 4 additions & 2 deletions python/core/components/v2/llms/style_rewrite/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def __init__(
secret_key: Optional[str] = None,
gateway: str = "",
lazy_certification: bool = True,
**kwargs
):
"""初始化StyleRewrite模型。
Expand All @@ -77,14 +78,15 @@ def __init__(
secret_key (str, 可选): 用户鉴权token, 默认从环境变量中获取: os.getenv("APPBUILDER_TOKEN", "").
gateway (str, 可选): 后端网关服务地址,默认从环境变量中获取: os.getenv("GATEWAY_URL", "")
lazy_certification (bool, 可选): 延迟认证,为True时在第一次运行时认证. Defaults to False.
**kwargs: 其他关键字参数.
Returns:
None
"""
super().__init__(
StyleRewriteArgs, model=model, secret_key=secret_key, gateway=gateway,
lazy_certification=lazy_certification)
lazy_certification=lazy_certification, **kwargs)

@components_run_trace
def run(self, message, style="营销话术", stream=False, temperature=1e-10, top_p=0.0, request_id=None):
Expand Down
4 changes: 3 additions & 1 deletion python/core/components/v2/llms/style_writing/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def __init__(
secret_key: Optional[str] = None,
gateway: str = "",
lazy_certification: bool = True,
**kwargs
):
"""初始化StyleRewrite模型。
Expand All @@ -107,14 +108,15 @@ def __init__(
secret_key (str, 可选): 用户鉴权token, 默认从环境变量中获取: os.getenv("APPBUILDER_TOKEN", "").
gateway (str, 可选): 后端网关服务地址,默认从环境变量中获取: os.getenv("GATEWAY_URL", "")
lazy_certification (bool, 可选): 延迟认证,为True时在第一次运行时认证. Defaults to False.
**kwargs: 其他关键字参数.
Returns:
None
"""
super().__init__(
StyleWritingArgs, model=model, secret_key=secret_key, gateway=gateway,
lazy_certification=lazy_certification)
lazy_certification=lazy_certification, **kwargs)

@components_run_trace
def run(self, message, style_query="通用", length=100, stream=False, temperature=1e-10, top_p=0, request_id=None):
Expand Down
3 changes: 0 additions & 3 deletions python/core/components/v2/translate/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,6 @@ class Translation(Component):
}
]

def __init__(self, **kwargs):
super().__init__(**kwargs)

@HTTPClient.check_param
@components_run_trace
def run(self, message: Message, from_lang: str = "auto", to_lang: str = "en",
Expand Down
65 changes: 61 additions & 4 deletions python/tests/component_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,15 @@ def register_rule(self, rule_name: str, rule_obj: RuleBase):
def remove_rule(self, rule_name: str):
del self.rules[rule_name]

def notify(self, component_cls, component_case) -> tuple[bool, list]:
def notify(self, component_obj, component_case) -> tuple[bool, list]:
check_pass = True
check_details = {}
reasons = []
for rule_name, rule_obj in self.rules.items():
if rule_name == "ToolEvalOutputJsonRule":
res = rule_obj.check(component_cls, component_case)
res = rule_obj.check(component_obj, component_case)
else:
res = rule_obj.check(component_cls)
res = rule_obj.check(component_obj)
check_details[rule_name] = res
if res.check_result == False:
check_pass = False
Expand Down Expand Up @@ -119,6 +119,61 @@ def check(self, component_obj) -> CheckInfo:
'null': None,
}

class InitKwargsRule(RuleBase):
def __init__(self):
super().__init__()
self.rule_name = "InitKwargsRule"

def _accepts_kwargs(self, func):
"""
检查函数是否接受 **kwargs 参数。
"""
sig = inspect.signature(func)
params = sig.parameters
return any(p.kind == inspect.Parameter.VAR_KEYWORD for p in params.values())

def check(self, component_obj) -> CheckInfo:
if not self._accepts_kwargs(component_obj.__init__):
return CheckInfo(
check_rule_name=self.rule_name,
check_result=False,
check_detail="组件的__init__初始化方法需要添加**kwargs参数"
)
else:
return CheckInfo(
check_rule_name=self.rule_name,
check_result=True,
check_detail=""
)

class ToolEvalKwargsRule(RuleBase):
def __init__(self):
super().__init__()
self.rule_name = "ToolEvalKwargsRule"

def _accepts_kwargs(self, func):
"""
检查函数是否接受 **kwargs 参数。
"""
sig = inspect.signature(func)
params = sig.parameters
return any(p.kind == inspect.Parameter.VAR_KEYWORD for p in params.values())

def check(self, component_obj) -> CheckInfo:
if not self._accepts_kwargs(component_obj.tool_eval):
return CheckInfo(
check_rule_name=self.rule_name,
check_result=False,
check_detail="组件的__init__初始化方法需要添加**kwargs参数"
)
else:
return CheckInfo(
check_rule_name=self.rule_name,
check_result=True,
check_detail=""
)


class MainfestMatchToolEvalRule(RuleBase):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -459,4 +514,6 @@ def write_error_data(txt_file_path, error_df, error_stats):
register_component_check_rule("ManifestValidRule", ManifestValidRule)
register_component_check_rule("MainfestMatchToolEvalRule", MainfestMatchToolEvalRule)
register_component_check_rule("ToolEvalInputNameRule", ToolEvalInputNameRule)
register_component_check_rule("ToolEvalOutputJsonRule", ToolEvalOutputJsonRule)
register_component_check_rule("ToolEvalOutputJsonRule", ToolEvalOutputJsonRule)
register_component_check_rule("InitKwargsRule", InitKwargsRule)
register_component_check_rule("ToolEvalKwargsRule", ToolEvalKwargsRule)

0 comments on commit 61c83c1

Please sign in to comment.