Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Sharktank][Llama][FP8] Minimal changes for numerically correct fp8 #859

Merged
merged 14 commits into from
Jan 29, 2025

Conversation

dan-garvey
Copy link
Member

@dan-garvey dan-garvey commented Jan 22, 2025

This patch enables the use of quark quantized models of the latest generation. Many changes were required to enable parity with the source model which is very sensitive to any numerical fluctuations. There is a test added to maintain this parity but is disabled until I can get the relevant data set up on the ci machine.

@dan-garvey dan-garvey force-pushed the users/dan_garvey/fp8_staging branch 4 times, most recently from 20e7316 to 6643fb3 Compare January 28, 2025 00:53
@dan-garvey dan-garvey marked this pull request as ready for review January 28, 2025 00:54
@dan-garvey dan-garvey force-pushed the users/dan_garvey/fp8_staging branch 3 times, most recently from 99d3a50 to 74344e0 Compare January 28, 2025 23:15
Copy link

@zjgarvey zjgarvey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mostly have a bunch of naive questions, so I won't give approval or request changes.

sharktank/sharktank/examples/paged_llm_v1.py Show resolved Hide resolved
@@ -88,7 +87,7 @@ def forward(self, x):
# level to do this, but for now its here.
if not isinstance(y, QuantizedTensor):
if y.dtype == torch.float8_e4m3fnuz:
y = ops.to(y, torch.float16)
y = ops.to(y, torch.bfloat16)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this change because the specific float8 type accumulates to something which only truncates safely to bfloat16 instead of float16?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No this is an artifact of the way the model was quantized. The actual fp8 matmul intrinsic accumulates into f32, which iree can truncate, but in python we just cast to match the reference model

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not a fan that the python implementation can really only compare for one specific quantization method like this. I don't have an answer off the top of my head, so fine for now but ideally would be good to make more agnostic somehow

Comment on lines -133 to -144
xk = (
self.cache_quantizer.quantize(xk)
.unpack()
.dequant()
.to(torch.float16)
)
xv = (
self.cache_quantizer.quantize(xv)
.unpack()
.dequant()
.to(torch.float16)
)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why was this removed?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quark's model loader didn't support the fp8 kv cache. We are still doing it for export, but it is missing in the python comparison.

Comment on lines +166 to +167
if attention_mask is not None:
attention_mask = attention_mask.to(torch.bfloat16)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see any quantization stuff here. Is the indentation incorrect?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if self.cache_quantizer and not self.fake_quant:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that's what I'm asking about. Do we need cache_quantizer and not fake_quant for the attention mask to be in bfloat16? I guess that's probably the case.

@@ -82,7 +86,7 @@ def __init__(self, theta: Theta, config: LlamaModelConfig):

self.add_module(
"token_embedding",
TokenEmbeddingLayer(theta("token_embd"), dtype=config.activation_dtype),
TokenEmbeddingLayer(theta("token_embd"), dtype=self.activation_dtype),

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From earlier in __init__, it doesn't look like self.activation_dtype is different from config.activation_dtype.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this change is a no op. Style preference.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For consistency, we can use self or config

sharktank/sharktank/utils/export_artifacts.py Show resolved Hide resolved
sharktank/sharktank/utils/patching.py Outdated Show resolved Hide resolved
@@ -55,11 +62,30 @@ def __init__(self):
# Map of module_name to last used index for duplicated tensors.
self.duplicate_tensors = {}

def before_forward(self, module_name, module, *args, **kwargs):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be useful to have a docstring for this. In isolation, I'm not really sure what it does, or what should be passed as args and kwargs.

It seems like the intent is that some input tensors are passed as args in a specific order, and this function will add them to self.tensors while managing duplicates appropriately.

for idx, arg in enumerate(args):
if not isinstance(arg, torch.Tensor):
continue
result_tensor = torch.detach(arg).contiguous().to(device="cpu").clone()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From the name_base would this more appropriately be called input_tensor?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that I think about it, perhaps adding this method warrants renaming the class to something like SaveInputAndOutputTensorsPatch and updating the docstrings for it.

Comment on lines 71 to 81
if name_base in self.tensors:
orig_dup = self.tensors[name_base]
del self.tensors[name_base]
self.duplicate_tensors[name_base] = 0
self.tensors[f"{name_base}#0"] = orig_dup
elif name_base in self.duplicate_tensors:
index = self.duplicate_tensors[name_base] + 1
self.duplicate_tensors[name_base] = index
self.tensors[f"{name_base}#{index}"] = result_tensor
else:
self.tensors[name_base] = result_tensor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm trying to parse this...

Lets say you are passing one tensor x in through args, but f"{module_name}_input_0" is already a key in self.tensors.

Then x isn't going to get added to anything, but instead you just rename the original tensors element.

Because I don't really know what the point is of this function, my naive assumption is that the elif should just be an if, that way x would get added to self.tensors as f"{name_base}#1".

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since it looks like this logic is used in after_forward, would it be useful to factor out something like the following?

def insert_tensor(self, tensor : torch.Tensor, name_base : str):
   """ Adds a tensor to self.tensors while updating duplicate counts. """

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, looking around more, I don't really see any instance where name_base would appear as one of the keys to self.tensors, so maybe the initial if statement is completely unnecessary?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is to cover cases of modules with the same name. Won't happen often from our own models, but it does happen in the wild.

@dan-garvey dan-garvey force-pushed the users/dan_garvey/fp8_staging branch from b600453 to 9e479b0 Compare January 29, 2025 02:50
temp

remove cast to f32

temp

temp

working using llama embed

passes numerics and compiles

seven flavors of absolute trash

make decode great again

guard no mask

first round cleanup

fix rms norm again

rebase rotary and add dtype for rotary to llama.py
Copy link
Collaborator

@aviator19941 aviator19941 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have a test or a command I can run to verify the numerics are correct?

sharktank/tests/models/llama/quark_parity_test.py Outdated Show resolved Hide resolved
sharktank/tests/models/llama/quark_parity_test.py Outdated Show resolved Hide resolved
@@ -49,7 +49,7 @@ def from_gguf_props(p: dict[str, Any]):
name_prefix = p.get("general.architecture", "llama")
default_expert_count = 0
default_expert_used_count = 0
default_rope_freq_base = 10000.0
default_rope_freq_base = 500000.0
Copy link
Collaborator

@archana-ramalingam archana-ramalingam Jan 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we pass this separately, as I have noticed rope_freq_base not being explicitly set in some models and might need to default to 10000?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perplexity seems to be passing, so not a blocker.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

llama3 defaults to 500000 so I think we should use that

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we know where that 10000 came from?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

llama2

@@ -160,6 +157,23 @@ def prefill(self):
attention_mask = replicate(attention_mask, tp)
seq_block_ids_tensor = replicate(seq_block_ids_tensor, tp)

if self.dump_bins:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use generate_data.py to fetch input data, given a model & a prompt.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doesn't work for values not supported by numpy

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we update generate_data.py to use torch tensors instead of numpy arrays? Should work around this issue

return matmul(x_layout.qs, weight_layout.qs, transpose_rhs=True).to(
torch.float16
)
return matmul(x_layout.qs, weight_layout.qs, transpose_rhs=True)
Copy link
Contributor

@nithinsubbiah nithinsubbiah Jan 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we do a return here, we're not actually inserting the mmt kernel which is what's intended from this script. This function is called when the input tensors are quantized and at least punet still expects to have the quantized kernel inserted if and when that happens

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does the mmt kernel do? because we can just lower fp8 matmul with torch

@dan-garvey
Copy link
Member Author

Do you have a test or a command I can run to verify the numerics are correct?

you can run pytest sharktank/tests/models/llama/quark_parity_test.py on mi300x right now. I plan to make it run anywhere in a follow up patch

Copy link
Contributor

@IanNod IanNod left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly looks good to me, some minor nit comments

@@ -160,6 +157,23 @@ def prefill(self):
attention_mask = replicate(attention_mask, tp)
seq_block_ids_tensor = replicate(seq_block_ids_tensor, tp)

if self.dump_bins:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we update generate_data.py to use torch tensors instead of numpy arrays? Should work around this issue

@@ -49,7 +49,7 @@ def from_gguf_props(p: dict[str, Any]):
name_prefix = p.get("general.architecture", "llama")
default_expert_count = 0
default_expert_used_count = 0
default_rope_freq_base = 10000.0
default_rope_freq_base = 500000.0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we know where that 10000 came from?

@@ -88,7 +87,7 @@ def forward(self, x):
# level to do this, but for now its here.
if not isinstance(y, QuantizedTensor):
if y.dtype == torch.float8_e4m3fnuz:
y = ops.to(y, torch.float16)
y = ops.to(y, torch.bfloat16)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not a fan that the python implementation can really only compare for one specific quantization method like this. I don't have an answer off the top of my head, so fine for now but ideally would be good to make more agnostic somehow

@IanNod
Copy link
Contributor

IanNod commented Jan 29, 2025

Do you have a test or a command I can run to verify the numerics are correct?

you can run pytest sharktank/tests/models/llama/quark_parity_test.py on mi300x right now. I plan to make it run anywhere in a follow up patch

Would be great if you could add to documentation so it's easy for anyone else to pick up and test/use fp8 models

@dan-garvey
Copy link
Member Author

#880
#881
@IanNod

@dan-garvey
Copy link
Member Author

I'll add documentation on halo models this afternoon

@dan-garvey
Copy link
Member Author

10000 is the default for llama2

@dan-garvey dan-garvey merged commit 1392a2e into main Jan 29, 2025
33 checks passed
@dan-garvey dan-garvey deleted the users/dan_garvey/fp8_staging branch January 29, 2025 20:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants