Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bugfix for npu not support float64 #10123

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open

Conversation

baymax591
Copy link
Contributor

@baymax591 baymax591 commented Dec 5, 2024

What does this PR do?

When using the FLUX model on NPU devices, an error was found in the embeddings.py file. After locating the problem, it was discovered that the issue was due to freqs_dtype being float64, which is not supported on NPU. To resolve this issue, a check for device.type was implemented. When using an NPU, float32 is used instead.

example used

import torch
from diffusers import FluxPipeline

pipe = FluxPipeline.from_pretrained("/data/baymax/models/FLUX.1-dev", 
                                    torch_dtype=torch.bfloat16,
                                    device_map="balanced"
                                    )

prompt = "A cat holding a sign that says hello world"
image = pipe(
    prompt,
    height=1024,
    width=1024,
    guidance_scale=3.5,
    num_inference_steps=50,
    max_sequence_length=512,
    generator=torch.Generator("cpu").manual_seed(0)
).images[0]
image.save("flux-dev.png")

Before this PR

(baymax) [root@modelfoundry-prod-node-0002 baymax]# python test_diffusers.py 
Loading checkpoint shards: 100%|███████████████████████████████████████████████| 2/2 [00:02<00:00,  1.18s/it]
Loading pipeline components...:  71%|██████████████████████████████            | 5/7 [00:04<00:01,  1.40it/s]You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers
Loading pipeline components...: 100%|██████████████████████████████████████████| 7/7 [00:12<00:00,  1.82s/it]
  0%|                                                                                 | 0/50 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "/data/baymax/test_diffusers.py", line 11, in <module>
    image = pipe(
  File "/root/miniconda3/envs/baymax/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/root/miniconda3/envs/baymax/lib/python3.10/site-packages/diffusers/pipelines/flux/pipeline_flux.py", line 730, in __call__
    noise_pred = self.transformer(
  File "/root/miniconda3/envs/baymax/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/miniconda3/envs/baymax/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/miniconda3/envs/baymax/lib/python3.10/site-packages/accelerate/hooks.py", line 170, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/root/miniconda3/envs/baymax/lib/python3.10/site-packages/diffusers/models/transformers/transformer_flux.py", line 475, in forward
    image_rotary_emb = self.pos_embed(ids)
  File "/root/miniconda3/envs/baymax/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/miniconda3/envs/baymax/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/miniconda3/envs/baymax/lib/python3.10/site-packages/diffusers/models/embeddings.py", line 761, in forward
    cos, sin = get_1d_rotary_pos_embed(
  File "/root/miniconda3/envs/baymax/lib/python3.10/site-packages/diffusers/models/embeddings.py", line 683, in get_1d_rotary_pos_embed
    freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float()  # [S, D]
RuntimeError: call aclnnRepeatInterleaveIntWithDim failed, detail:EZ1001: [PID: 483017] 2024-12-04-17:12:59.064.468 self not implemented for DT_DOUBLE, should be in dtype support list [DT_UINT8,DT_INT8,DT_INT16,DT_INT32,DT_INT64,DT_BOOL,DT_FLOAT16,DT_FLOAT,DT_BFLOAT16,].

[ERROR] 2024-12-04-17:12:59 (PID:483017, Device:4, RankID:-1) ERR01005 OPS internal error

After this PR

(baymax) [root@modelfoundry-prod-node-0002 baymax]# python test_diffusers.py 
Loading checkpoint shards: 100%|███████████████████████████████████████████| 2/2 [00:03<00:00,  1.67s/it]
Loading pipeline components...:  57%|█████████████████████▋                | 4/7 [00:17<00:16,  5.51s/it]You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers
Loading pipeline components...: 100%|██████████████████████████████████████| 7/7 [00:18<00:00,  2.70s/it]
100%|████████████████████████████████████████████████████████████████████| 50/50 [01:05<00:00,  1.31s/it]

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

cc @yiyixuxu @sayakpaul

@sayakpaul sayakpaul requested review from yiyixuxu and a-r-r-o-w and removed request for yiyixuxu December 5, 2024 02:41
Copy link

github-actions bot commented Jan 4, 2025

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Jan 4, 2025
Copy link
Member

@a-r-r-o-w a-r-r-o-w left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the changes should be relatively safe as it is device-dependant changes only in this PR.

Would may be cleaner if we do a follow-up to handle these kinds of device-specific handling of dtypes with better design in the scheduler. Off to @yiyixuxu for review, and I can handle any of the pipelines we're missing or that were newly added in the duration this PR went stale

@a-r-r-o-w a-r-r-o-w added wip and removed stale Issues that haven't received updates labels Jan 7, 2025
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@hlky hlky left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've applied this to newer code and changed it from is_mps_or_is_npu to separate is_mps is_npu to match what was done in FluxPosEmbed.

@baymax591
Copy link
Contributor Author

cc @yiyixuxu

@hlky hlky added close-to-merge and removed wip labels Jan 20, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants