Skip to content

Commit

Permalink
add ollama support
Browse files Browse the repository at this point in the history
  • Loading branch information
b1tg committed Apr 19, 2024
1 parent 205c482 commit 1d3d2d7
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 2 deletions.
20 changes: 19 additions & 1 deletion book_maker/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,14 @@ def main():
metavar="MODEL",
help="model to use, available: {%(choices)s}",
)
parser.add_argument(
"--ollama_model",
dest="ollama_model",
type=str,
default="ollama_model",
metavar="MODEL",
help="use ollama",
)
parser.add_argument(
"--language",
type=str,
Expand Down Expand Up @@ -308,6 +316,9 @@ def main():
):
API_KEY = OPENAI_API_KEY
# patch
elif options.ollama_model:
# any string is ok, can't be empty
API_KEY = "ollama"
else:
raise Exception(
"OpenAI API key not provided, please google how to obtain it",
Expand Down Expand Up @@ -365,6 +376,10 @@ def main():
# change api_base for issue #42
model_api_base = options.api_base

if options.ollama_model and not model_api_base:
# ollama default api_base
model_api_base = "http://localhost:11434/v1"

e = book_loader(
options.book_name,
translate_model,
Expand Down Expand Up @@ -418,7 +433,10 @@ def main():
)
# TODO refactor, quick fix for gpt4 model
if options.model == "chatgptapi":
e.translate_model.set_gpt35_models()
if options.ollama_model:
e.translate_model.set_gpt35_models(ollama_model=options.ollama_model)
else:
e.translate_model.set_gpt35_models()
if options.model == "gpt4":
e.translate_model.set_gpt4_models()
if options.block_size > 0:
Expand Down
5 changes: 4 additions & 1 deletion book_maker/translator/chatgptapi_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,10 @@ def set_deployment_id(self, deployment_id):
azure_deployment=self.deployment_id,
)

def set_gpt35_models(self):
def set_gpt35_models(self, ollama_model=""):
if ollama_model:
self.model_list = cycle([ollama_model])
return
# gpt3 all models for save the limit
if self.deployment_id:
self.model_list = cycle(["gpt-35-turbo"])
Expand Down

0 comments on commit 1d3d2d7

Please sign in to comment.