From f55ea6d78e370a1d75e50cdb351a28811748487d Mon Sep 17 00:00:00 2001 From: i30817 Date: Thu, 2 Nov 2023 13:46:35 +0000 Subject: [PATCH] Use a executor to speedup name normalization (#31) --- libretrofuzz/__main__.py | 57 +++++++++++++++++++++++++++------------- 1 file changed, 39 insertions(+), 18 deletions(-) diff --git a/libretrofuzz/__main__.py b/libretrofuzz/__main__.py index 9f676d1..752c81e 100644 --- a/libretrofuzz/__main__.py +++ b/libretrofuzz/__main__.py @@ -17,8 +17,9 @@ from typing import Optional, List from urllib.request import unquote, quote from tempfile import TemporaryDirectory +from concurrent.futures import ProcessPoolExecutor from contextlib import asynccontextmanager, contextmanager -from itertools import chain +from itertools import chain, repeat from struct import unpack import json import os @@ -357,8 +358,7 @@ def __call__(self, name, other, score_cutoff=None): # TODO remove this when a feature to choose when waiting happens if not self.hack and other in self.normcache: rest_of_score -= remaining * 0.65 - return fuzz.WRatio(name, other) * (DEF_SCORE / 100) + rest_of_score - + return rest_of_score + (DEF_SCORE / 100) * fuzz.WRatio(name, other) # --------------------------------------------------------------- # Normalization functions, part of the functions that change both @@ -1038,6 +1038,39 @@ async def downloadgamenames(client, system, nub_verbose): raise Exit(code=1) return args +async def exitcheck(): + await asyncio.sleep(0) # update key status + checkEscape() # check exit key status + +def norm(nometa,hack,n): + return (n, normalizer(nometa, hack, n)) +def norm_local(nometa,hack,before,n): + return (n, normalizer(nometa, hack, regex.sub(forbidden, "_", extractbefore(before, n)))) +async def norm2dict(names,remote_names,nometa,hack,before): + normcache = dict() + normcache2= dict() + executor = ProcessPoolExecutor() + norm_format = style("Preparing names: {remaining_s:2.1f}s", fg=BLUE, bold=True) + try: + tasknumber = len(names)+len(remote_names) + with tqdm(total=tasknumber, bar_format=norm_format, leave=False) as pbar: + chunks = max(1,int(len(names)/os.cpu_count())) + for k,i in executor.map(norm_local,repeat(nometa),repeat(hack),repeat(before),names,chunksize=chunks): + await exitcheck() + normcache[k]=i + pbar.update(1) + chunks = max(1,int(len(remote_names)/os.cpu_count())) + for k,i in executor.map(norm,repeat(nometa),repeat(hack),remote_names,chunksize=chunks): + await exitcheck() + normcache2[k]=i + pbar.update(1) + finally: + if int(platform.python_version_tuple()[0]) >= 3 and int(platform.python_version_tuple()[1]) > 8: + executor.shutdown(wait=True,cancel_futures=True) + else: + #TODO remove when the program drops python 3.8 compatibility (windows vista) + executor.shutdown(wait=True) + return normcache,normcache2 async def downloader( names: [str], @@ -1095,19 +1128,8 @@ async def downloader( error(f"Unavailable server thumbnails for system: {system}") raise StopPlaylist() - # preprocess data to build a heuristic later. Do not move - # into the later loop because thats when the heuristic is used - def norm(n): - return normalizer(nometa, hack, n) - - def norm_local(n): - return norm(regex.sub(forbidden, "_", extractbefore(before, n))) - - # local names normalization cache - normcache = dict(map(lambda n: (n, norm_local(n)), names)) - # remote names normalization cache - normcache2 = dict(map(lambda n: (n, norm(n)), remote_names)) - + # concurrently init local and remote names normalization caches + normcache,normcache2 = await norm2dict(names, remote_names, nometa, hack, before) # short names bool, got from enviromental variable short_names = os.getenv("SHORT") short_names = short_names and short_names != "0" @@ -1118,8 +1140,7 @@ def strfy_runtime(s, urldict=None): scorer = TitleScorer(normcache, normcache2, hack) for name, destination in zip(names, dbs): - await asyncio.sleep(0) # update key status - checkEscape() # check exit key status + await exitcheck() # if the user used filters, filter everything that doesn't match any of the globs if filters and not any(map(lambda x: fnmatch.fnmatch(name, x), filters)): skipped += 1