Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

env variable to select rounding mode #3515

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions fbgemm_gpu/fbgemm_gpu/envs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import os

from typing import Any, Callable, Dict

# pyre-ignore[5]
environment_variables: Dict[str, Callable[[], Any]] = {
# Decide which rounding mode to use when doing quantization and dequantization to/from MX4
# check https://fburl.com/code/rohboxgv for what's available
"MX4_QUANT_ROUNDING_MODE": lambda: os.getenv("MX4_QUANT_ROUNDING_MODE", "nearest"),
}


# pyre-ignore[3]
def __getattr__(name: str):
if name in environment_variables:
return environment_variables[name]()
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")


# pyre-ignore[3]
def __dir__():
return list(environment_variables.keys())
15 changes: 13 additions & 2 deletions fbgemm_gpu/fbgemm_gpu/quantize_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from typing import Optional, Union

import torch
from fbgemm_gpu import envs

from fbgemm_gpu.triton import dequantize_mx4, quantize_mx4, RoundingMode
from fbgemm_gpu.triton.quantize_ref import py_dequantize_mx4, py_quantize_mx4
Expand All @@ -36,7 +37,7 @@ def fp32_to_mx4(
group_size: int = 32,
ebits: int = 2,
mbits: int = 1,
rounding_mode: Optional[Union[RoundingMode, int]] = RoundingMode.even,
rounding_mode: Optional[Union[RoundingMode, int]] = None,
stochastic_casting: bool = False,
use_triton: bool = True,
) -> torch.Tensor:
Expand All @@ -58,7 +59,17 @@ def fp32_to_mx4(
# Accelerated MX4 is only available on cuda, if input is on cpu, use python.
# Operate on flattened input.
if rounding_mode is None:
rounding_mode = RoundingMode.even
try:
rounding_mode = RoundingMode.__members__.get(
envs.MX4_QUANT_ROUNDING_MODE, RoundingMode.even
)
except Exception as e:
logger.error(
"Failed to get MX4_QUANT_ROUNDING_MODE env var: {}. Fall back to RoundingMode.even".format(
e
)
)
rounding_mode = RoundingMode.even

if not tensor.is_cuda:
return py_quantize_mx4(
Expand Down
9 changes: 4 additions & 5 deletions fbgemm_gpu/fbgemm_gpu/triton/quantize_ref.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def py_quantize_mx4(
eg.
Input with shape [1, 8192] will be quantized to [1, 4096 + 256] as
each value contain two elements packed into an int8 and
there are 32 groups in each row.
there are 256 (8192 / group_size) groups in each row.
"""
# Define helpful constants.
FP32_MIN_NORMAL = 2 ** (-126)
Expand Down Expand Up @@ -150,16 +150,15 @@ def py_quantize_mx4(
biased_exp = torch.bitwise_and(a, FP32_EXP_MASK)
# Shift exponent over to least significant bits.
biased_exp = torch.bitwise_right_shift(biased_exp, FP32_EXP_OFFSET).to(torch.int8)

# Finally extract the mantissa.
trailing_mantissa = torch.bitwise_and(a, FP32_MANTISSA_MASK)
new_biased_exp = biased_exp - FP32_EXP_BIAS + FP4_EXP_BIAS

# Compute difference between ideal exponent and what can be represented.
exp_diff = torch.where(new_biased_exp <= 0, 1 - new_biased_exp, 0)
# Clip this difference to the maximum number of fp32 mantissa bits (23 + implicit).
exp_diff = torch.clamp(exp_diff, max=MAX_FP32_MANTISSA_BITS)

# Finally extract the mantissa.
trailing_mantissa = torch.bitwise_and(a, FP32_MANTISSA_MASK)

# Now perform mantissa rounding down to fp4.
is_subnorm = biased_exp == 0
# Add implied 1 to normal values.
Expand Down
57 changes: 57 additions & 0 deletions fbgemm_gpu/test/quantize/mx4_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,63 @@ def test_mx4_index_overflow_large_input(self) -> None:
# We just need to check that everything ran without an illegal memory access.
assert mx_dequantized[0][0] == 0

@unittest.skipIf(
not (
torch.cuda.is_available() and torch.cuda.mem_get_info()[0] / (1024**3) >= 32
),
"Test requires a gpu with at least 32GB of memory.",
)
# pyre-ignore[56]
@given(
shape=st.sampled_from(
[
[2 ^ 31 - 1], # Small shape with group_size = num_elements.
[1024 * 1024, 1024], # Multi dimensional shape that is padded.
[16, 1028], # Large shape with multiple padded rows.
[4, 30], # Multiple small rows with padding.
]
),
group_size=st.sampled_from([32, 64]),
magnitude=st.sampled_from([1.0, 1e3, 1e-3]),
mx4_format=st.sampled_from([(2, 1)]),
device=st.sampled_from(["cuda"]),
)
@settings(verbosity=Verbosity.verbose, max_examples=50, deadline=None)
def test_mx4_large_cases(
self,
shape: List[int],
group_size: int,
magnitude: int,
mx4_format: Tuple[int, int],
device: str,
) -> None:
"""Test correctness of mx4 routines with random inputs and unusual shapes."""
# We only want to consider total sizes that are divisible by group_size.
ebits, mbits = mx4_format

# Generate a random input with the specified magnitude.
input = torch.randn(shape, device=device, dtype=torch.float32) * magnitude

# Perform quant then dequant to check that proper shape is maintained and
# outputs are reasonably correct.
mx_quantized = fp32_to_mx4(input, group_size, ebits=ebits, mbits=mbits)
mx_dequantized = mx4_to_fp32(mx_quantized, group_size, ebits=ebits, mbits=mbits)

# If the rows of input are not divisible by group_size, we expect the output
# to be padded.
if input.shape[-1] % group_size != 0:
pad = group_size - (input.shape[-1] % group_size)
input = torch.nn.functional.pad(input, (0, pad))

# Check that output shape matches input shape.
assert mx_dequantized.shape == input.shape

# Check that values are reasonably close, based on expected variance.
# I give quite a bit of wiggle room to make sure this isnt flaky.
torch.testing.assert_close(input, mx_dequantized, rtol=1.0, atol=magnitude / 2)
assert not torch.isnan(mx_dequantized).any()
assert not torch.isinf(mx_dequantized).any()


if __name__ == "__main__":
unittest.main()