From 96d891cb94d90f220e066cebad349887137f07a6 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 24 Feb 2025 05:41:07 -0500 Subject: [PATCH] Speedup on some models by not upcasting bfloat16 to float32 on mac. --- comfy/ldm/modules/attention.py | 13 +++++++------ comfy/model_management.py | 2 +- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 24fb9d95080..2758f9508b5 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -30,11 +30,12 @@ FORCE_UPCAST_ATTENTION_DTYPE = model_management.force_upcast_attention_dtype() -def get_attn_precision(attn_precision): +def get_attn_precision(attn_precision, current_dtype): if args.dont_upcast_attention: return None - if FORCE_UPCAST_ATTENTION_DTYPE is not None: - return FORCE_UPCAST_ATTENTION_DTYPE + + if FORCE_UPCAST_ATTENTION_DTYPE is not None and current_dtype in FORCE_UPCAST_ATTENTION_DTYPE: + return FORCE_UPCAST_ATTENTION_DTYPE[current_dtype] return attn_precision def exists(val): @@ -81,7 +82,7 @@ def Normalize(in_channels, dtype=None, device=None): return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device) def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): - attn_precision = get_attn_precision(attn_precision) + attn_precision = get_attn_precision(attn_precision, q.dtype) if skip_reshape: b, _, _, dim_head = q.shape @@ -150,7 +151,7 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): - attn_precision = get_attn_precision(attn_precision) + attn_precision = get_attn_precision(attn_precision, query.dtype) if skip_reshape: b, _, _, dim_head = query.shape @@ -220,7 +221,7 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, return hidden_states def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): - attn_precision = get_attn_precision(attn_precision) + attn_precision = get_attn_precision(attn_precision, q.dtype) if skip_reshape: b, _, _, dim_head = q.shape diff --git a/comfy/model_management.py b/comfy/model_management.py index f4a63c6d371..1e6599be20a 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -954,7 +954,7 @@ def force_upcast_attention_dtype(): upcast = True if upcast: - return torch.float32 + return {torch.float16: torch.float32} else: return None