From e08e8c51d9c24c487131824ad6866293ca1f5f4e Mon Sep 17 00:00:00 2001 From: Daniel Ng Date: Wed, 20 Sep 2023 21:38:11 -0700 Subject: [PATCH] Add float8 & int4 numpy integration PiperOrigin-RevId: 567178215 --- ml_dtypes/_src/int4_numpy.h | 12 +++----- ml_dtypes/tests/int4_test.py | 58 ++++++++++++++++++++++++++++++++++++ 2 files changed, 62 insertions(+), 8 deletions(-) diff --git a/ml_dtypes/_src/int4_numpy.h b/ml_dtypes/_src/int4_numpy.h index 605a31a0..2b065628 100644 --- a/ml_dtypes/_src/int4_numpy.h +++ b/ml_dtypes/_src/int4_numpy.h @@ -718,6 +718,10 @@ bool RegisterInt4Casts() { return false; } } + if (PyArray_RegisterCanCast(&TypeDescriptor::npy_descr, NPY_HALF, + NPY_NOSCALAR) < 0) { + return false; + } if (PyArray_RegisterCanCast(&TypeDescriptor::npy_descr, NPY_FLOAT, NPY_NOSCALAR) < 0) { return false; @@ -748,14 +752,6 @@ bool RegisterInt4Casts() { TypeDescriptor::Dtype(), NPY_NOSCALAR) < 0) { return false; } - if (PyArray_RegisterCanCast(PyArray_DescrFromType(NPY_UBYTE), - TypeDescriptor::Dtype(), NPY_NOSCALAR) < 0) { - return false; - } - if (PyArray_RegisterCanCast(PyArray_DescrFromType(NPY_BYTE), - TypeDescriptor::Dtype(), NPY_NOSCALAR) < 0) { - return false; - } return true; } diff --git a/ml_dtypes/tests/int4_test.py b/ml_dtypes/tests/int4_test.py index 9332cb46..d62bf9f0 100644 --- a/ml_dtypes/tests/int4_test.py +++ b/ml_dtypes/tests/int4_test.py @@ -155,6 +155,64 @@ def testBinop(self, scalar_type, op): self.assertIsInstance(out, scalar_type) self.assertEqual(scalar_type(op(v, w)), out, msg=(v, w)) + @parameterized.product( + scalar_type=INT4_TYPES, + dtype=[ + np.float16, + np.float32, + np.float64, + np.int8, + np.int16, + np.int32, + np.int64, + np.complex64, + np.complex128, + np.uint8, + np.uint16, + np.uint32, + np.uint64, + np.intc, + np.int_, + np.longlong, + np.uintc, + np.ulonglong, + ], + ) + def testCanCast(self, scalar_type, dtype): + allowed_casts = [ + (np.bool_, int4), + (int4, np.int8), + (int4, np.int16), + (int4, np.int32), + (int4, np.int64), + (int4, np.float16), + (int4, np.float32), + (int4, np.float64), + (int4, np.complex64), + (int4, np.complex128), + (np.bool_, uint4), + (uint4, np.int8), + (uint4, np.int16), + (uint4, np.int32), + (uint4, np.int64), + (uint4, np.uint8), + (uint4, np.uint16), + (uint4, np.uint32), + (uint4, np.uint64), + (uint4, np.float16), + (uint4, np.float32), + (uint4, np.float64), + (uint4, np.complex64), + (uint4, np.complex128), + ] + + assert ((scalar_type, dtype) in allowed_casts) == np.can_cast( + scalar_type, dtype + ) + assert ((dtype, scalar_type) in allowed_casts) == np.can_cast( + dtype, scalar_type + ) + # Tests for the Python scalar type class ArrayTest(parameterized.TestCase):