Skip to content

Commit

Permalink
Fix namespace error of dtype args (#139)
Browse files Browse the repository at this point in the history
  • Loading branch information
puneeshkhanna authored Jan 15, 2025
1 parent fe5e7c8 commit 84c7fb8
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions scripts/benchmark_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@

def get_transformers_pipeline(args: Namespace):
if "dtype" in args:
assert args["dtype"] in {"float16", "bfloat16", "float32"}
assert args.dtype in {"float16", "bfloat16", "float32"}

return raw_pipeline(
model=args.model,
torch_dtype=args["dtype"],
torch_dtype=args.dtype,
model_kwargs={
"device_map": "balanced",
"max_memory": {0: "20GiB", "cpu": "64GiB"},
Expand Down

0 comments on commit 84c7fb8

Please sign in to comment.