Accuracy fix for llama3.1-70B in eager/torch.compile mode #1746
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Issue: Low accuracy in Llama3.1-70B with eager/torch.compile mode
(Following details extracted from huggingface/transformers#28685)
Use a number of transformers models that utilize arange for integer enumerations in the calculation of position embeddings with DeepSpeed zero.Init() and a low precision dtype (float16, bfloat16), and the generated embeddings will differ significantly from intended.
Using Llama as an example
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
The inv_freq.dtype == float32. Single precision float can cover the required integer range for the enumeration (I believe it's in the 2k-8k range for Llama?).
However, when DeepSpeed zero.Init is used the init function patching will override the float dtype passed in with a low precision float dtype, so float32 -> bfloat16 or float16. Thus the integer range that can be represented without significant loss drops down to 256 for bfloat16 or 2048 for float16. DeepSpeed's patching has an exception for integer dtype, it will not cast arange to the low precision float dtype if arange dtype is an int type.
Fix: Simply set the dtype as torch.int32 for torch.arange.
torch.int64 is not used because it generates incorrect values (and corresponding JIT_IR graph is not as expected).