Skip to content

Commit

Permalink
Speedup on some models by not upcasting bfloat16 to float32 on mac.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Feb 24, 2025
1 parent 4553891 commit 96d891c
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 7 deletions.
13 changes: 7 additions & 6 deletions comfy/ldm/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,12 @@

FORCE_UPCAST_ATTENTION_DTYPE = model_management.force_upcast_attention_dtype()

def get_attn_precision(attn_precision):
def get_attn_precision(attn_precision, current_dtype):
if args.dont_upcast_attention:
return None
if FORCE_UPCAST_ATTENTION_DTYPE is not None:
return FORCE_UPCAST_ATTENTION_DTYPE

if FORCE_UPCAST_ATTENTION_DTYPE is not None and current_dtype in FORCE_UPCAST_ATTENTION_DTYPE:
return FORCE_UPCAST_ATTENTION_DTYPE[current_dtype]
return attn_precision

def exists(val):
Expand Down Expand Up @@ -81,7 +82,7 @@ def Normalize(in_channels, dtype=None, device=None):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)

def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
attn_precision = get_attn_precision(attn_precision)
attn_precision = get_attn_precision(attn_precision, q.dtype)

if skip_reshape:
b, _, _, dim_head = q.shape
Expand Down Expand Up @@ -150,7 +151,7 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape


def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
attn_precision = get_attn_precision(attn_precision)
attn_precision = get_attn_precision(attn_precision, query.dtype)

if skip_reshape:
b, _, _, dim_head = query.shape
Expand Down Expand Up @@ -220,7 +221,7 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None,
return hidden_states

def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
attn_precision = get_attn_precision(attn_precision)
attn_precision = get_attn_precision(attn_precision, q.dtype)

if skip_reshape:
b, _, _, dim_head = q.shape
Expand Down
2 changes: 1 addition & 1 deletion comfy/model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -954,7 +954,7 @@ def force_upcast_attention_dtype():
upcast = True

if upcast:
return torch.float32
return {torch.float16: torch.float32}
else:
return None

Expand Down

0 comments on commit 96d891c

Please sign in to comment.