Skip to content

Commit

Permalink
NPU adaption for FLUX
Browse files Browse the repository at this point in the history
  • Loading branch information
J石页 committed Jan 6, 2025
1 parent b572635 commit c074fe4
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 12 deletions.
10 changes: 0 additions & 10 deletions examples/controlnet/train_controlnet_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,9 +473,6 @@ def parse_args(input_args=None):
parser.add_argument(
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
)
parser.add_argument(
"--enable_npu_flash_attention", action="store_true", help="Whether or not to use npu flash attention."
)
parser.add_argument(
"--set_grads_to_none",
action="store_true",
Expand Down Expand Up @@ -970,13 +967,6 @@ def load_model_hook(models, input_dir):
accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook)

if args.enable_npu_flash_attention:
if is_torch_npu_available():
logger.info("npu flash attention enabled.")
flux_transformer.enable_npu_flash_attention()
else:
raise ValueError("npu flash attention requires torch_npu extensions and is supported only on npu devices.")

if args.enable_xformers_memory_efficient_attention:
if is_xformers_available():
import xformers
Expand Down
11 changes: 9 additions & 2 deletions src/diffusers/models/transformers/transformer_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
FluxAttnProcessor2_0,
FluxAttnProcessor2_0_NPU,
FusedFluxAttnProcessor2_0,
FusedFluxAttnProcessor2_0_NPU,
)
from ...models.modeling_utils import ModelMixin
from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
Expand Down Expand Up @@ -140,7 +141,10 @@ def __init__(
self.norm1_context = AdaLayerNormZero(dim)

if hasattr(F, "scaled_dot_product_attention"):
processor = FluxAttnProcessor2_0()
if is_torch_npu_available():
processor = FluxAttnProcessor2_0_NPU()
else:
processor = FluxAttnProcessor2_0()
else:
raise ValueError(
"The current PyTorch version does not support the `scaled_dot_product_attention` function."
Expand Down Expand Up @@ -405,7 +409,10 @@ def fuse_qkv_projections(self):
if isinstance(module, Attention):
module.fuse_projections(fuse=True)

self.set_attn_processor(FusedFluxAttnProcessor2_0())
if is_torch_npu_available():
self.set_attn_processor(FusedFluxAttnProcessor2_0_NPU())
else:
self.set_attn_processor(FusedFluxAttnProcessor2_0())

# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
def unfuse_qkv_projections(self):
Expand Down

0 comments on commit c074fe4

Please sign in to comment.