Skip to content

Commit

Permalink
Explicitly enforce int4 bit representation with ignored high bits.
Browse files Browse the repository at this point in the history
The previous implementation using a bit-field leaves the exact
representation implementation-defined, with some platforms storing
bit-fields packed left-to-right, others right-to-left, and the
masked bits unspecified.  This complicates serialization and
vectorized conversions.

With this change, we now explicitly store the value in the lower
4 bits, and leave the upper 4 bits unspecified.  The type is
constructed in such a way that correctness is preserved
regardless of the upper bit values.

PiperOrigin-RevId: 577312207
  • Loading branch information
cantonios authored and The ml_dtypes Authors committed Nov 2, 2023
1 parent ccbc3f9 commit 20b73dc
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 53 deletions.
144 changes: 97 additions & 47 deletions ml_dtypes/include/int4.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,44 @@ limitations under the License.
#include <ostream>
#include <sstream>
#include <string>
#include <type_traits>

namespace ml_dtypes {

// Stores the 4-bit integer value in the low four bits of a byte. The upper
// four bits are left unspecified and ignored.
template <typename UnderlyingTy>
struct i4 {
private:
UnderlyingTy v : 4;
UnderlyingTy v_;

static_assert(
std::is_same_v<UnderlyingTy, uint8_t> ||
std::is_same_v<UnderlyingTy, int8_t>,
"The underyling type must be a signed or unsigned 8-bit integer.");

// Mask the upper four bits.
static inline constexpr UnderlyingTy Mask(UnderlyingTy v) { return v & 0x0F; }

// Mask the upper four bits and sign-extend for signed types.
static inline constexpr UnderlyingTy MaskAndSignExtend(UnderlyingTy v) {
return std::is_signed_v<UnderlyingTy> ? Mask(v) | ((v & 0x08) ? 0xF0 : 0x00)
: Mask(v);
}

// Casts to the corresponding UnderlyingTy value.
inline constexpr UnderlyingTy IntValue() const {
return MaskAndSignExtend(v_);
}

public:
constexpr i4() : v(0) {}
constexpr i4(const i4& other) = default;
constexpr i4(i4&& other) = default;
constexpr i4() noexcept : v_(0) {}
constexpr i4(const i4& other) noexcept = default;
constexpr i4(i4&& other) noexcept = default;
constexpr i4& operator=(const i4& other) = default;
constexpr i4& operator=(i4&&) = default;

explicit constexpr i4(UnderlyingTy val) : v(val & 0x0F) {}
explicit constexpr i4(UnderlyingTy val) : v_(Mask(val)) {}
template <typename T>
explicit constexpr i4(T t) : i4(static_cast<UnderlyingTy>(t)) {}

Expand All @@ -50,50 +72,78 @@ struct i4 {

template <typename T>
explicit constexpr operator T() const {
return static_cast<T>(v);
return static_cast<T>(IntValue());
}
// NOLINTNEXTLINE(google-explicit-constructor)
constexpr operator std::optional<int64_t>() const {
return static_cast<int64_t>(v);
}

constexpr i4 operator-() const { return i4(-v); }
constexpr i4 operator+(const i4& other) const { return i4((v + other.v)); }
constexpr i4 operator-(const i4& other) const { return i4((v - other.v)); }
constexpr i4 operator*(const i4& other) const { return i4((v * other.v)); }
constexpr i4 operator/(const i4& other) const { return i4((v / other.v)); }
constexpr i4 operator%(const i4& other) const { return i4((v % other.v)); }

constexpr i4 operator&(const i4& other) const { return i4((v & other.v)); }
constexpr i4 operator|(const i4& other) const { return i4((v | other.v)); }
constexpr i4 operator^(const i4& other) const { return i4((v ^ other.v)); }
constexpr i4 operator~() const { return i4(~v); }
constexpr i4 operator>>(int amount) const { return i4((v >> amount)); }
constexpr i4 operator<<(int amount) const { return i4((v << amount)); }

constexpr bool operator==(const i4& other) const { return v == other.v; }
constexpr bool operator!=(const i4& other) const { return v != other.v; }
constexpr bool operator<(const i4& other) const { return v < other.v; }
constexpr bool operator>(const i4& other) const { return v > other.v; }
constexpr bool operator<=(const i4& other) const { return v <= other.v; }
constexpr bool operator>=(const i4& other) const { return v >= other.v; }

constexpr bool operator==(int64_t other) const { return v == other; }
constexpr bool operator!=(int64_t other) const { return v != other; }
constexpr bool operator<(int64_t other) const { return v < other; }
constexpr bool operator>(int64_t other) const { return v > other; }
constexpr bool operator<=(int64_t other) const { return v <= other; }
constexpr bool operator>=(int64_t other) const { return v >= other; }

friend constexpr bool operator==(int64_t a, const i4& b) { return a == b.v; }
friend constexpr bool operator!=(int64_t a, const i4& b) { return a != b.v; }
friend constexpr bool operator<(int64_t a, const i4& b) { return a < b.v; }
friend constexpr bool operator>(int64_t a, const i4& b) { return a > b.v; }
friend constexpr bool operator<=(int64_t a, const i4& b) { return a <= b.v; }
friend constexpr bool operator>=(int64_t a, const i4& b) { return a >= b.v; }
return static_cast<int64_t>(IntValue());
}

constexpr i4 operator-() const { return i4(-v_); }
constexpr i4 operator+(const i4& other) const { return i4(v_ + other.v_); }
constexpr i4 operator-(const i4& other) const { return i4(v_ - other.v_); }
constexpr i4 operator*(const i4& other) const { return i4(v_ * other.v_); }
constexpr i4 operator/(const i4& other) const {
return i4(IntValue() / other.IntValue());
}
constexpr i4 operator%(const i4& other) const {
return i4((IntValue() % other.IntValue()));
}

constexpr i4 operator&(const i4& other) const { return i4(v_ & other.v_); }
constexpr i4 operator|(const i4& other) const { return i4(v_ | other.v_); }
constexpr i4 operator^(const i4& other) const { return i4(v_ ^ other.v_); }
constexpr i4 operator~() const { return i4(~v_); }
constexpr i4 operator>>(int amount) const { return i4(IntValue() >> amount); }
constexpr i4 operator<<(int amount) const { return i4(v_ << amount); }

constexpr bool operator==(const i4& other) const {
return Mask(v_) == Mask(other.v_);
}
constexpr bool operator!=(const i4& other) const {
return Mask(v_) != Mask(other.v_);
}
constexpr bool operator<(const i4& other) const {
return IntValue() < other.IntValue();
}
constexpr bool operator>(const i4& other) const {
return IntValue() > other.IntValue();
}
constexpr bool operator<=(const i4& other) const {
return IntValue() <= other.IntValue();
}
constexpr bool operator>=(const i4& other) const {
return IntValue() >= other.IntValue();
}

constexpr bool operator==(int64_t other) const { return IntValue() == other; }
constexpr bool operator!=(int64_t other) const { return IntValue() != other; }
constexpr bool operator<(int64_t other) const { return IntValue() < other; }
constexpr bool operator>(int64_t other) const { return IntValue() > other; }
constexpr bool operator<=(int64_t other) const { return IntValue() <= other; }
constexpr bool operator>=(int64_t other) const { return IntValue() >= other; }

friend constexpr bool operator==(int64_t a, const i4& b) {
return a == b.IntValue();
}
friend constexpr bool operator!=(int64_t a, const i4& b) {
return a != b.IntValue();
}
friend constexpr bool operator<(int64_t a, const i4& b) {
return a < b.IntValue();
}
friend constexpr bool operator>(int64_t a, const i4& b) {
return a > b.IntValue();
}
friend constexpr bool operator<=(int64_t a, const i4& b) {
return a <= b.IntValue();
}
friend constexpr bool operator>=(int64_t a, const i4& b) {
return a >= b.IntValue();
}

constexpr i4& operator++() {
v = (v + 1) & 0x0F;
v_ = Mask(v_ + 1);
return *this;
}

Expand All @@ -104,7 +154,7 @@ struct i4 {
}

constexpr i4& operator--() {
v = (v - 1) & 0x0F;
v_ = Mask(v_ - 1);
return *this;
}

Expand Down Expand Up @@ -156,13 +206,13 @@ struct i4 {
}

friend ::std::ostream& operator<<(::std::ostream& os, const i4& num) {
os << static_cast<int16_t>(num.v);
os << static_cast<int16_t>(num);
return os;
}

std::string ToString() const {
std::ostringstream os;
os << static_cast<int16_t>(v);
os << static_cast<int16_t>(*this);
return os.str();
}
};
Expand Down
26 changes: 20 additions & 6 deletions ml_dtypes/tests/int4_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License.
#include <optional>
#include <sstream>
#include <string>
#include <type_traits>
#include <utility>

#include <gtest/gtest.h>
Expand Down Expand Up @@ -95,6 +96,13 @@ TYPED_TEST(Int4Test, NumericLimitsBase) {
EXPECT_EQ(static_cast<int>(std::numeric_limits<Int4>::denorm_min()), 0);
}

TYPED_TEST(Int4Test, TypeTraits) {
using Int4 = TypeParam;
EXPECT_TRUE(std::is_trivially_copyable_v<Int4>);
EXPECT_TRUE(std::is_default_constructible_v<Int4>);
EXPECT_TRUE(std::is_nothrow_constructible_v<Int4>);
}

TYPED_TEST(Int4Test, CreateAndAssign) {
using Int4 = TypeParam;

Expand Down Expand Up @@ -171,6 +179,12 @@ TYPED_TEST(Int4Test, Constexpr) {
TEST_CONSTEXPR(int4(1) <<= 1);
}

template<typename Int4>
Int4 CreateInt4WithRandomHighBits(int val) {
return Eigen::numext::bit_cast<Int4>(
static_cast<uint8_t>(val | (Eigen::internal::random<uint8_t>() << 4)));
}

TYPED_TEST(Int4Test, Casts) {
using Int4 = TypeParam;

Expand All @@ -189,7 +203,7 @@ TYPED_TEST(Int4Test, Casts) {
for (int i = static_cast<int>(std::numeric_limits<Int4>::min());
i <= static_cast<int>(std::numeric_limits<Int4>::max()); ++i) {
// Round-trip.
EXPECT_EQ(static_cast<int>(Int4(i)), i);
EXPECT_EQ(static_cast<int>(CreateInt4WithRandomHighBits<Int4>(i)), i);

// Float truncation.
for (int j = 1; j < 10; ++j) {
Expand All @@ -204,7 +218,7 @@ TYPED_TEST(Int4Test, Operators) {
using Int4 = TypeParam;
for (int i = static_cast<int>(std::numeric_limits<Int4>::min());
i <= static_cast<int>(std::numeric_limits<Int4>::max()); ++i) {
Int4 x = Int4(i);
Int4 x = CreateInt4WithRandomHighBits<Int4>(i);

EXPECT_EQ(-x, Int4(-i));
EXPECT_EQ(~x, Int4(~i));
Expand All @@ -220,7 +234,7 @@ TYPED_TEST(Int4Test, Operators) {

for (int j = static_cast<int>(std::numeric_limits<Int4>::min());
j <= static_cast<int>(std::numeric_limits<Int4>::max()); ++j) {
Int4 y = Int4(j);
Int4 y = CreateInt4WithRandomHighBits<Int4>(j);

EXPECT_EQ(x + y, Int4(i + j));
EXPECT_EQ(x - y, Int4(i - j));
Expand Down Expand Up @@ -279,7 +293,7 @@ TYPED_TEST(Int4Test, ToString) {
using Int4 = TypeParam;
for (int i = static_cast<int>(std::numeric_limits<Int4>::min());
i <= static_cast<int>(std::numeric_limits<Int4>::max()); ++i) {
Int4 x = Int4(i);
Int4 x = CreateInt4WithRandomHighBits<Int4>(i);
std::stringstream ss;
ss << x;
EXPECT_EQ(ss.str(), std::to_string(i));
Expand Down Expand Up @@ -332,7 +346,7 @@ TYPED_TEST(Int4CastTest, CastThroughInt) {
using DestType = typename TypeParam::second_type;

for (int i = 0x00; i <= 0x0F; ++i) {
Int4 x = Int4(i);
Int4 x = CreateInt4WithRandomHighBits<Int4>(i);
DestType dest = static_cast<DestType>(x);
DestType expected = static_cast<DestType>(static_cast<int>(x));
EXPECT_EQ(dest, expected);
Expand Down Expand Up @@ -365,7 +379,7 @@ TYPED_TEST(Int4CastTest, DeviceCast) {
Eigen::TensorMap<Eigen::Tensor<DestType, 1>, Eigen::Aligned> dst_device(
dst_device_buffer, kNumElems);

// Allocate host buffers and initially src memory.
// Allocate host buffers and initialize src memory.
Eigen::Tensor<Int4, 1> src_cpu(kNumElems);
Eigen::Tensor<DestType, 1> dst_cpu(kNumElems);
for (int i = 0; i < kNumElems; ++i) {
Expand Down

0 comments on commit 20b73dc

Please sign in to comment.