diff --git a/optimum_benchmark/cli.py b/optimum_benchmark/cli.py index 0976fc12..7b44f621 100644 --- a/optimum_benchmark/cli.py +++ b/optimum_benchmark/cli.py @@ -53,15 +53,6 @@ cs.store(group="launcher", name=ProcessConfig.name, node=ProcessConfig) cs.store(group="launcher", name=TorchrunConfig.name, node=TorchrunConfig) -LOGGING_SETUP_DONE = False - - -def setup_logging_once(*args, **kwargs): - global LOGGING_SETUP_DONE - if not LOGGING_SETUP_DONE: - LOGGING_SETUP_DONE = True - setup_logging(*args, **kwargs) - # optimum-benchmark @hydra.main(version_base=None) @@ -69,7 +60,7 @@ def main(config: DictConfig) -> None: log_level = os.environ.get("LOG_LEVEL", "INFO") log_to_file = os.environ.get("LOG_TO_FILE", "1") == "1" override_benchmarks = os.environ.get("OVERRIDE_BENCHMARKS", "0") == "1" - setup_logging_once(level=log_level, to_file=log_to_file, prefix="MAIN-PROCESS") + setup_logging(level=log_level, to_file=log_to_file, prefix="MAIN-PROCESS") if glob.glob("benchmark_report.json") and not override_benchmarks: LOGGER.warning( diff --git a/optimum_benchmark/task_utils.py b/optimum_benchmark/task_utils.py index 5dbe72c2..cf1701b5 100644 --- a/optimum_benchmark/task_utils.py +++ b/optimum_benchmark/task_utils.py @@ -47,17 +47,12 @@ _TIMM_TASKS_TO_MODEL_LOADERS = { "image-classification": "create_model", } -_SENTENCE_TRANSFORMERS_TASKS_TO_MODEL_LOADERS = { - "feature-extraction": "SentenceTransformer", - "sentence-similarity": "SentenceTransformer", -} _LIBRARY_TO_TASKS_TO_MODEL_LOADER_MAP = { "timm": _TIMM_TASKS_TO_MODEL_LOADERS, "diffusers": _DIFFUSERS_TASKS_TO_MODEL_LOADERS, "transformers": _TRANSFORMERS_TASKS_TO_MODEL_LOADERS, - "sentence-transformers": _SENTENCE_TRANSFORMERS_TASKS_TO_MODEL_LOADERS, } _SYNONYM_TASK_MAP = { @@ -122,6 +117,9 @@ def infer_library_from_model_name_or_path(model_name_or_path: str, revision: Opt if inferred_library_name is None: raise KeyError(f"Could not find the proper library name for {model_name_or_path}.") + if inferred_library_name == "sentence-transformers": + inferred_library_name = "transformers" + return inferred_library_name