Skip to content

Commit

Permalink
Fix MX4 Illegal Memory Access for Large Inputs (#3229)
Browse files Browse the repository at this point in the history
Summary:

X-link: facebookresearch/FBGEMM#327

This diff adds handling for very large input tensors to fbgemm's mx4 routines. By default, triton uses int32 indexing which can overflow if the number of elements in an input exceeds the expressible limit of int32. We add a simple check to see if int64 indexing is needed to prevent this failure mode. Notably, using int64 indexing does hurt performance a bit so this diff avoids using it if not necessary.

Reviewed By: qchip

Differential Revision: D63992362
  • Loading branch information
jwfromm authored and facebook-github-bot committed Oct 7, 2024
1 parent e7f89e4 commit 40e4c59
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 0 deletions.
18 changes: 18 additions & 0 deletions fbgemm_gpu/fbgemm_gpu/triton/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def _kernel_quantize_mx4(
STOCHASTIC_CASTING: tl.constexpr,
FP4_EXP_BIAS: tl.constexpr,
GROUP_LOAD: tl.constexpr,
USE_INT64: tl.constexpr,
) -> None:
"""Quantize a 1D float tensor into a packed MX4 tensor.
Expand All @@ -122,6 +123,7 @@ def _kernel_quantize_mx4(
STOCHASTIC_CASTING (bool): Whether to use stochastic rounding when downcasting.
FP4_EXP_BIAS (int): Exponent bias of target mx4 format.
GROUP_LOAD (int): Number of groups to process simultaneously.
USE_INT64 (bool): Whether to use int64 for indexing. This is needed for large tensors.
"""
# Define Constant Expressions.
FP32_EXP_MASK: tl.constexpr = 0x7F800000 # type: ignore[Incompatible variable type]
Expand All @@ -147,6 +149,9 @@ def _kernel_quantize_mx4(

# Get the current thread number.
pid = tl.program_id(0)
# For very large inputs, we need to use int64 indexes. This is slower but necessary.
if USE_INT64:
pid = pid.to(tl.int64)
# Find starting offsets for this thread. These are calculated before adjusting for padding.
input_start = pid * (GROUPS_PER_THREAD * GROUP_SIZE)
output_start = pid * OUTPUT_CHUNK_SIZE
Expand Down Expand Up @@ -428,6 +433,8 @@ def triton_quantize_mx4(
else:
rand_bits = None

# Check if we need to use int64 for indexing.
use_int64 = a.numel() > 2**31 - 1
# Invoke triton quantization kernel over rows.
grid = (num_threads,)
_kernel_quantize_mx4[grid](
Expand All @@ -452,6 +459,8 @@ def triton_quantize_mx4(
FP4_EXP_BIAS=get_mx4_exp_bias(ebits),
# pyre-ignore[6]
GROUP_LOAD=GROUP_LOAD,
# pyre-ignore[6]
USE_INT64=use_int64,
)
# Inputs are now fully quantized and ready to return.
# Try to return in the original shape if possible.
Expand All @@ -472,6 +481,7 @@ def _kernel_dequantize_mx4(
GROUPS_PER_THREAD,
GROUP_SIZE: tl.constexpr,
GROUP_LOAD: tl.constexpr,
USE_INT64: tl.constexpr,
) -> None:
"""Dequantize a packed MX4 tensor and apply scaling.
Expand All @@ -483,6 +493,7 @@ def _kernel_dequantize_mx4(
GROUPS_PER_THREAD (int): Number of groups each thread is responsible for.
GROUP_SIZE (int): Size of chunks that use the same shared exponent.
GROUP_LOAD (int): Number of groups to process simultaneously.
USE_INT64 (bool): Whether to use int64 for indexing.
"""
# Define constants.
MX4_BIT_MASK: tl.constexpr = 0xF # type: ignore[Incompatible variable type]
Expand All @@ -495,6 +506,9 @@ def _kernel_dequantize_mx4(

# Get the current thread number.
pid = tl.program_id(0)
# For very large tensors, use int64 for indexing. This is slower but necessary.
if USE_INT64:
pid = pid.to(tl.int64)
# Find the starting offsets for this thread.
input_start = pid * (GROUPS_PER_THREAD * PACKED_GROUP_SIZE)
exp_start = input_start + GROUP_SIZE // 2
Expand Down Expand Up @@ -600,6 +614,8 @@ def triton_dequantize_mx4(
# Create output tensor.
output_elems = num_groups * group_size
out = torch.empty([output_elems], device=a.device, dtype=torch.float)
# Check if we need to use int64 for indexing.
use_int64 = a.numel() > 2**31 - 1
# Invoke triton dequantization kernel over rows.
grid = (num_threads,)
_kernel_dequantize_mx4[grid](
Expand All @@ -612,6 +628,8 @@ def triton_dequantize_mx4(
GROUP_SIZE=group_size,
# pyre-ignore[6]
GROUP_LOAD=GROUP_LOAD,
# pyre-ignore[6]
USE_INT64=use_int64,
)

out_shape = list(orig_shape[:-1]) + [-1]
Expand Down
10 changes: 10 additions & 0 deletions fbgemm_gpu/test/quantize/mx4_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,16 @@ def test_mx4_cases(
# I give quite a bit of wiggle room to make sure this isnt flaky.
torch.testing.assert_close(input, mx_dequantized, rtol=1.0, atol=magnitude / 2)

# pyre-fixme[56]:
@unittest.skipIf(*gpu_unavailable)
def test_mx4_index_overflow(self) -> None:
"""Tests that mx4 quantization kernels can handle inputs that would overflow int32 indices."""
large_input = torch.zeros(2**32, dtype=torch.float32).to("cuda")
mx_quantized = fp32_to_mx4(large_input, 32)
mx_dequantized = mx4_to_fp32(mx_quantized, 32)
# We just need to check that everything ran without an illegal memory access.
assert mx_dequantized[0] == 0


if __name__ == "__main__":
unittest.main()

0 comments on commit 40e4c59

Please sign in to comment.