Skip to content

Commit

Permalink
Add unittest for safe cast for (u)int4 type; remove safe cast from in…
Browse files Browse the repository at this point in the history
…t8 to int4 types

PiperOrigin-RevId: 573102837
  • Loading branch information
ChromeHearts authored and The ml_dtypes Authors committed Oct 13, 2023
1 parent baca1aa commit 348fd37
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 8 deletions.
12 changes: 4 additions & 8 deletions ml_dtypes/_src/int4_numpy.h
Original file line number Diff line number Diff line change
Expand Up @@ -718,6 +718,10 @@ bool RegisterInt4Casts() {
return false;
}
}
if (PyArray_RegisterCanCast(&TypeDescriptor<T>::npy_descr, NPY_HALF,
NPY_NOSCALAR) < 0) {
return false;
}
if (PyArray_RegisterCanCast(&TypeDescriptor<T>::npy_descr, NPY_FLOAT,
NPY_NOSCALAR) < 0) {
return false;
Expand Down Expand Up @@ -748,14 +752,6 @@ bool RegisterInt4Casts() {
TypeDescriptor<T>::Dtype(), NPY_NOSCALAR) < 0) {
return false;
}
if (PyArray_RegisterCanCast(PyArray_DescrFromType(NPY_UBYTE),
TypeDescriptor<T>::Dtype(), NPY_NOSCALAR) < 0) {
return false;
}
if (PyArray_RegisterCanCast(PyArray_DescrFromType(NPY_BYTE),
TypeDescriptor<T>::Dtype(), NPY_NOSCALAR) < 0) {
return false;
}

return true;
}
Expand Down
58 changes: 58 additions & 0 deletions ml_dtypes/tests/int4_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,64 @@ def testBinop(self, scalar_type, op):
self.assertIsInstance(out, scalar_type)
self.assertEqual(scalar_type(op(v, w)), out, msg=(v, w))

@parameterized.product(
scalar_type=INT4_TYPES,
dtype=[
np.float16,
np.float32,
np.float64,
np.int8,
np.int16,
np.int32,
np.int64,
np.complex64,
np.complex128,
np.uint8,
np.uint16,
np.uint32,
np.uint64,
np.intc,
np.int_,
np.longlong,
np.uintc,
np.ulonglong,
],
)
def testCanCast(self, scalar_type, dtype):
allowed_casts = [
(np.bool_, int4),
(int4, np.int8),
(int4, np.int16),
(int4, np.int32),
(int4, np.int64),
(int4, np.float16),
(int4, np.float32),
(int4, np.float64),
(int4, np.complex64),
(int4, np.complex128),
(np.bool_, uint4),
(uint4, np.int8),
(uint4, np.int16),
(uint4, np.int32),
(uint4, np.int64),
(uint4, np.uint8),
(uint4, np.uint16),
(uint4, np.uint32),
(uint4, np.uint64),
(uint4, np.float16),
(uint4, np.float32),
(uint4, np.float64),
(uint4, np.complex64),
(uint4, np.complex128),
]

assert ((scalar_type, dtype) in allowed_casts) == np.can_cast(
scalar_type, dtype
)
assert ((dtype, scalar_type) in allowed_casts) == np.can_cast(
dtype, scalar_type
)


# Tests for the Python scalar type
class ArrayTest(parameterized.TestCase):
Expand Down

0 comments on commit 348fd37

Please sign in to comment.