Skip to content

Commit

Permalink
fix: #375
Browse files Browse the repository at this point in the history
Signed-off-by: yihong0618 <[email protected]>
  • Loading branch information
yihong0618 committed Jan 28, 2024
1 parent 66e12f1 commit 365d662
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 15 deletions.
5 changes: 4 additions & 1 deletion book_maker/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,8 +403,11 @@ def main():
raise ValueError("`api_base` must be provided when using `deployment_id`")
e.translate_model.set_deployment_id(options.deployment_id)
# TODO refactor, quick fix for gpt4 model
if options.model == "chatgptapi":
print(21232)
e.translate_model.set_gpt35_models()
if options.model == "gpt4":
e.translate_model.set_gpt4_models("gpt4")
e.translate_model.set_gpt4_models()
if options.block_size > 0:
e.block_size = options.block_size

Expand Down
38 changes: 24 additions & 14 deletions book_maker/translator/chatgptapi_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,7 @@ def __init__(
self.system_content = environ.get("OPENAI_API_SYS_MSG") or ""
self.deployment_id = None
self.temperature = temperature
# gpt3 all models for save the limit
my_model_list = [
i["id"] for i in self.openai_client.models.list().model_dump()["data"]
]
model_list = list(set(my_model_list) & set(GPT35_MODEL_LIST))
print(f"Using model list {model_list}")
self.model_list = cycle(model_list)
self.model_list = None

def rotate_key(self):
self.openai_client.api_key = next(self.keys)
Expand Down Expand Up @@ -314,10 +308,26 @@ def set_deployment_id(self, deployment_id):
azure_deployment=self.deployment_id,
)

def set_gpt4_models(self, model="gpt4"):
my_model_list = [
i["id"] for i in self.openai_client.models.list().model_dump()["data"]
]
model_list = list(set(my_model_list) & set(GPT4_MODEL_LIST))
print(f"Using model list {model_list}")
self.model_list = cycle(model_list)
def set_gpt35_models(self):
# gpt3 all models for save the limit
if self.deployment_id:
self.model_list = cycle(["gpt-35-turbo"])
else:
my_model_list = [
i["id"] for i in self.openai_client.models.list().model_dump()["data"]
]
model_list = list(set(my_model_list) & set(GPT35_MODEL_LIST))
print(f"Using model list {model_list}")
self.model_list = cycle(model_list)

def set_gpt4_models(self):
# for issue #375 azure can not use model list
if self.deployment_id:
self.model_list = cycle(["gpt-4"])
else:
my_model_list = [
i["id"] for i in self.openai_client.models.list().model_dump()["data"]
]
model_list = list(set(my_model_list) & set(GPT4_MODEL_LIST))
print(f"Using model list {model_list}")
self.model_list = cycle(model_list)

0 comments on commit 365d662

Please sign in to comment.