diff --git a/plugins/flytekit-openai/flytekitplugins/openai/batch/task.py b/plugins/flytekit-openai/flytekitplugins/openai/batch/task.py index f5054582dd..695e8882e6 100644 --- a/plugins/flytekit-openai/flytekitplugins/openai/batch/task.py +++ b/plugins/flytekit-openai/flytekitplugins/openai/batch/task.py @@ -33,7 +33,7 @@ def __init__( self, name: str, openai_organization: str, - config: Dict[str, Any] = {}, + config: Dict[str, Any], **kwargs, ): super().__init__( diff --git a/plugins/flytekit-openai/flytekitplugins/openai/batch/workflow.py b/plugins/flytekit-openai/flytekitplugins/openai/batch/workflow.py index 1f0ff30b51..209bd0d981 100644 --- a/plugins/flytekit-openai/flytekitplugins/openai/batch/workflow.py +++ b/plugins/flytekit-openai/flytekitplugins/openai/batch/workflow.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Iterator +from typing import Any, Dict, Iterator, Optional from flytekit import Resources, Workflow from flytekit.models.security import Secret @@ -18,7 +18,7 @@ def create_batch( name: str, openai_organization: str, secret: Secret, - config: Dict[str, Any] = {}, + config: Optional[Dict[str, Any]] = None, is_json_iterator: bool = True, file_upload_mem: str = "700Mi", file_download_mem: str = "700Mi", @@ -45,6 +45,8 @@ def create_batch( name=f"openai-file-upload-{name.replace('.', '')}", task_config=OpenAIFileConfig(openai_organization=openai_organization, secret=secret), ) + if config is None: + config = {} batch_endpoint_task_obj = BatchEndpointTask( name=f"openai-batch-{name.replace('.', '')}", openai_organization=openai_organization,