-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
#19 Support multiple translation processing at once with persistently loaded model #28
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
|
||
SSUUMMMMAARRYY OOFF LLEESSSS CCOOMMMMAANNDDSS | ||
|
||
Commands marked with * may be preceded by a number, _N. | ||
Notes in parentheses indicate the behavior if _N is given. | ||
A key preceded by a caret indicates the Ctrl key; thus ^K is ctrl-K. |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,14 +7,14 @@ | |
import json | ||
import logging | ||
import os | ||
from contextlib import contextmanager | ||
import threading | ||
from copy import deepcopy | ||
from time import perf_counter | ||
from typing import TypedDict | ||
|
||
import ctranslate2 | ||
from sentencepiece import SentencePieceProcessor | ||
from util import clean_text | ||
from lib.util import clean_text | ||
|
||
|
||
GPU_ACCELERATED = os.getenv("COMPUTE_DEVICE", "CPU") != "CPU" | ||
|
||
|
@@ -29,46 +29,41 @@ class TranslateRequest(TypedDict): | |
if os.getenv("CI") is not None: | ||
ctranslate2.set_random_seed(420) | ||
|
||
|
||
@contextmanager | ||
def translate_context(config: dict): | ||
try: | ||
tokenizer = SentencePieceProcessor() | ||
tokenizer.Load(os.path.join(config["loader"]["model_path"], config["tokenizer_file"])) | ||
|
||
translator = ctranslate2.Translator( | ||
**{ | ||
"device": "cuda" if GPU_ACCELERATED else "cpu", | ||
**config["loader"], | ||
} | ||
) | ||
except KeyError as e: | ||
raise Exception("Incorrect config file, ensure all required keys are present from the default config") from e | ||
except Exception as e: | ||
raise Exception("Error loading the translation model") from e | ||
|
||
try: | ||
start = perf_counter() | ||
yield (tokenizer, translator) | ||
elapsed = perf_counter() - start | ||
logger.info(f"time taken: {elapsed:.2f}s") | ||
except Exception as e: | ||
raise Exception("Error translating the input text") from e | ||
finally: | ||
del tokenizer | ||
# todo: offload to cpu? | ||
del translator | ||
|
||
#Removed the translate_context function in favor of initializing tokenizer and translator directly in the Service class | ||
#Simplifies the code by avoiding repeated creation and deletion of these objects for every translation request. | ||
|
||
class Service: | ||
def __init__(self, config: dict): | ||
# Used a threading lock here to ensure thread safety when processing concurrent translation requests. | ||
self._lock=threading.Lock() | ||
global logger | ||
try: | ||
log_level = config.get("log_level", logging.INFO) | ||
if isinstance(log_level, str): | ||
log_level = getattr(logging, log_level.upper(), logging.INFO) | ||
logger.setLevel(log_level) | ||
|
||
ctranslate2.set_log_level(log_level) | ||
logger.setLevel(log_level) | ||
self.load_config(config) | ||
Comment on lines
+41
to
48
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
ctranslate2.set_log_level(config["log_level"]) | ||
logger.setLevel(config["log_level"]) | ||
|
||
with open("languages.json") as f: | ||
self.tokenizer= SentencePieceProcessor() | ||
self.tokenizer.Load(os.path.join(self.config["loader"]["model_path"],config["tokenizer_file"])) | ||
|
||
self.translator = ctranslate2.Translator( | ||
**{ | ||
"device": "cuda" if GPU_ACCELERATED else "cpu", | ||
**self.config["loader"], | ||
} | ||
) | ||
|
||
#Updated the path resolution for languages.json | ||
# to ensure it works regardless of the current working directory. | ||
languages_path = os.path.join(os.path.dirname(__file__), "..", "languages.json") | ||
if not os.path.exists(languages_path): | ||
raise FileNotFoundError(f"languages.json not found at {languages_path}") | ||
|
||
with open(languages_path,) as f: | ||
self.languages = json.loads(f.read()) | ||
except Exception as e: | ||
raise Exception( | ||
|
@@ -80,29 +75,45 @@ def get_languages(self) -> dict[str, str]: | |
|
||
def load_config(self, config: dict): | ||
config_copy = deepcopy(config) | ||
config_copy["loader"].pop("model_name", None) | ||
# Remove 'model_name' and resolve it if necessary | ||
if "model_name" in config_copy["loader"]: | ||
model_name = config_copy["loader"].pop("model_name") | ||
# Assuming "model_path" is dynamically resolved | ||
# For example, you can set it to a default local directory | ||
resolved_model_path = os.path.join("models", model_name.replace("/", "_")) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cool, but let's make these changes to only model_path and not model_name and let's use the "persistent_storage"/models path for this purpose. also, it would good to not modify the model_paths that don't already are prefixed with "persistent_storage"/models or |
||
if not os.path.exists(resolved_model_path): | ||
raise Exception( | ||
f"Model '{model_name}' not found. Please download or set up the model at {resolved_model_path}." | ||
) | ||
config_copy["loader"]["model_path"] = resolved_model_path | ||
elif "model_path" not in config_copy["loader"]: | ||
raise KeyError("The configuration must contain either 'model_name' or 'model_path' under the 'loader' key.") | ||
self.config = config_copy | ||
|
||
if "hf_model_path" in config_copy["loader"]: | ||
config_copy["loader"]["model_path"] = config_copy["loader"].pop("hf_model_path") | ||
Comment on lines
-85
to
-86
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this needs to be added back |
||
|
||
self.config = config_copy | ||
|
||
def translate(self, data: TranslateRequest) -> str: | ||
|
||
logger.debug(f"translating text to: {data['target_language']}") | ||
|
||
with translate_context(self.config) as (tokenizer, translator): | ||
input_tokens = tokenizer.Encode(f"<2{data['target_language']}> {clean_text(data['input'])}", out_type=str) | ||
results = translator.translate_batch( | ||
with self._lock: | ||
input_tokens = self.tokenizer.Encode( | ||
f"<2{data['target_language']}> {clean_text(data['input'])}",out_type=str | ||
) | ||
results = self.translator.translate_batch( | ||
[input_tokens], | ||
batch_type="tokens", | ||
**self.config["inference"], | ||
) | ||
|
||
if len(results) == 0 or len(results[0].hypotheses) == 0: | ||
raise Exception("Empty result returned from translator") | ||
|
||
# todo: handle multiple hypotheses | ||
translation = tokenizer.Decode(results[0].hypotheses[0]) | ||
|
||
translation = self.tokenizer.Decode(results[0].hypotheses[0]) | ||
logger.debug(f"Translated string: {translation}") | ||
return translation | ||
|
||
def close(self): | ||
# Cleanup resources during service shutdown. | ||
del self.tokenizer | ||
del self.translator | ||
logger.info("Service resources released.") | ||
Comment on lines
+115
to
+119
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. when an ex-app is disabled, the container is shutdown so this cleanup is not required |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,10 +2,10 @@ | |
# SPDX-FileCopyrightText: 2024 Nextcloud GmbH and Nextcloud contributors | ||
# SPDX-License-Identifier: MIT | ||
# | ||
"""The main module of the translate2 app""" | ||
|
||
import logging | ||
"""The main module of the translate2 app""" | ||
import os | ||
import logging | ||
import threading | ||
import traceback | ||
from contextlib import asynccontextmanager | ||
|
@@ -18,56 +18,58 @@ | |
from nc_py_api import AsyncNextcloudApp, NextcloudApp, NextcloudException | ||
from nc_py_api.ex_app import LogLvl, run_app, set_handlers | ||
from nc_py_api.ex_app.providers.task_processing import ShapeEnumValue, TaskProcessingProvider | ||
from Service import Service, TranslateRequest | ||
from util import load_config_file, save_config_file | ||
|
||
#Instead of "from Service import Service, TranslateRequest" used "from lib.Service import Service, TranslateRequest" | ||
from lib.Service import Service, TranslateRequest | ||
from lib.util import load_config_file | ||
|
||
import concurrent.futures | ||
Comment on lines
+22
to
+26
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please remove those comments from the code and add them to the extended commit messages |
||
|
||
|
||
load_dotenv() | ||
|
||
config = load_config_file() | ||
|
||
|
||
|
||
# logging config | ||
logging.basicConfig() | ||
logger = logging.getLogger(__name__) | ||
logger.setLevel(config["log_level"]) | ||
|
||
|
||
class ModelConfig(dict): | ||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
|
||
def __setitem__(self, key, value): | ||
if key == "path": | ||
config["loader"]["hf_model_path"] = value | ||
save_config_file(config) | ||
|
||
super().__setitem__(key, value) | ||
|
||
# Removed ModelConfig and Simplified code by directly handling model configuration with the config dictionary. | ||
|
||
# download models if "model_name" key is present in the config | ||
models_to_fetch = None | ||
cache_dir = os.getenv("APP_PERSISTENT_STORAGE", "models/") | ||
if "model_name" in config["loader"]: | ||
models_to_fetch = { config["loader"]["model_name"]: ModelConfig({ "cache_dir": cache_dir }) } | ||
models_to_fetch = { config["loader"]["model_name"]: { "cache_dir": cache_dir } } | ||
|
||
|
||
app_enabled = threading.Event() | ||
@asynccontextmanager | ||
async def lifespan(_: FastAPI): | ||
global app_enabled | ||
set_handlers( | ||
fast_api_app=APP, | ||
enabled_handler=enabled_handler, # type: ignore | ||
models_to_fetch=models_to_fetch, # type: ignore | ||
) | ||
service= Service(config) | ||
try: | ||
set_handlers( | ||
fast_api_app=APP, | ||
enabled_handler=enabled_handler, # type: ignore | ||
models_to_fetch=models_to_fetch, # type: ignore | ||
) | ||
|
||
nc = NextcloudApp() | ||
if nc.enabled_state: | ||
app_enabled.set() | ||
worker = threading.Thread(target=task_fetch_thread, args=(Service(config),)) | ||
worker.start() | ||
nc = NextcloudApp() | ||
if nc.enabled_state: | ||
app_enabled.set() | ||
worker = threading.Thread(target=task_fetch_thread, args=(service,)) | ||
worker.start() | ||
|
||
yield | ||
app_enabled.clear() | ||
yield | ||
finally: | ||
#clean up the Service instance and background workers. | ||
app_enabled.clear() | ||
service.close() | ||
|
||
|
||
APP_ID = "translate2" | ||
|
@@ -105,41 +107,54 @@ def task_fetch_thread(service: Service): | |
global app_enabled | ||
|
||
nc = NextcloudApp() | ||
while True: | ||
if not app_enabled.is_set(): | ||
logger.debug("Shutting down task fetch worker, app not enabled") | ||
break | ||
|
||
try: | ||
task = nc.providers.task_processing.next_task([APP_ID], [TASK_TYPE_ID]) | ||
except Exception as e: | ||
logger.error(f"Error fetching task: {e}") | ||
sleep(IDLE_POLLING_INTERVAL) | ||
continue | ||
|
||
if not task: | ||
logger.debug("No tasks found") | ||
sleep(IDLE_POLLING_INTERVAL) | ||
continue | ||
|
||
logger.debug(f"Processing task: {task}") | ||
|
||
input_ = task.get("task", {}).get("input") | ||
if input_ is None or not isinstance(input_, dict): | ||
logger.error("Invalid task object received, expected task object with input key") | ||
continue | ||
|
||
try: | ||
request = TranslateRequest(**input_) | ||
translation = service.translate(request) | ||
output = translation | ||
nc.providers.task_processing.report_result( | ||
task_id=task["task"]["id"], | ||
output={"output": output}, | ||
) | ||
except Exception as e: | ||
report_error(task, e) | ||
#Used a thread pool for concurrent processing. | ||
with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor: | ||
while True: | ||
if not app_enabled.is_set(): | ||
logger.debug("Shutting down task fetch worker, app not enabled") | ||
break | ||
|
||
try: | ||
task = nc.providers.task_processing.next_task([APP_ID], [TASK_TYPE_ID]) | ||
except Exception as e: | ||
logger.error(f"Error fetching task: {e}") | ||
sleep(IDLE_POLLING_INTERVAL) | ||
continue | ||
|
||
if not task: | ||
logger.debug(f"No tasks found. Sleeping for {IDLE_POLLING_INTERVAL}s") | ||
sleep(IDLE_POLLING_INTERVAL) | ||
continue | ||
|
||
logger.debug(f"Processing task: {task}") | ||
|
||
input_ = task.get("task", {}).get("input") | ||
if input_ is None or not isinstance(input_, dict): | ||
logger.error("Invalid task object received, expected task object with input key") | ||
continue | ||
|
||
try: | ||
request = TranslateRequest(**input_) | ||
executor.submit(process_task, service, nc, task, request) | ||
except Exception as e: | ||
logger.error(f"Error submitting task to executor: {e}") | ||
report_error(task, e) | ||
|
||
def process_task(service: Service, nc: NextcloudApp, task: dict, request: TranslateRequest): | ||
try: | ||
translation = service.translate(request) | ||
nc.providers.task_processing.report_result( | ||
task_id=task["task"]["id"], | ||
output={"output": translation}, | ||
) | ||
logger.info(f"Successfully processed task {task['task']['id']}") | ||
except NextcloudException as e: | ||
logger.error(f"Nextcloud exception while reporting result: {e}") | ||
report_error(task, e) | ||
except Exception as e: | ||
report_error(task, e) | ||
|
||
enabled_lock = threading.Lock() | ||
|
||
async def enabled_handler(enabled: bool, nc: AsyncNextcloudApp) -> str: | ||
global app_enabled | ||
|
@@ -181,8 +196,9 @@ async def enabled_handler(enabled: bool, nc: AsyncNextcloudApp) -> str: | |
|
||
if not app_enabled.is_set(): | ||
app_enabled.set() | ||
worker = threading.Thread(target=task_fetch_thread, args=(service,)) | ||
worker.start() | ||
if not hasattr(service, "_worker"): | ||
service._worker = threading.Thread(target=task_fetch_thread, args=(service,)) | ||
service._worker.start() | ||
Comment on lines
-184
to
+201
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is not really required since this |
||
|
||
return "" | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
extra file