diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index d4cdfa74..9e1a50a5 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -50,3 +50,27 @@ jobs: - name: Run tests run: | pytest -n auto + build-nightly: + name: Python 3.12 with nightly numpy + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + with: + submodules: true + - name: Set up Python 3.12 + uses: actions/setup-python@v4 + with: + python-version: "3.12" + - name: Install dependencies + run: | + python -m pip install --upgrade pip + python -m pip install setuptools wheel + python -m pip install -U --pre numpy \ + -i https://pypi.anaconda.org/scientific-python-nightly-wheels/simple + python -c "import numpy; print(f'{numpy.__version__=}')" + - name: Build ml_dtypes + run: | + python -m pip install .[dev] --no-build-isolation + - name: Run tests + run: | + pytest -n auto diff --git a/ml_dtypes/_src/custom_float.h b/ml_dtypes/_src/custom_float.h index c3afc74a..c1cccc12 100644 --- a/ml_dtypes/_src/custom_float.h +++ b/ml_dtypes/_src/custom_float.h @@ -861,6 +861,9 @@ bool RegisterFloatUFuncs(PyObject* numpy) { return ok; } +// TODO(jakevdp): simplify the following. The already_registered check is no +// longer necessary, and heap allocation is probably not important any longer. +// // Returns true if the numpy type for T is successfully registered, including if // it was already registered (e.g. by a different library). If // `already_registered` is non-null, it's set to true if the type was already @@ -870,35 +873,9 @@ bool RegisterFloatDtype(PyObject* numpy, bool* already_registered = nullptr) { if (already_registered != nullptr) { *already_registered = false; } - // If another module (presumably either TF or JAX) has registered a bfloat16 - // type, use it. We don't want two bfloat16 types if we can avoid it since it - // leads to confusion if we have two different types with the same name. This - // assumes that the other module has a sufficiently complete bfloat16 - // implementation. The only known NumPy bfloat16 extension at the time of - // writing is this one (distributed in TF and JAX). - // TODO(phawkins): distribute the bfloat16 extension as its own pip package, - // so we can unambiguously refer to a single canonical definition of bfloat16. - int typenum = - PyArray_TypeNumFromName(const_cast(TypeDescriptor::kTypeName)); - if (typenum != NPY_NOTYPE) { - PyArray_Descr* descr = PyArray_DescrFromType(typenum); - // The test for an argmax function here is to verify that the - // bfloat16 implementation is sufficiently new, and, say, not from - // an older version of TF or JAX. - if (descr && descr->f && descr->f->argmax) { - TypeDescriptor::npy_type = typenum; - TypeDescriptor::type_ptr = reinterpret_cast(descr->typeobj); - if (already_registered != nullptr) { - *already_registered = true; - } - return true; - } - } - // It's important that we heap-allocate our type. This is because tp_name - // is not a fully-qualified name for a heap-allocated type, and - // PyArray_TypeNumFromName() (above) looks at the tp_name field to find - // types. Existing implementations in JAX and TensorFlow look for "bfloat16", + // is not a fully-qualified name for a heap-allocated type. + // Existing implementations in JAX and TensorFlow look for "bfloat16", // not "ml_dtypes.bfloat16" when searching for an implementation. Safe_PyObjectPtr name = make_safe(PyUnicode_FromString(TypeDescriptor::kTypeName)); diff --git a/ml_dtypes/tests/custom_float_test.py b/ml_dtypes/tests/custom_float_test.py index d71ae8b1..6ca15b21 100644 --- a/ml_dtypes/tests/custom_float_test.py +++ b/ml_dtypes/tests/custom_float_test.py @@ -37,6 +37,14 @@ float8_e5m2fnuz = ml_dtypes.float8_e5m2fnuz +try: + # numpy >= 2.0 + ComplexWarning = np.exceptions.ComplexWarning +except AttributeError: + # numpy < 2.0 + ComplexWarning = np.ComplexWarning + + @contextlib.contextmanager def ignore_warning(**kw): with warnings.catch_warnings(): @@ -703,7 +711,7 @@ def testCasts(self, float_type): self.assertTrue(np.all(x == z)) self.assertEqual(dtype, z.dtype) - @ignore_warning(category=np.ComplexWarning) + @ignore_warning(category=ComplexWarning) def testConformNumpyComplex(self, float_type): for dtype in [np.complex64, np.complex128, np.clongdouble]: x = np.array([1.5, 2.5 + 2.0j, 3.5], dtype=dtype) diff --git a/pytest.ini b/pytest.ini index 9e90ef7c..2842b4cf 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,3 +1,4 @@ [pytest] filterwarnings = error + ignore:numpy.core._multiarray_umat.*:DeprecationWarning