diff --git a/ml_dtypes/include/int4.h b/ml_dtypes/include/int4.h index 419cc9f3..ccd47b2a 100644 --- a/ml_dtypes/include/int4.h +++ b/ml_dtypes/include/int4.h @@ -22,22 +22,44 @@ limitations under the License. #include #include #include +#include namespace ml_dtypes { +// Stores the 4-bit integer value in the low four bits of a byte. The upper +// four bits are left unspecified and ignored. template struct i4 { private: - UnderlyingTy v : 4; + UnderlyingTy v_; + + static_assert( + std::is_same_v || + std::is_same_v, + "The underyling type must be a signed or unsigned 8-bit integer."); + + // Mask the upper four bits. + static inline constexpr UnderlyingTy Mask(UnderlyingTy v) { return v & 0x0F; } + + // Mask the upper four bits and sign-extend for signed types. + static inline constexpr UnderlyingTy MaskAndSignExtend(UnderlyingTy v) { + return std::is_signed_v ? Mask(v) | ((v & 0x08) ? 0xF0 : 0x00) + : Mask(v); + } + + // Casts to the corresponding UnderlyingTy value. + inline constexpr UnderlyingTy IntValue() const { + return MaskAndSignExtend(v_); + } public: - constexpr i4() : v(0) {} - constexpr i4(const i4& other) = default; - constexpr i4(i4&& other) = default; + constexpr i4() noexcept : v_(0) {} + constexpr i4(const i4& other) noexcept = default; + constexpr i4(i4&& other) noexcept = default; constexpr i4& operator=(const i4& other) = default; constexpr i4& operator=(i4&&) = default; - explicit constexpr i4(UnderlyingTy val) : v(val & 0x0F) {} + explicit constexpr i4(UnderlyingTy val) : v_(Mask(val)) {} template explicit constexpr i4(T t) : i4(static_cast(t)) {} @@ -50,50 +72,78 @@ struct i4 { template explicit constexpr operator T() const { - return static_cast(v); + return static_cast(IntValue()); } // NOLINTNEXTLINE(google-explicit-constructor) constexpr operator std::optional() const { - return static_cast(v); - } - - constexpr i4 operator-() const { return i4(-v); } - constexpr i4 operator+(const i4& other) const { return i4((v + other.v)); } - constexpr i4 operator-(const i4& other) const { return i4((v - other.v)); } - constexpr i4 operator*(const i4& other) const { return i4((v * other.v)); } - constexpr i4 operator/(const i4& other) const { return i4((v / other.v)); } - constexpr i4 operator%(const i4& other) const { return i4((v % other.v)); } - - constexpr i4 operator&(const i4& other) const { return i4((v & other.v)); } - constexpr i4 operator|(const i4& other) const { return i4((v | other.v)); } - constexpr i4 operator^(const i4& other) const { return i4((v ^ other.v)); } - constexpr i4 operator~() const { return i4(~v); } - constexpr i4 operator>>(int amount) const { return i4((v >> amount)); } - constexpr i4 operator<<(int amount) const { return i4((v << amount)); } - - constexpr bool operator==(const i4& other) const { return v == other.v; } - constexpr bool operator!=(const i4& other) const { return v != other.v; } - constexpr bool operator<(const i4& other) const { return v < other.v; } - constexpr bool operator>(const i4& other) const { return v > other.v; } - constexpr bool operator<=(const i4& other) const { return v <= other.v; } - constexpr bool operator>=(const i4& other) const { return v >= other.v; } - - constexpr bool operator==(int64_t other) const { return v == other; } - constexpr bool operator!=(int64_t other) const { return v != other; } - constexpr bool operator<(int64_t other) const { return v < other; } - constexpr bool operator>(int64_t other) const { return v > other; } - constexpr bool operator<=(int64_t other) const { return v <= other; } - constexpr bool operator>=(int64_t other) const { return v >= other; } - - friend constexpr bool operator==(int64_t a, const i4& b) { return a == b.v; } - friend constexpr bool operator!=(int64_t a, const i4& b) { return a != b.v; } - friend constexpr bool operator<(int64_t a, const i4& b) { return a < b.v; } - friend constexpr bool operator>(int64_t a, const i4& b) { return a > b.v; } - friend constexpr bool operator<=(int64_t a, const i4& b) { return a <= b.v; } - friend constexpr bool operator>=(int64_t a, const i4& b) { return a >= b.v; } + return static_cast(IntValue()); + } + + constexpr i4 operator-() const { return i4(-v_); } + constexpr i4 operator+(const i4& other) const { return i4(v_ + other.v_); } + constexpr i4 operator-(const i4& other) const { return i4(v_ - other.v_); } + constexpr i4 operator*(const i4& other) const { return i4(v_ * other.v_); } + constexpr i4 operator/(const i4& other) const { + return i4(IntValue() / other.IntValue()); + } + constexpr i4 operator%(const i4& other) const { + return i4((IntValue() % other.IntValue())); + } + + constexpr i4 operator&(const i4& other) const { return i4(v_ & other.v_); } + constexpr i4 operator|(const i4& other) const { return i4(v_ | other.v_); } + constexpr i4 operator^(const i4& other) const { return i4(v_ ^ other.v_); } + constexpr i4 operator~() const { return i4(~v_); } + constexpr i4 operator>>(int amount) const { return i4(IntValue() >> amount); } + constexpr i4 operator<<(int amount) const { return i4(v_ << amount); } + + constexpr bool operator==(const i4& other) const { + return Mask(v_) == Mask(other.v_); + } + constexpr bool operator!=(const i4& other) const { + return Mask(v_) != Mask(other.v_); + } + constexpr bool operator<(const i4& other) const { + return IntValue() < other.IntValue(); + } + constexpr bool operator>(const i4& other) const { + return IntValue() > other.IntValue(); + } + constexpr bool operator<=(const i4& other) const { + return IntValue() <= other.IntValue(); + } + constexpr bool operator>=(const i4& other) const { + return IntValue() >= other.IntValue(); + } + + constexpr bool operator==(int64_t other) const { return IntValue() == other; } + constexpr bool operator!=(int64_t other) const { return IntValue() != other; } + constexpr bool operator<(int64_t other) const { return IntValue() < other; } + constexpr bool operator>(int64_t other) const { return IntValue() > other; } + constexpr bool operator<=(int64_t other) const { return IntValue() <= other; } + constexpr bool operator>=(int64_t other) const { return IntValue() >= other; } + + friend constexpr bool operator==(int64_t a, const i4& b) { + return a == b.IntValue(); + } + friend constexpr bool operator!=(int64_t a, const i4& b) { + return a != b.IntValue(); + } + friend constexpr bool operator<(int64_t a, const i4& b) { + return a < b.IntValue(); + } + friend constexpr bool operator>(int64_t a, const i4& b) { + return a > b.IntValue(); + } + friend constexpr bool operator<=(int64_t a, const i4& b) { + return a <= b.IntValue(); + } + friend constexpr bool operator>=(int64_t a, const i4& b) { + return a >= b.IntValue(); + } constexpr i4& operator++() { - v = (v + 1) & 0x0F; + v_ = Mask(v_ + 1); return *this; } @@ -104,7 +154,7 @@ struct i4 { } constexpr i4& operator--() { - v = (v - 1) & 0x0F; + v_ = Mask(v_ - 1); return *this; } @@ -156,13 +206,13 @@ struct i4 { } friend ::std::ostream& operator<<(::std::ostream& os, const i4& num) { - os << static_cast(num.v); + os << static_cast(num); return os; } std::string ToString() const { std::ostringstream os; - os << static_cast(v); + os << static_cast(*this); return os.str(); } }; diff --git a/ml_dtypes/tests/int4_test.cc b/ml_dtypes/tests/int4_test.cc index d7cf538f..dd67e79c 100644 --- a/ml_dtypes/tests/int4_test.cc +++ b/ml_dtypes/tests/int4_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -95,6 +96,13 @@ TYPED_TEST(Int4Test, NumericLimitsBase) { EXPECT_EQ(static_cast(std::numeric_limits::denorm_min()), 0); } +TYPED_TEST(Int4Test, TypeTraits) { + using Int4 = TypeParam; + EXPECT_TRUE(std::is_trivially_copyable_v); + EXPECT_TRUE(std::is_default_constructible_v); + EXPECT_TRUE(std::is_nothrow_constructible_v); +} + TYPED_TEST(Int4Test, CreateAndAssign) { using Int4 = TypeParam; @@ -171,6 +179,12 @@ TYPED_TEST(Int4Test, Constexpr) { TEST_CONSTEXPR(int4(1) <<= 1); } +template +Int4 CreateInt4WithRandomHighBits(int val) { + return Eigen::numext::bit_cast( + static_cast(val | (Eigen::internal::random() << 4))); +} + TYPED_TEST(Int4Test, Casts) { using Int4 = TypeParam; @@ -189,7 +203,7 @@ TYPED_TEST(Int4Test, Casts) { for (int i = static_cast(std::numeric_limits::min()); i <= static_cast(std::numeric_limits::max()); ++i) { // Round-trip. - EXPECT_EQ(static_cast(Int4(i)), i); + EXPECT_EQ(static_cast(CreateInt4WithRandomHighBits(i)), i); // Float truncation. for (int j = 1; j < 10; ++j) { @@ -204,7 +218,7 @@ TYPED_TEST(Int4Test, Operators) { using Int4 = TypeParam; for (int i = static_cast(std::numeric_limits::min()); i <= static_cast(std::numeric_limits::max()); ++i) { - Int4 x = Int4(i); + Int4 x = CreateInt4WithRandomHighBits(i); EXPECT_EQ(-x, Int4(-i)); EXPECT_EQ(~x, Int4(~i)); @@ -220,7 +234,7 @@ TYPED_TEST(Int4Test, Operators) { for (int j = static_cast(std::numeric_limits::min()); j <= static_cast(std::numeric_limits::max()); ++j) { - Int4 y = Int4(j); + Int4 y = CreateInt4WithRandomHighBits(j); EXPECT_EQ(x + y, Int4(i + j)); EXPECT_EQ(x - y, Int4(i - j)); @@ -279,7 +293,7 @@ TYPED_TEST(Int4Test, ToString) { using Int4 = TypeParam; for (int i = static_cast(std::numeric_limits::min()); i <= static_cast(std::numeric_limits::max()); ++i) { - Int4 x = Int4(i); + Int4 x = CreateInt4WithRandomHighBits(i); std::stringstream ss; ss << x; EXPECT_EQ(ss.str(), std::to_string(i)); @@ -332,7 +346,7 @@ TYPED_TEST(Int4CastTest, CastThroughInt) { using DestType = typename TypeParam::second_type; for (int i = 0x00; i <= 0x0F; ++i) { - Int4 x = Int4(i); + Int4 x = CreateInt4WithRandomHighBits(i); DestType dest = static_cast(x); DestType expected = static_cast(static_cast(x)); EXPECT_EQ(dest, expected); @@ -365,7 +379,7 @@ TYPED_TEST(Int4CastTest, DeviceCast) { Eigen::TensorMap, Eigen::Aligned> dst_device( dst_device_buffer, kNumElems); - // Allocate host buffers and initially src memory. + // Allocate host buffers and initialize src memory. Eigen::Tensor src_cpu(kNumElems); Eigen::Tensor dst_cpu(kNumElems); for (int i = 0; i < kNumElems; ++i) {