Skip to content

Commit

Permalink
use transformers for now
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed May 21, 2024
1 parent dc29eec commit a73c3f1
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 15 deletions.
11 changes: 1 addition & 10 deletions optimum_benchmark/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,23 +53,14 @@
cs.store(group="launcher", name=ProcessConfig.name, node=ProcessConfig)
cs.store(group="launcher", name=TorchrunConfig.name, node=TorchrunConfig)

LOGGING_SETUP_DONE = False


def setup_logging_once(*args, **kwargs):
global LOGGING_SETUP_DONE
if not LOGGING_SETUP_DONE:
LOGGING_SETUP_DONE = True
setup_logging(*args, **kwargs)


# optimum-benchmark
@hydra.main(version_base=None)
def main(config: DictConfig) -> None:
log_level = os.environ.get("LOG_LEVEL", "INFO")
log_to_file = os.environ.get("LOG_TO_FILE", "1") == "1"
override_benchmarks = os.environ.get("OVERRIDE_BENCHMARKS", "0") == "1"
setup_logging_once(level=log_level, to_file=log_to_file, prefix="MAIN-PROCESS")
setup_logging(level=log_level, to_file=log_to_file, prefix="MAIN-PROCESS")

if glob.glob("benchmark_report.json") and not override_benchmarks:
LOGGER.warning(
Expand Down
8 changes: 3 additions & 5 deletions optimum_benchmark/task_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,17 +47,12 @@
_TIMM_TASKS_TO_MODEL_LOADERS = {
"image-classification": "create_model",
}
_SENTENCE_TRANSFORMERS_TASKS_TO_MODEL_LOADERS = {
"feature-extraction": "SentenceTransformer",
"sentence-similarity": "SentenceTransformer",
}


_LIBRARY_TO_TASKS_TO_MODEL_LOADER_MAP = {
"timm": _TIMM_TASKS_TO_MODEL_LOADERS,
"diffusers": _DIFFUSERS_TASKS_TO_MODEL_LOADERS,
"transformers": _TRANSFORMERS_TASKS_TO_MODEL_LOADERS,
"sentence-transformers": _SENTENCE_TRANSFORMERS_TASKS_TO_MODEL_LOADERS,
}

_SYNONYM_TASK_MAP = {
Expand Down Expand Up @@ -122,6 +117,9 @@ def infer_library_from_model_name_or_path(model_name_or_path: str, revision: Opt
if inferred_library_name is None:
raise KeyError(f"Could not find the proper library name for {model_name_or_path}.")

if inferred_library_name == "sentence-transformers":
inferred_library_name = "transformers"

return inferred_library_name


Expand Down

0 comments on commit a73c3f1

Please sign in to comment.