diff --git a/book_maker/cli.py b/book_maker/cli.py index 171ddeb0..50cbc9c8 100644 --- a/book_maker/cli.py +++ b/book_maker/cli.py @@ -403,8 +403,11 @@ def main(): raise ValueError("`api_base` must be provided when using `deployment_id`") e.translate_model.set_deployment_id(options.deployment_id) # TODO refactor, quick fix for gpt4 model + if options.model == "chatgptapi": + print(21232) + e.translate_model.set_gpt35_models() if options.model == "gpt4": - e.translate_model.set_gpt4_models("gpt4") + e.translate_model.set_gpt4_models() if options.block_size > 0: e.block_size = options.block_size diff --git a/book_maker/translator/chatgptapi_translator.py b/book_maker/translator/chatgptapi_translator.py index c9daf32d..849c9201 100644 --- a/book_maker/translator/chatgptapi_translator.py +++ b/book_maker/translator/chatgptapi_translator.py @@ -66,13 +66,7 @@ def __init__( self.system_content = environ.get("OPENAI_API_SYS_MSG") or "" self.deployment_id = None self.temperature = temperature - # gpt3 all models for save the limit - my_model_list = [ - i["id"] for i in self.openai_client.models.list().model_dump()["data"] - ] - model_list = list(set(my_model_list) & set(GPT35_MODEL_LIST)) - print(f"Using model list {model_list}") - self.model_list = cycle(model_list) + self.model_list = None def rotate_key(self): self.openai_client.api_key = next(self.keys) @@ -314,10 +308,26 @@ def set_deployment_id(self, deployment_id): azure_deployment=self.deployment_id, ) - def set_gpt4_models(self, model="gpt4"): - my_model_list = [ - i["id"] for i in self.openai_client.models.list().model_dump()["data"] - ] - model_list = list(set(my_model_list) & set(GPT4_MODEL_LIST)) - print(f"Using model list {model_list}") - self.model_list = cycle(model_list) + def set_gpt35_models(self): + # gpt3 all models for save the limit + if self.deployment_id: + self.model_list = cycle(["gpt-35-turbo"]) + else: + my_model_list = [ + i["id"] for i in self.openai_client.models.list().model_dump()["data"] + ] + model_list = list(set(my_model_list) & set(GPT35_MODEL_LIST)) + print(f"Using model list {model_list}") + self.model_list = cycle(model_list) + + def set_gpt4_models(self): + # for issue #375 azure can not use model list + if self.deployment_id: + self.model_list = cycle(["gpt-4"]) + else: + my_model_list = [ + i["id"] for i in self.openai_client.models.list().model_dump()["data"] + ] + model_list = list(set(my_model_list) & set(GPT4_MODEL_LIST)) + print(f"Using model list {model_list}") + self.model_list = cycle(model_list)