Skip to content

Commit

Permalink
PR #19096: Add F4E2M1FN and F8E8M0FNU types
Browse files Browse the repository at this point in the history
Imported from GitHub PR #19096

This PR adds F4E2M1FN primitive type (4-bit float with 2 bits exponent and 1 bit mantissa), F8E8M0FNU primitive type (8-bit float with 8 bits exponent, no mantissa and no sign) and enables loads/stores in the same way S4/U4 type is implemented.

This will enable using microscaling (MX) formats ([RFC](#18085)), such as MXFP4.

```c
F4E2M1FN
- Exponent bias: 1
- Maximum stored exponent value: 3 (binary 11)
- Maximum unbiased exponent value: 3 - 1 = 2
- Minimum stored exponent value: 1 (binary 01)
- Minimum unbiased exponent value: 1 − 1 = 0
- Has Positive and Negative zero
- Doesn't have infinity
- Doesn't have NaNs

Additional details:
- Zeros (+/-): S.00.0
- Max normal number: S.11.1 = ±2^(2) x (1 + 0.5) = ±6.0
- Min normal number: S.01.0 = ±2^(0) = ±1.0
- Min subnormal number: S.00.1 = ±2^(0) x 0.5 = ±0.5

F8E8M0FNU
- Exponent bias: 127
- Maximum stored exponent value: 254 (binary 1111'1110)
- Maximum unbiased exponent value: 254 - 127 = 127
- Minimum stored exponent value: 0 (binary 0000'0000)
- Minimum unbiased exponent value: 0 − 127 = -127
- Doesn't have zero
- Doesn't have infinity
- NaN is encoded as binary 1111'1111

Additional details:
- Zeros cannot be represented
- Negative values cannot be represented
- Mantissa is always 1
```

Related PRs:
- openxla/stablehlo#2582
- jax-ml/ml_dtypes#181
- llvm/llvm-project#95392
- llvm/llvm-project#108877
- jax-ml/ml_dtypes#166
- llvm/llvm-project#107127
- llvm/llvm-project#111028

The PR is split into multiple commits just to make the review easier, it is possible that some tests could fail if only some (i.e. not all) of these commits are applied.
Copybara import of the project:

--
fa539fb by Sergey Kozub <[email protected]>:

Add F4E2M1FN type: import mxfloat.h

--
2c01403 by Sergey Kozub <[email protected]>:

Add F4E2M1FN type: primitive type

--
e919ed5 by Sergey Kozub <[email protected]>:

Add F4E2M1FN type: literal support

--
ca16839 by Sergey Kozub <[email protected]>:

Add F4E2M1FN type: conversion codegen

--
eedc079 by Sergey Kozub <[email protected]>:

Add F4E2M1FN type: python interface

--
8e0305c by Sergey Kozub <[email protected]>:

Add F4E2M1FN type: FFI

--
aabe9c6 by Sergey Kozub <[email protected]>:

Add F4E2M1FN type: HLO evaluator

--
87da2eb by Sergey Kozub <[email protected]>:

Add F4E2M1FN type: add tests

--
e0ee48c by Sergey Kozub <[email protected]>:

Add F8E8M0FNU type

--
be2e457 by Sergey Kozub <[email protected]>:

Addressing PR#19096 review comments

Merging this change closes #19096

FUTURE_COPYBARA_INTEGRATE_REVIEW=#19096 from openxla:skozub/e2m1 be2e457
PiperOrigin-RevId: 702273510
  • Loading branch information
sergey-kozub authored and Google-ML-Automation committed Dec 3, 2024
1 parent 2500111 commit 1e145f9
Show file tree
Hide file tree
Showing 79 changed files with 1,767 additions and 351 deletions.
1 change: 1 addition & 0 deletions third_party/tsl/tsl/platform/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1066,6 +1066,7 @@ cc_library(
deps = [
"@ml_dtypes//:float8",
"@ml_dtypes//:intn",
"@ml_dtypes//:mxfloat",
],
)

Expand Down
3 changes: 3 additions & 0 deletions third_party/tsl/tsl/platform/ml_dtypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,18 @@ limitations under the License.

#include "ml_dtypes/include/float8.h" // from @ml_dtypes
#include "ml_dtypes/include/intn.h" // from @ml_dtypes
#include "ml_dtypes/include/mxfloat.h" // from @ml_dtypes

namespace tsl {
using float4_e2m1fn = ::ml_dtypes::float4_e2m1fn;
using float8_e3m4 = ::ml_dtypes::float8_e3m4;
using float8_e4m3 = ::ml_dtypes::float8_e4m3;
using float8_e4m3fn = ::ml_dtypes::float8_e4m3fn;
using float8_e4m3fnuz = ::ml_dtypes::float8_e4m3fnuz;
using float8_e4m3b11fnuz = ::ml_dtypes::float8_e4m3b11fnuz;
using float8_e5m2 = ::ml_dtypes::float8_e5m2;
using float8_e5m2fnuz = ::ml_dtypes::float8_e5m2fnuz;
using float8_e8m0fnu = ::ml_dtypes::float8_e8m0fnu;

using int2 = ::ml_dtypes::int2;
using uint2 = ::ml_dtypes::uint2;
Expand Down
28 changes: 28 additions & 0 deletions xla/array2d_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,34 @@ TEST(Array2dTest, LinspaceF8E3M4) {
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(2, 1)), 3.5);
}

TEST(Array2dTest, LinspaceF4E2M1FN) {
auto arr = MakeLinspaceArray2D<tsl::float4_e2m1fn>(1.0, 3.5, 3, 2);

EXPECT_EQ(arr->n1(), 3);
EXPECT_EQ(arr->n2(), 2);

EXPECT_FLOAT_EQ(static_cast<float>((*arr)(0, 0)), 1.0);
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(0, 1)), 1.5);
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(1, 0)), 2.0);
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(1, 1)), 2.0); // 2.5 rounded down
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(2, 0)), 3.0);
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(2, 1)), 4.0); // 3.5 rounded up
}

TEST(Array2dTest, LinspaceF8E8M0FNU) {
auto arr = MakeLinspaceArray2D<tsl::float8_e8m0fnu>(1.0, 3.5, 3, 2);

EXPECT_EQ(arr->n1(), 3);
EXPECT_EQ(arr->n2(), 2);

EXPECT_FLOAT_EQ(static_cast<float>((*arr)(0, 0)), 1.0);
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(0, 1)), 2.0); // 1.5 rounded up
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(1, 0)), 2.0);
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(1, 1)), 2.0); // 2.5 rounded down
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(2, 0)), 4.0); // 3.0 rounded up
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(2, 1)), 4.0); // 3.5 rounded up
}

TEST(Array2dTest, Stringification) {
auto arr = MakeLinspaceArray2D(1.0, 3.5, 3, 2);
const std::string expected = R"([[1, 1.5],
Expand Down
9 changes: 7 additions & 2 deletions xla/comparison_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,13 @@ class Comparison {
// -NaN < -Inf < -Finite < -0 < +0 < +Finite < +Inf < +NaN
// Reference:
// https://www.tensorflow.org/xla/operation_semantics#element-wise_comparison_operations
using R = SignedIntegerTypeForSizeType<sizeof(T)>;
return GetComparator<R>()(ToSignMagnitude(a), ToSignMagnitude(b));
if constexpr (std::numeric_limits<T>::is_signed) {
using R = SignedIntegerTypeForSizeType<sizeof(T)>;
return GetComparator<R>()(ToSignMagnitude(a), ToSignMagnitude(b));
} else {
using R = UnsignedIntegerTypeForSizeType<sizeof(T)>;
return GetComparator<R>()(ToSignMagnitude(a), ToSignMagnitude(b));
}
}
}
// Applies the comparison from this Comparison's direction and ordering.
Expand Down
4 changes: 4 additions & 0 deletions xla/ffi/api/api.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,8 @@ inline std::ostream& operator<<(std::ostream& os,
return os << "C128";
case XLA_FFI_DataType_TOKEN:
return os << "TOKEN";
case XLA_FFI_DataType_F4E2M1FN:
return os << "F4E2M1FN";
case XLA_FFI_DataType_F8E5M2:
return os << "F8E5M2";
case XLA_FFI_DataType_F8E3M4:
Expand All @@ -145,6 +147,8 @@ inline std::ostream& operator<<(std::ostream& os,
return os << "F8E5M2FNUZ";
case XLA_FFI_DataType_F8E4M3FNUZ:
return os << "F8E4M3FNUZ";
case XLA_FFI_DataType_F8E8M0FNU:
return os << "F8E8M0FNU";
}
}

Expand Down
2 changes: 2 additions & 0 deletions xla/ffi/api/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,8 @@ typedef enum {
XLA_FFI_DataType_F8E4M3B11FNUZ = 23,
XLA_FFI_DataType_F8E5M2FNUZ = 24,
XLA_FFI_DataType_F8E4M3FNUZ = 25,
XLA_FFI_DataType_F4E2M1FN = 30,
XLA_FFI_DataType_F8E8M0FNU = 31,
} XLA_FFI_DataType;
// LINT.ThenChange(ffi_test.cc)

Expand Down
6 changes: 6 additions & 0 deletions xla/ffi/api/ffi.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ enum class DataType : uint8_t {
F8E5M2FNUZ = XLA_FFI_DataType_F8E5M2FNUZ,
F8E4M3FNUZ = XLA_FFI_DataType_F8E4M3FNUZ,
F8E3M4 = XLA_FFI_DataType_F8E3M4,
F4E2M1FN = XLA_FFI_DataType_F4E2M1FN,
F8E8M0FNU = XLA_FFI_DataType_F8E8M0FNU,
};

// Create aliases in ::xla::ffi namespace for all DataTypes, for consistency
Expand Down Expand Up @@ -106,6 +108,8 @@ inline constexpr DataType F8E4M3B11FNUZ = DataType::F8E4M3B11FNUZ;
inline constexpr DataType F8E5M2FNUZ = DataType::F8E5M2FNUZ;
inline constexpr DataType F8E4M3FNUZ = DataType::F8E4M3FNUZ;
inline constexpr DataType F8E3M4 = DataType::F8E3M4;
inline constexpr DataType F4E2M1FN = DataType::F4E2M1FN;
inline constexpr DataType F8E8M0FNU = DataType::F8E8M0FNU;

inline std::ostream& operator<<(std::ostream& os, const DataType dtype) {
return os << static_cast<XLA_FFI_DataType>(dtype);
Expand All @@ -127,6 +131,8 @@ constexpr size_t ByteWidth(DataType dtype) {
case DataType::F8E5M2FNUZ:
case DataType::F8E4M3FNUZ:
case DataType::F8E3M4:
case DataType::F4E2M1FN:
case DataType::F8E8M0FNU:
return 1;
case DataType::S16:
case DataType::U16:
Expand Down
6 changes: 6 additions & 0 deletions xla/ffi/api/ffi_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ TEST(FfiTest, DataTypeEnumValue) {

EXPECT_EQ(encoded(PrimitiveType::TOKEN), encoded(DataType::TOKEN));

EXPECT_EQ(encoded(PrimitiveType::F4E2M1FN), encoded(DataType::F4E2M1FN));
EXPECT_EQ(encoded(PrimitiveType::F8E5M2), encoded(DataType::F8E5M2));
EXPECT_EQ(encoded(PrimitiveType::F8E4M3), encoded(DataType::F8E4M3));
EXPECT_EQ(encoded(PrimitiveType::F8E4M3FN), encoded(DataType::F8E4M3FN));
Expand All @@ -137,6 +138,7 @@ TEST(FfiTest, DataTypeEnumValue) {
EXPECT_EQ(encoded(PrimitiveType::F8E5M2FNUZ), encoded(DataType::F8E5M2FNUZ));
EXPECT_EQ(encoded(PrimitiveType::F8E4M3FNUZ), encoded(DataType::F8E4M3FNUZ));
EXPECT_EQ(encoded(PrimitiveType::F8E3M4), encoded(DataType::F8E3M4));
EXPECT_EQ(encoded(PrimitiveType::F8E8M0FNU), encoded(DataType::F8E8M0FNU));
}

TEST(FfiTest, DataTypeByteWidth) {
Expand Down Expand Up @@ -179,6 +181,8 @@ TEST(FfiTest, DataTypeByteWidth) {
EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::C128),
ByteWidth(DataType::C128));

EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F4E2M1FN),
ByteWidth(DataType::F4E2M1FN));
EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F8E5M2),
ByteWidth(DataType::F8E5M2));
EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F8E4M3),
Expand All @@ -193,6 +197,8 @@ TEST(FfiTest, DataTypeByteWidth) {
ByteWidth(DataType::F8E4M3FNUZ));
EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F8E3M4),
ByteWidth(DataType::F8E3M4));
EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F8E8M0FNU),
ByteWidth(DataType::F8E8M0FNU));
}

TEST(FfiTest, ErrorEnumValue) {
Expand Down
2 changes: 2 additions & 0 deletions xla/ffi/call_frame.cc
Original file line number Diff line number Diff line change
Expand Up @@ -264,13 +264,15 @@ static XLA_FFI_DataType ToDataType(PrimitiveType primitive_type) {
case PrimitiveType::C64:
case PrimitiveType::C128:
case PrimitiveType::TOKEN:
case PrimitiveType::F4E2M1FN:
case PrimitiveType::F8E5M2:
case PrimitiveType::F8E4M3:
case PrimitiveType::F8E4M3FN:
case PrimitiveType::F8E4M3B11FNUZ:
case PrimitiveType::F8E5M2FNUZ:
case PrimitiveType::F8E4M3FNUZ:
case PrimitiveType::F8E3M4:
case PrimitiveType::F8E8M0FNU:
return static_cast<XLA_FFI_DataType>(primitive_type);
default:
DCHECK(false) << "Unsupported primitive type "
Expand Down
70 changes: 70 additions & 0 deletions xla/fp_util_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,76 @@ class FP8E4M3DistanceTest : public ::testing::Test {};
using F8E4M3Types = ::testing::Types<tsl::float8_e4m3, tsl::float8_e4m3fn>;
TYPED_TEST_SUITE(FP8E4M3DistanceTest, F8E4M3Types);

TEST(FPDistanceTest, F4E2M1FNDistance) {
// a & b are equal
EXPECT_EQ(CalculateDistanceInFloats<tsl::float4_e2m1fn>(
tsl::float4_e2m1fn(4.0), tsl::float4_e2m1fn(4.0)),
0);

// a & b have the same exponents
EXPECT_EQ(CalculateDistanceInFloats<tsl::float4_e2m1fn>(
tsl::float4_e2m1fn(4.0), tsl::float4_e2m1fn(6.0)),
1);

// a & b have different exponents
EXPECT_EQ(CalculateDistanceInFloats<tsl::float4_e2m1fn>(
tsl::float4_e2m1fn(2.0), tsl::float4_e2m1fn(4.0)),
2);

// 1 from 0 in the positive direction
EXPECT_EQ(CalculateDistanceInFloats<tsl::float4_e2m1fn>(
std::numeric_limits<tsl::float4_e2m1fn>::denorm_min(),
tsl::float4_e2m1fn(0)),
1);

// 1 from 0 in the negative direction
EXPECT_EQ(CalculateDistanceInFloats<tsl::float4_e2m1fn>(
-std::numeric_limits<tsl::float4_e2m1fn>::denorm_min(),
tsl::float4_e2m1fn(0)),
1);

// a & b have different signs
EXPECT_EQ(CalculateDistanceInFloats<tsl::float4_e2m1fn>(
-std::numeric_limits<tsl::float4_e2m1fn>::denorm_min(),
std::numeric_limits<tsl::float4_e2m1fn>::denorm_min()),
2);

// 1 non denorm from 0 in the positive direction
EXPECT_EQ(CalculateDistanceInFloats<tsl::float4_e2m1fn>(
std::numeric_limits<tsl::float4_e2m1fn>::min(),
tsl::float4_e2m1fn(0)),
2);

// 1 non denorm from 0 in the negative direction
EXPECT_EQ(CalculateDistanceInFloats<tsl::float4_e2m1fn>(
-std::numeric_limits<tsl::float4_e2m1fn>::min(),
tsl::float4_e2m1fn(0)),
2);

// a & b have different signs
EXPECT_EQ(CalculateDistanceInFloats<tsl::float4_e2m1fn>(
-std::numeric_limits<tsl::float4_e2m1fn>::min(),
std::numeric_limits<tsl::float4_e2m1fn>::min()),
4);
}

TEST(FPDistanceTest, F8E8M0FNUDistance) {
// a & b are equal
EXPECT_EQ(CalculateDistanceInFloats<tsl::float8_e8m0fnu>(
tsl::float8_e8m0fnu(1.0), tsl::float8_e8m0fnu(1.0)),
0);

// one step apart
EXPECT_EQ(CalculateDistanceInFloats<tsl::float8_e8m0fnu>(
tsl::float8_e8m0fnu(1.0), tsl::float8_e8m0fnu(2.0)),
1);

// two steps apart
EXPECT_EQ(CalculateDistanceInFloats<tsl::float8_e8m0fnu>(
tsl::float8_e8m0fnu(0.5), tsl::float8_e8m0fnu(2.0)),
2);
}

TEST(FPDistanceTest, F8E3M4Distance) {
// a & b are equal
EXPECT_EQ(CalculateDistanceInFloats<tsl::float8_e3m4>(tsl::float8_e3m4(8.0),
Expand Down
11 changes: 7 additions & 4 deletions xla/hlo/builder/lib/math.cc
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ XlaOp IsNegZero(XlaOp operand) {
case F32:
return Eq(BitcastConvertType(operand, U32),
ConstantR0WithType(&b, U32, uint32_t{1} << 31));
case F4E2M1FN:
case F8E3M4:
case F8E4M3:
case F8E5M2:
Expand Down Expand Up @@ -971,8 +972,9 @@ XlaOp Igamma(XlaOp a, XlaOp x) {
TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Igamma", a));
PrimitiveType a_x_type = a_shape.element_type();
bool needs_upcast = false;
for (PrimitiveType type : {BF16, F16, F8E3M4, F8E4M3, F8E5M2, F8E4M3FN,
F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ}) {
for (PrimitiveType type :
{BF16, F16, F4E2M1FN, F8E3M4, F8E4M3, F8E4M3B11FNUZ, F8E4M3FN,
F8E4M3FNUZ, F8E5M2, F8E5M2FNUZ}) {
if (a_shape.element_type() == type) {
needs_upcast = true;
break;
Expand Down Expand Up @@ -1024,8 +1026,9 @@ XlaOp IgammaGradA(XlaOp a, XlaOp x) {
}
TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("IgammaGradA", a));
bool needs_upcast = false;
for (PrimitiveType type : {BF16, F16, F8E3M4, F8E4M3, F8E5M2, F8E4M3FN,
F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ}) {
for (PrimitiveType type :
{BF16, F16, F4E2M1FN, F8E3M4, F8E4M3, F8E4M3B11FNUZ, F8E4M3FN,
F8E4M3FNUZ, F8E5M2, F8E5M2FNUZ}) {
if (a_shape.element_type() == type) {
needs_upcast = true;
break;
Expand Down
34 changes: 24 additions & 10 deletions xla/hlo/builder/lib/math_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,17 +95,22 @@ class MathTypedTest : public MathTest {
Tuple(&b, {IsFinite(x), IsInf(x), IsPosInf(x), IsNegInf(x), IsNan(x)});

bool has_inf = std::numeric_limits<T>::has_infinity;
bool has_nan = std::numeric_limits<T>::has_quiet_NaN;
bool has_finite = !has_inf && !has_nan;
bool has_nan_only = !has_inf && has_nan;

auto expected = LiteralUtil::MakeTupleOwned(
LiteralUtil::CreateR1<bool>(
{true, true, true, true, true, false, false, false, false}),
LiteralUtil::CreateR1<bool>({true, true, true, true, true, has_finite,
has_finite, has_finite, has_finite}),
LiteralUtil::CreateR1<bool>({false, false, false, false, false, has_inf,
has_inf, false, false}),
LiteralUtil::CreateR1<bool>(
{false, false, false, false, false, has_inf, false, false, false}),
LiteralUtil::CreateR1<bool>(
{false, false, false, false, false, false, has_inf, false, false}),
LiteralUtil::CreateR1<bool>({false, false, false, false, false,
!has_inf, !has_inf, true, true}));
has_nan_only, has_nan_only, has_nan,
has_nan}));
ComputeAndCompareLiteral(&b, expected, {});
}

Expand All @@ -118,10 +123,11 @@ class MathTypedTest : public MathTest {
LiteralUtil::CreateR1<T>({T{-0.0}, T{0}, T{1}, T{-1}, inf, -inf, nan}),
&b));

bool is_mx = std::is_same_v<T, tsl::float4_e2m1fn>;
ComputeAndCompareLiteral(
&b,
LiteralUtil::CreateR1<bool>(
{has_negative_zero_v<T>, false, false, false, false, false, false}),
{has_negative_zero_v<T>, false, false, false, false, false, is_mx}),
{}, error_spec_);
}

Expand All @@ -136,6 +142,9 @@ class MathTypedTest : public MathTest {
// For good measure, we also check pow with an exponent other than 0.5.
void TestSqrtPowInequivalence() {
SetFastMathDisabled(true);
if (std::is_same_v<T, tsl::float4_e2m1fn>) {
GTEST_SKIP() << "Skipping due to low precision";
}

// Tests disable constant folding by default, but this test needs it
// enabled, otherwise we don't tickle the bug we're trying to catch.
Expand Down Expand Up @@ -181,19 +190,24 @@ class MathTypedTest : public MathTest {
&b);
Erf(x);

bool has_inf = std::numeric_limits<T>::has_infinity;
std::vector<T> expected = {
has_inf ? T(-1) : nan, has_inf ? T(1) : nan, T(-0), T(0), T(-1), T(1)};
bool inf_as_nan = !std::numeric_limits<T>::has_infinity &&
std::numeric_limits<T>::has_quiet_NaN;
std::vector<T> expected = {inf_as_nan ? nan : T(-1),
inf_as_nan ? nan : T(1),
T(-0),
T(0),
T(-1),
T(1)};

ComputeAndCompareR1<T>(&b, expected, {}, error_spec_);
}
};

// TODO(b/123355973): Add bfloat16 to TestTypes once it's working.
using TestTypes =
::testing::Types<tsl::float8_e3m4, tsl::float8_e4m3, tsl::float8_e4m3fnuz,
tsl::float8_e4m3b11fnuz, tsl::float8_e5m2,
tsl::float8_e5m2fnuz,
::testing::Types<tsl::float4_e2m1fn, tsl::float8_e3m4, tsl::float8_e4m3,
tsl::float8_e4m3fnuz, tsl::float8_e4m3b11fnuz,
tsl::float8_e5m2, tsl::float8_e5m2fnuz,
#ifndef XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16
Eigen::half,
#endif
Expand Down
1 change: 1 addition & 0 deletions xla/hlo/evaluator/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ cc_library(
"hlo_evaluator_typed_visitor_int4.cc",
"hlo_evaluator_typed_visitor_int64.cc",
"hlo_evaluator_typed_visitor_int8.cc",
"hlo_evaluator_typed_visitor_mxfloat.cc",
"hlo_evaluator_typed_visitor_uint16.cc",
"hlo_evaluator_typed_visitor_uint32.cc",
"hlo_evaluator_typed_visitor_uint64.cc",
Expand Down
2 changes: 1 addition & 1 deletion xla/hlo/evaluator/hlo_evaluator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3722,7 +3722,7 @@ absl::StatusOr<Literal> StochasticConvertOp(const Literal& operand_literal,
const Shape& result_shape) {
std::function<ResultT(Fp, Uint)> stochastic_convert_op =
[](Fp operand, Uint random) -> ResultT {
bool is_negative = static_cast<bool>(Eigen::numext::signbit(operand));
bool is_negative = static_cast<bool>(SignAndMagnitude(operand).first);
if (Eigen::numext::isinf(operand)) {
return is_negative ? std::numeric_limits<ResultT>::min()
: std::numeric_limits<ResultT>::max();
Expand Down
Loading

0 comments on commit 1e145f9

Please sign in to comment.