diff --git a/apps/shark_studio/tests/api_test.py b/apps/shark_studio/tests/api_test.py index c88a1e70cb..f9fa23df4f 100644 --- a/apps/shark_studio/tests/api_test.py +++ b/apps/shark_studio/tests/api_test.py @@ -6,8 +6,26 @@ import logging import unittest +import json + from apps.shark_studio.api.llm import LanguageModel +from apps.shark_studio.api.sd import shark_sd_fn_dict_input, view_json_file +from apps.shark_studio.web.utils.file_utils import get_resource_path + +class SDAPITest(unittest.TestCase): + def testSDSimple(self): + from apps.shark_studio.modules.shared_cmd_opts import cmd_opts + import apps.shark_studio.web.utils.globals as global_obj + + global_obj._init() + sd_json = view_json_file(get_resource_path("../configs/default_sd_config.json")) + sd_kwargs = json.loads(sd_json) + for arg in vars(cmd_opts): + if arg in sd_kwargs: + sd_kwargs[arg] = getattr(cmd_opts, arg) + for i in shark_sd_fn_dict_input(sd_kwargs): + print(i) class LLMAPITest(unittest.TestCase): def testLLMSimple(self):