-
Notifications
You must be signed in to change notification settings - Fork 37
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Sharktank][Llama][FP8] Minimal changes for numerically correct fp8 #859
Conversation
20e7316
to
6643fb3
Compare
99d3a50
to
74344e0
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mostly have a bunch of naive questions, so I won't give approval or request changes.
@@ -88,7 +87,7 @@ def forward(self, x): | |||
# level to do this, but for now its here. | |||
if not isinstance(y, QuantizedTensor): | |||
if y.dtype == torch.float8_e4m3fnuz: | |||
y = ops.to(y, torch.float16) | |||
y = ops.to(y, torch.bfloat16) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this change because the specific float8 type accumulates to something which only truncates safely to bfloat16
instead of float16
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No this is an artifact of the way the model was quantized. The actual fp8 matmul intrinsic accumulates into f32, which iree can truncate, but in python we just cast to match the reference model
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not a fan that the python implementation can really only compare for one specific quantization method like this. I don't have an answer off the top of my head, so fine for now but ideally would be good to make more agnostic somehow
xk = ( | ||
self.cache_quantizer.quantize(xk) | ||
.unpack() | ||
.dequant() | ||
.to(torch.float16) | ||
) | ||
xv = ( | ||
self.cache_quantizer.quantize(xv) | ||
.unpack() | ||
.dequant() | ||
.to(torch.float16) | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why was this removed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Quark's model loader didn't support the fp8 kv cache. We are still doing it for export, but it is missing in the python comparison.
if attention_mask is not None: | ||
attention_mask = attention_mask.to(torch.bfloat16) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see any quantization stuff here. Is the indentation incorrect?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if self.cache_quantizer and not self.fake_quant:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, that's what I'm asking about. Do we need cache_quantizer
and not fake_quant
for the attention mask to be in bfloat16
? I guess that's probably the case.
@@ -82,7 +86,7 @@ def __init__(self, theta: Theta, config: LlamaModelConfig): | |||
|
|||
self.add_module( | |||
"token_embedding", | |||
TokenEmbeddingLayer(theta("token_embd"), dtype=config.activation_dtype), | |||
TokenEmbeddingLayer(theta("token_embd"), dtype=self.activation_dtype), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From earlier in __init__
, it doesn't look like self.activation_dtype
is different from config.activation_dtype
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah this change is a no op. Style preference.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For consistency, we can use self
or config
@@ -55,11 +62,30 @@ def __init__(self): | |||
# Map of module_name to last used index for duplicated tensors. | |||
self.duplicate_tensors = {} | |||
|
|||
def before_forward(self, module_name, module, *args, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might be useful to have a docstring for this. In isolation, I'm not really sure what it does, or what should be passed as args
and kwargs
.
It seems like the intent is that some input tensors are passed as args
in a specific order, and this function will add them to self.tensors
while managing duplicates appropriately.
for idx, arg in enumerate(args): | ||
if not isinstance(arg, torch.Tensor): | ||
continue | ||
result_tensor = torch.detach(arg).contiguous().to(device="cpu").clone() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From the name_base
would this more appropriately be called input_tensor
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now that I think about it, perhaps adding this method warrants renaming the class to something like SaveInputAndOutputTensorsPatch
and updating the docstrings for it.
if name_base in self.tensors: | ||
orig_dup = self.tensors[name_base] | ||
del self.tensors[name_base] | ||
self.duplicate_tensors[name_base] = 0 | ||
self.tensors[f"{name_base}#0"] = orig_dup | ||
elif name_base in self.duplicate_tensors: | ||
index = self.duplicate_tensors[name_base] + 1 | ||
self.duplicate_tensors[name_base] = index | ||
self.tensors[f"{name_base}#{index}"] = result_tensor | ||
else: | ||
self.tensors[name_base] = result_tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm trying to parse this...
Lets say you are passing one tensor x
in through args
, but f"{module_name}_input_0"
is already a key in self.tensors
.
Then x
isn't going to get added to anything, but instead you just rename the original tensors element.
Because I don't really know what the point is of this function, my naive assumption is that the elif
should just be an if
, that way x
would get added to self.tensors
as f"{name_base}#1"
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since it looks like this logic is used in after_forward
, would it be useful to factor out something like the following?
def insert_tensor(self, tensor : torch.Tensor, name_base : str):
""" Adds a tensor to self.tensors while updating duplicate counts. """
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, looking around more, I don't really see any instance where name_base
would appear as one of the keys to self.tensors
, so maybe the initial if statement is completely unnecessary?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it is to cover cases of modules with the same name. Won't happen often from our own models, but it does happen in the wild.
b600453
to
9e479b0
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you have a test or a command I can run to verify the numerics are correct?
@@ -49,7 +49,7 @@ def from_gguf_props(p: dict[str, Any]): | |||
name_prefix = p.get("general.architecture", "llama") | |||
default_expert_count = 0 | |||
default_expert_used_count = 0 | |||
default_rope_freq_base = 10000.0 | |||
default_rope_freq_base = 500000.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we pass this separately, as I have noticed rope_freq_base
not being explicitly set in some models and might need to default to 10000?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perplexity seems to be passing, so not a blocker.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
llama3 defaults to 500000 so I think we should use that
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we know where that 10000 came from?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
llama2
@@ -160,6 +157,23 @@ def prefill(self): | |||
attention_mask = replicate(attention_mask, tp) | |||
seq_block_ids_tensor = replicate(seq_block_ids_tensor, tp) | |||
|
|||
if self.dump_bins: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can use generate_data.py to fetch input data, given a model & a prompt.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
doesn't work for values not supported by numpy
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we update generate_data.py to use torch tensors instead of numpy arrays? Should work around this issue
return matmul(x_layout.qs, weight_layout.qs, transpose_rhs=True).to( | ||
torch.float16 | ||
) | ||
return matmul(x_layout.qs, weight_layout.qs, transpose_rhs=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we do a return here, we're not actually inserting the mmt kernel which is what's intended from this script. This function is called when the input tensors are quantized and at least punet still expects to have the quantized kernel inserted if and when that happens
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what does the mmt kernel do? because we can just lower fp8 matmul with torch
you can run pytest sharktank/tests/models/llama/quark_parity_test.py on mi300x right now. I plan to make it run anywhere in a follow up patch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mostly looks good to me, some minor nit comments
@@ -160,6 +157,23 @@ def prefill(self): | |||
attention_mask = replicate(attention_mask, tp) | |||
seq_block_ids_tensor = replicate(seq_block_ids_tensor, tp) | |||
|
|||
if self.dump_bins: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we update generate_data.py to use torch tensors instead of numpy arrays? Should work around this issue
@@ -49,7 +49,7 @@ def from_gguf_props(p: dict[str, Any]): | |||
name_prefix = p.get("general.architecture", "llama") | |||
default_expert_count = 0 | |||
default_expert_used_count = 0 | |||
default_rope_freq_base = 10000.0 | |||
default_rope_freq_base = 500000.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we know where that 10000 came from?
@@ -88,7 +87,7 @@ def forward(self, x): | |||
# level to do this, but for now its here. | |||
if not isinstance(y, QuantizedTensor): | |||
if y.dtype == torch.float8_e4m3fnuz: | |||
y = ops.to(y, torch.float16) | |||
y = ops.to(y, torch.bfloat16) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not a fan that the python implementation can really only compare for one specific quantization method like this. I don't have an answer off the top of my head, so fine for now but ideally would be good to make more agnostic somehow
Would be great if you could add to documentation so it's easy for anyone else to pick up and test/use fp8 models |
I'll add documentation on halo models this afternoon |
10000 is the default for llama2 |
This patch enables the use of quark quantized models of the latest generation. Many changes were required to enable parity with the source model which is very sensitive to any numerical fluctuations. There is a test added to maintain this parity but is disabled until I can get the relevant data set up on the ci machine.