From fd28e848294f967d0477997ac3bf75785ca0e972 Mon Sep 17 00:00:00 2001 From: Sergey Kozub Date: Tue, 14 Jan 2025 10:26:14 -0800 Subject: [PATCH] PR #21380: Add F4E2M1FN and F8E8M0FNU types MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Imported from GitHub PR https://github.com/openxla/xla/pull/21380 Previous PR https://github.com/openxla/xla/pull/19096 was rolled back, re-trying. This PR adds F4E2M1FN primitive type (4-bit float with 2 bits exponent and 1 bit mantissa), F8E8M0FNU primitive type (8-bit float with 8 bits exponent, no mantissa and no sign) and enables loads/stores in the same way S4/U4 type is implemented. This will enable using microscaling (MX) formats ([RFC](https://github.com/openxla/xla/discussions/18085)), such as MXFP4. ```c F4E2M1FN - Exponent bias: 1 - Maximum stored exponent value: 3 (binary 11) - Maximum unbiased exponent value: 3 - 1 = 2 - Minimum stored exponent value: 1 (binary 01) - Minimum unbiased exponent value: 1 − 1 = 0 - Has Positive and Negative zero - Doesn't have infinity - Doesn't have NaNs Additional details: - Zeros (+/-): S.00.0 - Max normal number: S.11.1 = ±2^(2) x (1 + 0.5) = ±6.0 - Min normal number: S.01.0 = ±2^(0) = ±1.0 - Min subnormal number: S.00.1 = ±2^(0) x 0.5 = ±0.5 F8E8M0FNU - Exponent bias: 127 - Maximum stored exponent value: 254 (binary 1111'1110) - Maximum unbiased exponent value: 254 - 127 = 127 - Minimum stored exponent value: 0 (binary 0000'0000) - Minimum unbiased exponent value: 0 − 127 = -127 - Doesn't have zero - Doesn't have infinity - NaN is encoded as binary 1111'1111 Additional details: - Zeros cannot be represented - Negative values cannot be represented - Mantissa is always 1 ``` Related PRs: - https://github.com/openxla/stablehlo/pull/2582 - https://github.com/jax-ml/ml_dtypes/pull/181 - https://github.com/llvm/llvm-project/pull/95392 - https://github.com/llvm/llvm-project/pull/108877 - https://github.com/jax-ml/ml_dtypes/pull/166 - https://github.com/llvm/llvm-project/pull/107127 - https://github.com/llvm/llvm-project/pull/111028 Copybara import of the project: -- d7e00c49a4b4f26c06266d6bb941275e67464c01 by Sergey Kozub : Add F4E2M1FN and F8E8M0FNU types Merging this change closes #21380 PiperOrigin-RevId: 715434229 --- tests/Dialect/mhlo/ops.mlir | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tests/Dialect/mhlo/ops.mlir b/tests/Dialect/mhlo/ops.mlir index d07a178c6..16a64cdc2 100644 --- a/tests/Dialect/mhlo/ops.mlir +++ b/tests/Dialect/mhlo/ops.mlir @@ -6844,6 +6844,13 @@ func.func @invalid_dimension_attr(%arg0: tensor) -> tensor { + %0 = "mhlo.convert"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// ----- + func.func @f8e3m4(%arg0: tensor) -> tensor { %0 = "mhlo.convert"(%arg0) : (tensor) -> tensor func.return %0 : tensor @@ -6872,6 +6879,13 @@ func.func @f8e5m2(%arg0: tensor) -> tensor { // ----- +func.func @f8e8m0fnu(%arg0: tensor) -> tensor { + %0 = "mhlo.convert"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// ----- + func.func @top_k_1d(%arg0 : tensor<16xf32>) { %0:2 = mhlo.topk(%arg0, k=8, largest=true) : tensor<16xf32> -> (tensor<8xf32>, tensor<8xi32>) return