From f9b3afd238056e1ca7e61ae499c5e07a4deb5205 Mon Sep 17 00:00:00 2001 From: Chengmo Date: Sat, 17 Aug 2024 12:14:54 +0800 Subject: [PATCH] Component Text2Image add float math judge (#474) * add math judge * update * update ut * update doc --- appbuilder/core/components/text_to_image/component.py | 6 ++++-- appbuilder/core/components/text_to_image/model.py | 8 ++++---- appbuilder/tests/test_text_to_image.py | 2 +- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/appbuilder/core/components/text_to_image/component.py b/appbuilder/core/components/text_to_image/component.py index 56cd34aba..660aa14c3 100644 --- a/appbuilder/core/components/text_to_image/component.py +++ b/appbuilder/core/components/text_to_image/component.py @@ -16,6 +16,7 @@ """ import time import json +import math from appbuilder.core.component import Component from appbuilder.core.message import Message @@ -105,9 +106,10 @@ def run( request.task_id = taskId text2ImageQueryResponse = self.queryText2ImageData(request, request_id=request_id) if text2ImageQueryResponse.data.task_progress is not None: - task_progress = text2ImageQueryResponse.data.task_progress - if task_progress == 1: + task_progress = float(text2ImageQueryResponse.data.task_progress) + if math.isclose(1.0, task_progress, rel_tol=1e-9, abs_tol=0.0): break + # NOTE(chengmo):文生图组件的返回时间在10s以上,查询过于频繁会被限流,导致异常报错 # 此处采用 yangyongzhen老师提供的方案,前三次查询间隔3s,后三次查询间隔逐渐增大 if task_request_time <= 3: diff --git a/appbuilder/core/components/text_to_image/model.py b/appbuilder/core/components/text_to_image/model.py index 8e74dc500..ccca09b99 100644 --- a/appbuilder/core/components/text_to_image/model.py +++ b/appbuilder/core/components/text_to_image/model.py @@ -186,8 +186,8 @@ class Text2ImageQueryData(proto.Message): 任务 ID. task_status(str): 计算总状态。有 INIT(初始化),WAIT(排队中), RUNNING(生成中), FAILED(失败), SUCCESS(成功)四种状态,只有 SUCCESS 为成功状态。 - task_progress(int): - 图片生成总进度,进度包含2种,0为未处理完,1为处理完成。 + task_progress(float): + 图片生成总进度,0到1之间的浮点数表示进度,0为未处理完,1为处理完成。 sub_task_result_list(Text2ImageSubTaskResultList): 子任务生成结果列表。 """ @@ -216,8 +216,8 @@ class Text2ImageSubTaskResultList(proto.Message): 参数: sub_task_status(int): 单风格图片状态。有 INIT(初始化),WAIT(排队中), RUNNING(生成中), FAILED(失败), SUCCESS(成功)四种状态,只有 SUCCESS 为成功状态。 - sub_task_progress(int): - 单任务图片生成进度,进度包含2种,0为未处理完,1为处理完成。 + sub_task_progress(float): + 单任务图片生成进度,0到1之间的浮点数表示进度,0为未处理完,1为处理完成。 sub_task_error_code(str): 单风格任务错误码。0:正常;501:文本黄反拦截;201:模型生图失败。 final_image_list(Text2ImageFinalImageList): diff --git a/appbuilder/tests/test_text_to_image.py b/appbuilder/tests/test_text_to_image.py index f29b70f42..ed48fbc43 100644 --- a/appbuilder/tests/test_text_to_image.py +++ b/appbuilder/tests/test_text_to_image.py @@ -85,7 +85,7 @@ def test_extract_img_urls(self): """ response = Text2ImageQueryResponse() - response.data.task_progress = 1 + response.data.task_progress = 1.0 response.data.sub_task_result_list = [{'final_image_list': [{'img_url': 'http://example.com'}]}] img_urls = self.text2Image.extract_img_urls(response) self.assertEqual(img_urls, ['http://example.com'])