From 61c83c1b580125d401964ab902c4a8ab8f40828a Mon Sep 17 00:00:00 2001 From: yinjiaqi Date: Tue, 7 Jan 2025 18:30:40 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9B=B4=E6=96=B0=E9=83=A8=E5=88=86Components?= =?UTF-8?q?=E7=BB=84=E4=BB=B6=E7=9A=84init=E6=96=B9=E6=B3=95,=E6=96=B0?= =?UTF-8?q?=E5=A2=9Einit=E6=96=B9=E6=B3=95=E4=B8=8Etool=5Feval=E7=9A=84kwa?= =?UTF-8?q?rgs=E5=8F=82=E6=95=B0=E6=A3=80=E6=9F=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- python/core/components/llms/base.py | 2 + .../v2/llms/similar_question/component.py | 4 +- .../v2/llms/style_rewrite/component.py | 6 +- .../v2/llms/style_writing/component.py | 4 +- .../core/components/v2/translate/component.py | 3 - python/tests/component_check.py | 65 +++++++++++++++++-- 6 files changed, 73 insertions(+), 11 deletions(-) diff --git a/python/core/components/llms/base.py b/python/core/components/llms/base.py index 346a4774..ba07e259 100644 --- a/python/core/components/llms/base.py +++ b/python/core/components/llms/base.py @@ -272,6 +272,7 @@ def __init__( secret_key: Optional[str] = None, gateway: str = "", lazy_certification: bool = False, + **kwargs ): """ Args: @@ -280,6 +281,7 @@ def __init__( secret_key (Optional[str], optional): 可选的密钥. Defaults to None. gateway (str, optional): 网关地址. Defaults to "". lazy_certification (bool, optional): 延迟认证,为True时在第一次运行时认证. Defaults to False. + **kwargs: 其他关键字参数 """ super(CompletionBaseComponent, self).__init__( diff --git a/python/core/components/v2/llms/similar_question/component.py b/python/core/components/v2/llms/similar_question/component.py index 4b200bd3..1f1d2fa1 100644 --- a/python/core/components/v2/llms/similar_question/component.py +++ b/python/core/components/v2/llms/similar_question/component.py @@ -72,6 +72,7 @@ def __init__( secret_key: Optional[str] = None, gateway: str = "", lazy_certification: bool = True, + **kwargs ): """初始化StyleRewrite模型。 @@ -80,6 +81,7 @@ def __init__( secret_key (str, 可选): 用户鉴权token, 默认从环境变量中获取: os.getenv("APPBUILDER_TOKEN", ""). gateway (str, 可选): 后端网关服务地址,默认从环境变量中获取: os.getenv("GATEWAY_URL", "") lazy_certification (bool, 可选): 延迟认证,为True时在第一次运行时认证. Defaults to False. + **kwargs: 其他关键字参数. Returns: None @@ -87,7 +89,7 @@ def __init__( """ super().__init__( SimilarQuestionMeta, model=model, secret_key=secret_key, gateway=gateway, - lazy_certification=lazy_certification) + lazy_certification=lazy_certification, **kwargs) @components_run_trace def run(self, message, stream=False, temperature=1e-10, top_p=0.0, request_id=None): diff --git a/python/core/components/v2/llms/style_rewrite/component.py b/python/core/components/v2/llms/style_rewrite/component.py index 051325c1..bc6b590a 100644 --- a/python/core/components/v2/llms/style_rewrite/component.py +++ b/python/core/components/v2/llms/style_rewrite/component.py @@ -69,6 +69,7 @@ def __init__( secret_key: Optional[str] = None, gateway: str = "", lazy_certification: bool = True, + **kwargs ): """初始化StyleRewrite模型。 @@ -77,14 +78,15 @@ def __init__( secret_key (str, 可选): 用户鉴权token, 默认从环境变量中获取: os.getenv("APPBUILDER_TOKEN", ""). gateway (str, 可选): 后端网关服务地址,默认从环境变量中获取: os.getenv("GATEWAY_URL", "") lazy_certification (bool, 可选): 延迟认证,为True时在第一次运行时认证. Defaults to False. - + **kwargs: 其他关键字参数. + Returns: None """ super().__init__( StyleRewriteArgs, model=model, secret_key=secret_key, gateway=gateway, - lazy_certification=lazy_certification) + lazy_certification=lazy_certification, **kwargs) @components_run_trace def run(self, message, style="营销话术", stream=False, temperature=1e-10, top_p=0.0, request_id=None): diff --git a/python/core/components/v2/llms/style_writing/component.py b/python/core/components/v2/llms/style_writing/component.py index 573de925..b0566bbe 100644 --- a/python/core/components/v2/llms/style_writing/component.py +++ b/python/core/components/v2/llms/style_writing/component.py @@ -99,6 +99,7 @@ def __init__( secret_key: Optional[str] = None, gateway: str = "", lazy_certification: bool = True, + **kwargs ): """初始化StyleRewrite模型。 @@ -107,6 +108,7 @@ def __init__( secret_key (str, 可选): 用户鉴权token, 默认从环境变量中获取: os.getenv("APPBUILDER_TOKEN", ""). gateway (str, 可选): 后端网关服务地址,默认从环境变量中获取: os.getenv("GATEWAY_URL", "") lazy_certification (bool, 可选): 延迟认证,为True时在第一次运行时认证. Defaults to False. + **kwargs: 其他关键字参数. Returns: None @@ -114,7 +116,7 @@ def __init__( """ super().__init__( StyleWritingArgs, model=model, secret_key=secret_key, gateway=gateway, - lazy_certification=lazy_certification) + lazy_certification=lazy_certification, **kwargs) @components_run_trace def run(self, message, style_query="通用", length=100, stream=False, temperature=1e-10, top_p=0, request_id=None): diff --git a/python/core/components/v2/translate/component.py b/python/core/components/v2/translate/component.py index abe7c75c..3ded4419 100644 --- a/python/core/components/v2/translate/component.py +++ b/python/core/components/v2/translate/component.py @@ -82,9 +82,6 @@ class Translation(Component): } ] - def __init__(self, **kwargs): - super().__init__(**kwargs) - @HTTPClient.check_param @components_run_trace def run(self, message: Message, from_lang: str = "auto", to_lang: str = "en", diff --git a/python/tests/component_check.py b/python/tests/component_check.py index 0351512c..abad933f 100644 --- a/python/tests/component_check.py +++ b/python/tests/component_check.py @@ -38,15 +38,15 @@ def register_rule(self, rule_name: str, rule_obj: RuleBase): def remove_rule(self, rule_name: str): del self.rules[rule_name] - def notify(self, component_cls, component_case) -> tuple[bool, list]: + def notify(self, component_obj, component_case) -> tuple[bool, list]: check_pass = True check_details = {} reasons = [] for rule_name, rule_obj in self.rules.items(): if rule_name == "ToolEvalOutputJsonRule": - res = rule_obj.check(component_cls, component_case) + res = rule_obj.check(component_obj, component_case) else: - res = rule_obj.check(component_cls) + res = rule_obj.check(component_obj) check_details[rule_name] = res if res.check_result == False: check_pass = False @@ -119,6 +119,61 @@ def check(self, component_obj) -> CheckInfo: 'null': None, } +class InitKwargsRule(RuleBase): + def __init__(self): + super().__init__() + self.rule_name = "InitKwargsRule" + + def _accepts_kwargs(self, func): + """ + 检查函数是否接受 **kwargs 参数。 + """ + sig = inspect.signature(func) + params = sig.parameters + return any(p.kind == inspect.Parameter.VAR_KEYWORD for p in params.values()) + + def check(self, component_obj) -> CheckInfo: + if not self._accepts_kwargs(component_obj.__init__): + return CheckInfo( + check_rule_name=self.rule_name, + check_result=False, + check_detail="组件的__init__初始化方法需要添加**kwargs参数" + ) + else: + return CheckInfo( + check_rule_name=self.rule_name, + check_result=True, + check_detail="" + ) + +class ToolEvalKwargsRule(RuleBase): + def __init__(self): + super().__init__() + self.rule_name = "ToolEvalKwargsRule" + + def _accepts_kwargs(self, func): + """ + 检查函数是否接受 **kwargs 参数。 + """ + sig = inspect.signature(func) + params = sig.parameters + return any(p.kind == inspect.Parameter.VAR_KEYWORD for p in params.values()) + + def check(self, component_obj) -> CheckInfo: + if not self._accepts_kwargs(component_obj.tool_eval): + return CheckInfo( + check_rule_name=self.rule_name, + check_result=False, + check_detail="组件的__init__初始化方法需要添加**kwargs参数" + ) + else: + return CheckInfo( + check_rule_name=self.rule_name, + check_result=True, + check_detail="" + ) + + class MainfestMatchToolEvalRule(RuleBase): def __init__(self): super().__init__() @@ -459,4 +514,6 @@ def write_error_data(txt_file_path, error_df, error_stats): register_component_check_rule("ManifestValidRule", ManifestValidRule) register_component_check_rule("MainfestMatchToolEvalRule", MainfestMatchToolEvalRule) register_component_check_rule("ToolEvalInputNameRule", ToolEvalInputNameRule) -register_component_check_rule("ToolEvalOutputJsonRule", ToolEvalOutputJsonRule) \ No newline at end of file +register_component_check_rule("ToolEvalOutputJsonRule", ToolEvalOutputJsonRule) +register_component_check_rule("InitKwargsRule", InitKwargsRule) +register_component_check_rule("ToolEvalKwargsRule", ToolEvalKwargsRule) \ No newline at end of file