Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Sharktank][Llama][FP8] Minimal changes for numerically correct fp8 #859

Merged
merged 14 commits into from
Jan 29, 2025
62 changes: 51 additions & 11 deletions sharktank/sharktank/examples/paged_llm_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,14 @@ def __init__(
page_cache_size: int = 128,
# Need to look at the model more for this.
end_token: int = 2,
dump_bins: bool = False,
):
self.model = model
self.tokenizer = tokenizer
if self.model.config.kv_cache_type == "paged":
self.shared_cache_state = model.cache.allocate(page_cache_size)
self.free_pages = list(range(1, page_cache_size))
else:
self.shared_cache_state = None
self.shared_cache_state = model.cache.allocate(page_cache_size)
zjgarvey marked this conversation as resolved.
Show resolved Hide resolved
self.free_pages = list(range(1, page_cache_size))
self.end_token = end_token
self.dump_bins = dump_bins

@property
def block_seq_stride(self) -> int:
Expand All @@ -64,13 +63,9 @@ def begin_batch(self, prompts: list[str]):
cache_state = self.shared_cache_state
else:
cache_state = self.model.cache.allocate(bs=len(prompts))
return Batch(self, token_ids, seq_lens, cache_state)
return Batch(self, token_ids, seq_lens, cache_state, dump_bins=self.dump_bins)

def alloc_page(self) -> int:
if self.model.config.kv_cache_type == "direct":
# We don't allocate block ids for the direct cache.
return 0

return self.free_pages.pop()

def release_page(self, index: int):
Expand All @@ -86,6 +81,7 @@ def __init__(
token_ids: torch.Tensor,
seq_lens: torch.Tensor,
cache_state: list[torch.Tensor],
dump_bins: bool = False,
):
self.bs = token_ids.shape[0]
assert seq_lens.shape[0] == self.bs
Expand All @@ -95,6 +91,7 @@ def __init__(
self.cache_state = cache_state
self.results: list[list[int]] = [[] for _ in range(self.bs)]
self.done_result_indices: set[int] = set()
self.dump_bins = dump_bins

# Assemble the batch.
seq_stride = self.parent.block_seq_stride
Expand Down Expand Up @@ -160,6 +157,23 @@ def prefill(self):
attention_mask = replicate(attention_mask, tp)
seq_block_ids_tensor = replicate(seq_block_ids_tensor, tp)

if self.dump_bins:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use generate_data.py to fetch input data, given a model & a prompt.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doesn't work for values not supported by numpy

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we update generate_data.py to use torch tensors instead of numpy arrays? Should work around this issue

torch.save(
token_ids,
f"prefill_token_ids_{'_'.join([str(x) for x in token_ids.shape])}.bin",
)
torch.save(
torch.tensor(token_ids.shape[0]).to(torch.int64),
f"prefill_seq_lens_1.bin",
)
torch.save(
seq_block_ids_tensor,
f"prefill_seq_block_ids_{'_'.join([str(x) for x in seq_block_ids_tensor.shape])}.bin",
)
torch.save(
self.cache_state[0].to(torch.float8_e4m3fnuz),
f"prefill_cache_state_{'_'.join([str(x) for x in self.cache_state[0].shape])}.bin",
)
logits = model.prefill(
token_ids,
attention_mask=attention_mask,
Expand Down Expand Up @@ -204,6 +218,27 @@ def decode(self):
seq_block_ids_tensor = replicate(seq_block_ids_tensor, tp)
decode_attention_mask = replicate(decode_attention_mask, tp)

if self.dump_bins:
torch.save(
self.next_tokens,
f"decode_next_tokens_{'_'.join([str(x)for x in self.next_tokens.shape])}.bin",
)
torch.save(
start_positions,
f"decode_start_positions_{'_'.join([str(x)for x in start_positions.shape])}.bin",
)
torch.save(
seq_block_ids_tensor,
f"decode_seq_block_ids_tensor_{'_'.join([str(x)for x in seq_block_ids_tensor.shape])}.bin",
)
torch.save(
torch.tensor(self.next_tokens.shape[0]).to(torch.int64),
f"decode_seq_lens_1.bin",
)
torch.save(
self.cache_state[0].to(torch.float8_e4m3fnuz),
f"decode_cache_state_{'_'.join([str(x) for x in self.cache_state[0].shape])}.bin",
)
logits = model.decode(
self.next_tokens,
attention_mask=decode_attention_mask,
Expand Down Expand Up @@ -238,6 +273,11 @@ def main():
"--save_intermediates_path",
help="save module forward outputs to safetensors, ex: run_0 will save to run_0_prefill.savetensors",
)
parser.add_argument(
"--dump-bins",
help="dump input tensors to bin files",
action="store_true",
)
cli.add_input_dataset_options(parser)
cli.add_tokenizer_options(parser)
cli.add_quantization_options(parser)
Expand Down Expand Up @@ -274,7 +314,7 @@ def main():

intermediates_saver = SaveModuleResultTensorsPatch()
intermediates_saver.patch_child_modules(model)
generator = TorchGenerator(model, tokenizer)
generator = TorchGenerator(model, tokenizer, dump_bins=args.dump_bins)

print(f":: Prompting:")
for prompt in prompts:
Expand Down
1 change: 1 addition & 0 deletions sharktank/sharktank/kernels/bitcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
]

_ftype_to_ctype_table = {
torch.bfloat16: torch.complex32,
torch.float16: torch.complex32,
torch.float32: torch.complex64,
}
Expand Down
2 changes: 1 addition & 1 deletion sharktank/sharktank/layers/configs/llm_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def from_gguf_props(p: dict[str, Any]):
name_prefix = p.get("general.architecture", "llama")
default_expert_count = 0
default_expert_used_count = 0
default_rope_freq_base = 10000.0
default_rope_freq_base = 500000.0
Copy link
Collaborator

@archana-ramalingam archana-ramalingam Jan 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we pass this separately, as I have noticed rope_freq_base not being explicitly set in some models and might need to default to 10000?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perplexity seems to be passing, so not a blocker.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

llama3 defaults to 500000 so I think we should use that

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we know where that 10000 came from?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

llama2

default_rope_dimension_count = 128
attention_head_count = _int_prop(p, f"{name_prefix}.attention.head_count")
rope_dimension_count = _optional_int_prop(
Expand Down
11 changes: 8 additions & 3 deletions sharktank/sharktank/layers/ffn_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,20 @@ def __init__(
theta: Theta,
is_gated: bool = True,
activation_fn: Callable[[AnyTensor], AnyTensor] = F.silu,
fake_quant: bool = False,
):
super().__init__(theta)

self.is_gated = is_gated
self.activation_fn = activation_fn
if self.is_gated:
self.add_module("ffn_gate", LinearLayer(theta("ffn_gate")))
self.add_module("ffn_up", LinearLayer(theta("ffn_up")))
self.add_module("ffn_down", LinearLayer(theta("ffn_down")))
self.add_module(
"ffn_gate", LinearLayer(theta("ffn_gate"), fake_quant=fake_quant)
)
self.add_module("ffn_up", LinearLayer(theta("ffn_up"), fake_quant=fake_quant))
self.add_module(
"ffn_down", LinearLayer(theta("ffn_down"), fake_quant=fake_quant)
)

def forward(
self,
Expand Down
3 changes: 1 addition & 2 deletions sharktank/sharktank/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ def forward(self, x):
x = qdq_input.quantize(x).unpack().dequant()

y = ops.linear(x, weight, bias)

# Unconditionally dequantize.
if isinstance(y, QuantizedTensor):
y = y.unpack().dequant()
Expand All @@ -88,7 +87,7 @@ def forward(self, x):
# level to do this, but for now its here.
if not isinstance(y, QuantizedTensor):
if y.dtype == torch.float8_e4m3fnuz:
y = ops.to(y, torch.float16)
y = ops.to(y, torch.bfloat16)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this change because the specific float8 type accumulates to something which only truncates safely to bfloat16 instead of float16?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No this is an artifact of the way the model was quantized. The actual fp8 matmul intrinsic accumulates into f32, which iree can truncate, but in python we just cast to match the reference model

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not a fan that the python implementation can really only compare for one specific quantization method like this. I don't have an answer off the top of my head, so fine for now but ideally would be good to make more agnostic somehow

return y
if qdq_output is not None:
y = qdq_output.quantize(y).unpack().dequant()
Expand Down
2 changes: 1 addition & 1 deletion sharktank/sharktank/layers/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(
def forward(self, x: torch.Tensor):
orig_dtype = x.dtype
x = ops.to(x, self.dtype)
norm = ops.rms_norm(x, self.weight, epsilon=self.epsilon)
norm = ops.rms_norm(x, self.weight, epsilon=self.epsilon, orig_dtype=orig_dtype)
# Will automatically upcast to the dtype of the weight, which is
# often in higher precision. Downcast back to expected.
norm = ops.to(norm, orig_dtype)
Expand Down
27 changes: 6 additions & 21 deletions sharktank/sharktank/layers/paged_llama_attention_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,7 @@ def forward(
cache_state: list[torch.Tensor] = None,
):
assert bool(start_index is not None) ^ bool(embedding_batch_mask is not None)

x = self.attn_norm(h)

bs, batch_seq_len, feature_dim = x.shape
assert feature_dim == self.head_count * self.head_dim

Expand All @@ -128,22 +126,7 @@ def forward(

# Used by fp8_e4m3fnuz model
if self.cache_quantizer is not None:
# For fake quant, store the fp16 qdq value in the cache
if self.fake_quant:
xk = (
self.cache_quantizer.quantize(xk)
.unpack()
.dequant()
.to(torch.float16)
)
xv = (
self.cache_quantizer.quantize(xv)
.unpack()
.dequant()
.to(torch.float16)
)
Comment on lines -133 to -144

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why was this removed?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quark's model loader didn't support the fp8 kv cache. We are still doing it for export, but it is missing in the python comparison.

# For real quant, store the quantized fp8 value in the cache
else:
if not self.fake_quant:
# TODO: this seems like a bastardization of our quantized tensor api
# Probably want to add support for using quantized tensors more directly
xk = self.cache_quantizer.quantize(xk).unpack().qs
Expand Down Expand Up @@ -175,11 +158,14 @@ def repeat_kv(x: torch.Tensor) -> torch.Tensor:
# Fake quant is already dequantized when stored in the cache.
if self.cache_quantizer and not self.fake_quant:
xk = self.cache_quantizer.dequantize_raw_tensor(
xk, torch.float16, name="xk_deq"
xk, torch.bfloat16, name="xk_deq"
)
xv = self.cache_quantizer.dequantize_raw_tensor(
xv, torch.float16, name="xv_deq"
xv, torch.bfloat16, name="xv_deq"
)
if attention_mask is not None:
attention_mask = attention_mask.to(torch.bfloat16)
Comment on lines +166 to +167

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see any quantization stuff here. Is the indentation incorrect?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if self.cache_quantizer and not self.fake_quant:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that's what I'm asking about. Do we need cache_quantizer and not fake_quant for the attention mask to be in bfloat16? I guess that's probably the case.


# Transpose into [bs, heads, sl, dim]
xq = xq.transpose(1, 2)
keys = xk.transpose(1, 2)
Expand Down Expand Up @@ -223,7 +209,6 @@ def repeat_kv(x: torch.Tensor) -> torch.Tensor:

attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.flatten(2, 3)

# Project.
attn_output = self.attn_output(attn_output)
attn_output = self.attn_output_norm(attn_output)
Expand Down
8 changes: 7 additions & 1 deletion sharktank/sharktank/models/llama/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@
from ...utils.create_cache import *
from ... import ops


from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding

__all__ = [
"PagedLlamaModelV1",
]
Expand Down Expand Up @@ -82,7 +86,7 @@ def __init__(self, theta: Theta, config: LlamaModelConfig):

self.add_module(
"token_embedding",
TokenEmbeddingLayer(theta("token_embd"), dtype=config.activation_dtype),
TokenEmbeddingLayer(theta("token_embd"), dtype=self.activation_dtype),

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From earlier in __init__, it doesn't look like self.activation_dtype is different from config.activation_dtype.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this change is a no op. Style preference.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For consistency, we can use self or config

)
self.add_module(
"attention_embedding",
Expand All @@ -93,6 +97,7 @@ def __init__(self, theta: Theta, config: LlamaModelConfig):
device=self.device,
use_hf=self.use_hf,
tensor_parallelism_size=config.tensor_parallelism_size,
dtype=config.activation_dtype,
),
)
self.add_module(
Expand Down Expand Up @@ -258,6 +263,7 @@ def __init__(
"ffn",
FFN(
theta=theta,
fake_quant=fake_quant,
),
)
self.add_module(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def apply_per_layer_quant(

# It looks dumb but, this step is required for numerical correctness against quark.
# weight = weight.view(torch.float8_e4m3fn)
weight = (weight.to(torch.float64) * weight_quant_scale).to(torch.float16)
weight = (weight.to(torch.float64) * weight_quant_scale).to(torch.bfloat16)

weight_quant_zero_point = layer_theta.optional_tensor("weight_zero_point")
if weight_quant_zero_point == None:
Expand Down Expand Up @@ -148,6 +148,7 @@ def quantize_weight(
weight_quant = weight_quantizer.quantize(weight, name=weight_name)
updated_tensors[weight_quant.name] = weight_quant

# In older quark models the qkv layer is fused. Unfuse.
if "qkv" in layer_name:
# The qkv layer is fused in the quark model, decompose back into individual q, k , and v weights
q_weight, k_weight, v_weight = torch.split(weight, split_sizes)
Expand Down Expand Up @@ -275,8 +276,6 @@ def single_replace(
updated_tensors: dict[str, InferenceTensor],
):
data = quant_theta(layer_name).tensor("weight").as_torch()
if data.dtype == torch.bfloat16:
data = data.to(torch.float32)
updated_tensors[gguf_name] = DefaultPrimitiveTensor(name=gguf_name, data=data)


Expand Down
9 changes: 6 additions & 3 deletions sharktank/sharktank/ops/default_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,11 +400,14 @@ def reshape_default(input: Union[PrimitiveTensor, Tensor], shape: List[int]) ->

# RMS norm
@rms_norm.override(AllOfType(Tensor, InferenceTensor))
def rms_norm_default(x, weight, *, epsilon: float) -> Tensor:
def rms_norm_default(
x, weight, *, epsilon: float, orig_dtype: Union[None, torch.dtype]
) -> Tensor:
if orig_dtype is None:
orig_dtype = x.dtype
variance = x.pow(2).mean(-1, keepdim=True)
output = x * elementwise(torch.rsqrt, variance + epsilon)
# The cast here is to match the hf implementation, affects numerics
output = elementwise(torch.mul, weight, to(output, weight.dtype))
output = elementwise(torch.mul, weight, to(output, orig_dtype))
return output


Expand Down
4 changes: 1 addition & 3 deletions sharktank/sharktank/ops/qlinear_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,7 @@ def qlinear_tensor_scaled(
if x_layout.qs.dtype.is_floating_point or weight_layout.qs.dtype.is_floating_point:
if x_layout.qs.dtype == torch.float8_e4m3fnuz:
# assume quark
return matmul(x_layout.qs, weight_layout.qs, transpose_rhs=True).to(
torch.float16
)
return matmul(x_layout.qs, weight_layout.qs, transpose_rhs=True)
Copy link
Contributor

@nithinsubbiah nithinsubbiah Jan 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we do a return here, we're not actually inserting the mmt kernel which is what's intended from this script. This function is called when the input tensors are quantized and at least punet still expects to have the quantized kernel inserted if and when that happens

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does the mmt kernel do? because we can just lower fp8 matmul with torch

else:
return NotImplemented

Expand Down
13 changes: 10 additions & 3 deletions sharktank/sharktank/ops/signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,18 +769,25 @@ def _module_register_buffer_trampoline(


@overridable
def rms_norm(x: AnyTensor, weight: AnyTensor, *, epsilon: float) -> AnyTensor:
def rms_norm(
x: AnyTensor, weight: AnyTensor, *, epsilon: float, orig_dtype: torch.dtype
) -> AnyTensor:
"""Computes the full, unbiased RMS normalization of an input."""
raise NotImplementedError


@rms_norm.trampoline
def _rms_norm_trampoline(
d: SignatureDispatcher, x: AnyTensor, weight: AnyTensor, *, epsilon: float
d: SignatureDispatcher,
x: AnyTensor,
weight: AnyTensor,
*,
epsilon: float,
orig_dtype: torch.dtype,
):
tensors = (x, weight)
for override in d.find_overrides(tensors):
result = override(x, weight, epsilon=epsilon)
result = override(x, weight, epsilon=epsilon, orig_dtype=orig_dtype)
if result is not NotImplemented:
return override, result
else:
Expand Down
Loading
Loading