Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Accuracy fix for llama3.1-70B in eager/torch.compile mode #1746

Merged
merged 1 commit into from
Feb 7, 2025

Conversation

ckvermaAI
Copy link
Contributor

Issue: Low accuracy in Llama3.1-70B with eager/torch.compile mode

(Following details extracted from huggingface/transformers#28685)
Use a number of transformers models that utilize arange for integer enumerations in the calculation of position embeddings with DeepSpeed zero.Init() and a low precision dtype (float16, bfloat16), and the generated embeddings will differ significantly from intended.

  1. Using Llama as an example
    t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
    The inv_freq.dtype == float32. Single precision float can cover the required integer range for the enumeration (I believe it's in the 2k-8k range for Llama?).

  2. However, when DeepSpeed zero.Init is used the init function patching will override the float dtype passed in with a low precision float dtype, so float32 -> bfloat16 or float16. Thus the integer range that can be represented without significant loss drops down to 256 for bfloat16 or 2048 for float16. DeepSpeed's patching has an exception for integer dtype, it will not cast arange to the low precision float dtype if arange dtype is an int type.

Fix: Simply set the dtype as torch.int32 for torch.arange.

torch.int64 is not used because it generates incorrect values (and corresponding JIT_IR graph is not as expected).

Copy link
Contributor

@yafshar yafshar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice fix!

LGTM!

Hi @regisss, this PR is ready for your final review. Could you please take a look?

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@libinta libinta added the run-test Run CI for PRs from external contributors label Feb 7, 2025
Copy link
Collaborator

@regisss regisss left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@regisss regisss merged commit a0d14d2 into huggingface:main Feb 7, 2025
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
run-test Run CI for PRs from external contributors
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants