Skip to content

Commit

Permalink
auto unload models if models_ttl is reached
Browse files Browse the repository at this point in the history
  • Loading branch information
hcharbonnier authored and hcharbonnier committed Feb 4, 2025
1 parent 126bb08 commit f7a4cbb
Show file tree
Hide file tree
Showing 13 changed files with 103 additions and 16 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,8 @@ Colorizer: **mc2**
--kernel-size KERNEL_SIZE Set the convolution kernel size of the text erasure area to
completely clean up text residues
--config-file CONFIG_FILE path to the config file
--models-ttl MODELS_TTL How long to keep models in memory in seconds after last use (0 means
forever)
```

<!-- Auto generated end -->
Expand Down
2 changes: 2 additions & 0 deletions README_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ FIL: Filipino (Tagalog)
--kernel-size KERNEL_SIZE Set the convolution kernel size of the text erasure area to
completely clean up text residues
--config-file CONFIG_FILE path to the config file
--models-ttl MODELS_TTL How long to keep models in memory in seconds after last use (0 means
forever)
```

<!-- Auto generated end -->
Expand Down
4 changes: 4 additions & 0 deletions manga_translator/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ def general_parser(g_parser):
help='Path to the post-translation dictionary file')
g_parser.add_argument('--kernel-size', default=3, type=int,
help='Set the convolution kernel size of the text erasure area to completely clean up text residues')
g_parser.add_argument('--models-ttl', default=3, type=int,
help='How long to keep models in memory in seconds after last use (0 means forever)')



Expand Down Expand Up @@ -129,12 +131,14 @@ def reparse(arr: list):
parser_ws.add_argument('--port', default=5003, type=int, help='Port for WebSocket service')
parser_ws.add_argument('--nonce', default=os.getenv('MT_WEB_NONCE', ''), type=str, help='Nonce for securing internal WebSocket communication')
parser_ws.add_argument('--ws-url', default='ws://localhost:5000', type=str, help='Server URL for WebSocket mode')
parser_ws.add_argument('--models-ttl', default='0', type=int, help='How long to keep models in memory in seconds after last use (0 means forever)')

# API mode
parser_api = subparsers.add_parser('shared', help='Run in API mode')
parser_api.add_argument('--host', default='127.0.0.1', type=str, help='Host for API service')
parser_api.add_argument('--port', default=5003, type=int, help='Port for API service')
parser_api.add_argument('--nonce', default=os.getenv('MT_WEB_NONCE', ''), type=str, help='Nonce for securing internal API server communication')
parser_api.add_argument("--report", default=None,type=str, help='reports to server to register instance')
parser_api.add_argument('--models-ttl', default='0', type=int, help='models TTL in memory in seconds')

subparsers.add_parser('config-help', help='Print help information for config file')
3 changes: 3 additions & 0 deletions manga_translator/colorization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,6 @@ async def dispatch(key: Colorizer, device: str = 'cpu', **kwargs) -> Image.Image
if isinstance(colorizer, OfflineColorizer):
await colorizer.load(device)
return await colorizer.colorize(**kwargs)

async def unload(key: Colorizer):
colorizer_cache.pop(key, None)
3 changes: 3 additions & 0 deletions manga_translator/detection/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,6 @@ async def dispatch(detector_key: Detector, image: np.ndarray, detect_size: int,
if isinstance(detector, OfflineDetector):
await detector.load(device)
return await detector.detect(image, detect_size, text_threshold, box_threshold, unclip_ratio, invert, gamma_correct, rotate, auto_rotate, verbose)

async def unload(detector_key: Detector):
detector_cache.pop(detector_key, None)
3 changes: 3 additions & 0 deletions manga_translator/inpainting/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,6 @@ async def dispatch(inpainter_key: Inpainter, image: np.ndarray, mask: np.ndarray
await inpainter.load(device)
config = config or InpainterConfig()
return await inpainter.inpaint(image, mask, config, inpainting_size, verbose)

async def unload(inpainter_key: Inpainter):
inpainter_cache.pop(inpainter_key, None)
88 changes: 74 additions & 14 deletions manga_translator/manga_translator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import cv2
import json
import langcodes
Expand All @@ -23,18 +24,19 @@
sort_regions,
)

from .detection import dispatch as dispatch_detection, prepare as prepare_detection
from .upscaling import dispatch as dispatch_upscaling, prepare as prepare_upscaling
from .ocr import dispatch as dispatch_ocr, prepare as prepare_ocr
from .detection import dispatch as dispatch_detection, prepare as prepare_detection, unload as unload_detection
from .upscaling import dispatch as dispatch_upscaling, prepare as prepare_upscaling, unload as unload_upscaling
from .ocr import dispatch as dispatch_ocr, prepare as prepare_ocr, unload as unload_ocr
from .textline_merge import dispatch as dispatch_textline_merge
from .mask_refinement import dispatch as dispatch_mask_refinement
from .inpainting import dispatch as dispatch_inpainting, prepare as prepare_inpainting
from .inpainting import dispatch as dispatch_inpainting, prepare as prepare_inpainting, unload as unload_inpainting
from .translators import (
LANGDETECT_MAP,
dispatch as dispatch_translation,
prepare as prepare_translation,
unload as unload_translation,
)
from .colorization import dispatch as dispatch_colorization, prepare as prepare_colorization
from .colorization import dispatch as dispatch_colorization, prepare as prepare_colorization, unload as unload_colorization
from .rendering import dispatch as dispatch_rendering, dispatch_eng_render

# Will be overwritten by __main__.py if module is being run directly (with python -m)
Expand Down Expand Up @@ -90,6 +92,7 @@ class MangaTranslator:
_gpu_limited_memory: bool
device: Optional[str]
kernel_size: Optional[int]
models_ttl: int
_progress_hooks: list[Any]
result_sub_folder: str

Expand All @@ -103,6 +106,7 @@ def __init__(self, params: dict = None):
self._gpu_limited_memory = False
self.ignore_errors = False
self.verbose = False
self.models_ttl = 0

self._progress_hooks = []
self._add_logger_hook()
Expand All @@ -118,10 +122,14 @@ def __init__(self, params: dict = None):
# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True

self._model_usage_timestamps = {}
self._detector_cleanup_task = None

def parse_init_params(self, params: dict):
self.verbose = params.get('verbose', False)
self.use_mtpe = params.get('use_mtpe', False)
self.font_path = params.get('font_path', None)
self.models_ttl = params.get('models_ttl', 0)

self.ignore_errors = params.get('ignore_errors', False)
# check mps for apple silicon or cuda for nvidia
Expand Down Expand Up @@ -167,19 +175,24 @@ async def translate(self, image: Image.Image, config: Config) -> Context:
ctx.result = None

# preload and download models (not strictly necessary, remove to lazy load)
logger.info('Loading models')
if config.upscale.upscale_ratio:
await prepare_upscaling(config.upscale.upscaler)
await prepare_detection(config.detector.detector)
await prepare_ocr(config.ocr.ocr, self.device)
await prepare_inpainting(config.inpainter.inpainter, self.device)
await prepare_translation(config.translator.translator_gen)
if config.colorizer.colorizer != Colorizer.none:
await prepare_colorization(config.colorizer.colorizer)
if ( self.models_ttl == 0 ):
logger.info('Loading models')
if config.upscale.upscale_ratio:
await prepare_upscaling(config.upscale.upscaler)
await prepare_detection(config.detector.detector)
await prepare_ocr(config.ocr.ocr, self.device)
await prepare_inpainting(config.inpainter.inpainter, self.device)
await prepare_translation(config.translator.translator_gen)
if config.colorizer.colorizer != Colorizer.none:
await prepare_colorization(config.colorizer.colorizer)

# translate
return await self._translate(config, ctx)

async def _translate(self, config: Config, ctx: Context) -> Context:
# Start the background cleanup job once if not already started.
if self._detector_cleanup_task is None:
self._detector_cleanup_task = asyncio.create_task(self._detector_cleanup_job())
# -- Colorization
if config.colorizer.colorizer != Colorizer.none:
await self._report_progress('colorizing')
Expand Down Expand Up @@ -302,20 +315,59 @@ async def _revert_upscale(self, config: Config, ctx: Context):
return ctx

async def _run_colorizer(self, config: Config, ctx: Context):
current_time = time.time()
self._model_usage_timestamps[("colorizer", config.colorizer.colorizer)] = current_time
#todo: im pretty sure the ctx is never used. does it need to be passed in?
return await dispatch_colorization(config.colorizer.colorizer, device=self.device, image=ctx.input, **ctx)

async def _run_upscaling(self, config: Config, ctx: Context):
current_time = time.time()
self._model_usage_timestamps[("upscaling", config.upscale.upscaler)] = current_time
return (await dispatch_upscaling(config.upscale.upscaler, [ctx.img_colorized], config.upscale.upscale_ratio, self.device))[0]

async def _run_detection(self, config: Config, ctx: Context):
current_time = time.time()
self._model_usage_timestamps[("detection", config.detector.detector)] = current_time
return await dispatch_detection(config.detector.detector, ctx.img_rgb, config.detector.detection_size, config.detector.text_threshold,
config.detector.box_threshold,
config.detector.unclip_ratio, config.detector.det_invert, config.detector.det_gamma_correct, config.detector.det_rotate,
config.detector.det_auto_rotate,
self.device, self.verbose)

async def _unload_model(self, tool: str, model: str):
logger.info(f"Unloading {tool} model: {model}")
match tool:
case 'colorization':
await unload_colorization(model)
case 'detection':
await unload_detection(model)
case 'inpainting':
await unload_inpainting(model)
case 'ocr':
await unload_ocr(model)
case 'upscaling':
await unload_upscaling(model)
case 'translation':
await unload_translation(model)
if torch.cuda.is_available():
torch.cuda.empty_cache() # empty CUDA cache

# Background models cleanup job.
async def _detector_cleanup_job(self):
while True:
if self.models_ttl == 0:
await asyncio.sleep(1)
continue
now = time.time()
for (tool, model), last_used in list(self._model_usage_timestamps.items()):
if now - last_used > self.models_ttl:
await self._unload_model(tool, model)
del self._model_usage_timestamps[(tool, model)]
await asyncio.sleep(1)

async def _run_ocr(self, config: Config, ctx: Context):
current_time = time.time()
self._model_usage_timestamps[("ocr", config.ocr.ocr)] = current_time
textlines = await dispatch_ocr(config.ocr.ocr, ctx.img_rgb, ctx.textlines, config.ocr, self.device, self.verbose)

new_textlines = []
Expand All @@ -329,6 +381,8 @@ async def _run_ocr(self, config: Config, ctx: Context):
return new_textlines

async def _run_textline_merge(self, config: Config, ctx: Context):
current_time = time.time()
self._model_usage_timestamps[("textline_merge", "textline_merge")] = current_time
text_regions = await dispatch_textline_merge(ctx.textlines, ctx.img_rgb.shape[1], ctx.img_rgb.shape[0],
verbose=self.verbose)
# Filter out languages to skip
Expand Down Expand Up @@ -417,6 +471,8 @@ async def _run_textline_merge(self, config: Config, ctx: Context):
return text_regions

async def _run_text_translation(self, config: Config, ctx: Context):
current_time = time.time()
self._model_usage_timestamps[("translation", config.translator.translator)] = current_time
if self.load_text:
input_filename = os.path.splitext(os.path.basename(self.input_files[0]))[0]
with open(self._result_path(f"{input_filename}_translations.txt"), "r") as f:
Expand Down Expand Up @@ -643,10 +699,14 @@ async def _run_mask_refinement(self, config: Config, ctx: Context):
config.mask_dilation_offset, config.ocr.ignore_bubble, self.verbose,self.kernel_size)

async def _run_inpainting(self, config: Config, ctx: Context):
current_time = time.time()
self._model_usage_timestamps[("inpainting", config.inpainter.inpainter)] = current_time
return await dispatch_inpainting(config.inpainter.inpainter, ctx.img_rgb, ctx.mask, config.inpainter, config.inpainter.inpainting_size, self.device,
self.verbose)

async def _run_text_rendering(self, config: Config, ctx: Context):
current_time = time.time()
self._model_usage_timestamps[("rendering", config.render.renderer)] = current_time
if config.render.renderer == Renderer.none:
output = ctx.img_inpainted
# manga2eng currently only supports horizontal left to right rendering
Expand Down
4 changes: 3 additions & 1 deletion manga_translator/ocr/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import numpy as np
from typing import List, Optional

from .common import CommonOCR, OfflineOCR
from .model_32px import Model32pxOCR
from .model_48px import Model48pxOCR
Expand Down Expand Up @@ -37,3 +36,6 @@ async def dispatch(ocr_key: Ocr, image: np.ndarray, regions: List[Quadrilateral]
await ocr.load(device)
config = config or OcrConfig()
return await ocr.recognize(image, regions, config, verbose)

async def unload(ocr_key: Ocr):
ocr_cache.pop(ocr_key, None)
1 change: 0 additions & 1 deletion manga_translator/ocr/model_32px.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ async def _load(self, device: str):
if self.use_gpu:
self.model = self.model.to(device)


async def _unload(self):
del self.model

Expand Down
3 changes: 3 additions & 0 deletions manga_translator/translators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,6 @@ async def dispatch(chain: TranslatorChain, queries: List[str], translator_config
'id': 'IND',
'tl': 'FIL'
}

async def unload(key: Translator):
translator_cache.pop(key, None)
3 changes: 3 additions & 0 deletions manga_translator/upscaling/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,6 @@ async def dispatch(upscaler_key: Upscaler, image_batch: List[Image.Image], upsca
if isinstance(upscaler, OfflineUpscaler):
await upscaler.load(device)
return await upscaler.upscale(image_batch, upscale_ratio)

async def unload(upscaler_key: Upscaler):
upscaler_cache.pop(upscaler_key, None)
1 change: 1 addition & 0 deletions server/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ def parse_arguments():
help='If a translator should be launched automatically')
parser.add_argument('--ignore-errors', action='store_true', help='Skip image on encountered error.')
parser.add_argument('--nonce', default=os.getenv('MT_WEB_NONCE', ''), type=str, help='Nonce for securing internal web server communication')
parser.add_argument('--models-ttl', default='0', type=int, help='models TTL in memory in seconds')
g = parser.add_mutually_exclusive_group()
g.add_argument('--use-gpu', action='store_true', help='Turn on/off gpu (auto switch between mps and cuda)')
g.add_argument('--use-gpu-limited', action='store_true', help='Turn on/off gpu (excluding offline translator)')
Expand Down
2 changes: 2 additions & 0 deletions server/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,8 @@ def start_translator_client_proc(host: str, port: int, nonce: str, params: Names
cmds.append('--ignore-errors')
if params.verbose:
cmds.append('--verbose')
if params.models_ttl:
cmds.append('--models-ttl=%s' % params.models_ttl)
base_path = os.path.dirname(os.path.abspath(__file__))
parent = os.path.dirname(base_path)
proc = subprocess.Popen(cmds, cwd=parent)
Expand Down

0 comments on commit f7a4cbb

Please sign in to comment.