Skip to content

Commit

Permalink
Merge pull request #129 from jakevdp:ci-numpy-nightly
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 592609686
  • Loading branch information
The ml_dtypes Authors committed Dec 20, 2023
2 parents 97d392e + 6632fdb commit 29edcb3
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 29 deletions.
24 changes: 24 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,27 @@ jobs:
- name: Run tests
run: |
pytest -n auto
build-nightly:
name: Python 3.12 with nightly numpy
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
with:
submodules: true
- name: Set up Python 3.12
uses: actions/setup-python@v4
with:
python-version: "3.12"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install setuptools wheel
python -m pip install -U --pre numpy \
-i https://pypi.anaconda.org/scientific-python-nightly-wheels/simple
python -c "import numpy; print(f'{numpy.__version__=}')"
- name: Build ml_dtypes
run: |
python -m pip install .[dev] --no-build-isolation
- name: Run tests
run: |
pytest -n auto
33 changes: 5 additions & 28 deletions ml_dtypes/_src/custom_float.h
Original file line number Diff line number Diff line change
Expand Up @@ -861,6 +861,9 @@ bool RegisterFloatUFuncs(PyObject* numpy) {
return ok;
}

// TODO(jakevdp): simplify the following. The already_registered check is no
// longer necessary, and heap allocation is probably not important any longer.
//
// Returns true if the numpy type for T is successfully registered, including if
// it was already registered (e.g. by a different library). If
// `already_registered` is non-null, it's set to true if the type was already
Expand All @@ -870,35 +873,9 @@ bool RegisterFloatDtype(PyObject* numpy, bool* already_registered = nullptr) {
if (already_registered != nullptr) {
*already_registered = false;
}
// If another module (presumably either TF or JAX) has registered a bfloat16
// type, use it. We don't want two bfloat16 types if we can avoid it since it
// leads to confusion if we have two different types with the same name. This
// assumes that the other module has a sufficiently complete bfloat16
// implementation. The only known NumPy bfloat16 extension at the time of
// writing is this one (distributed in TF and JAX).
// TODO(phawkins): distribute the bfloat16 extension as its own pip package,
// so we can unambiguously refer to a single canonical definition of bfloat16.
int typenum =
PyArray_TypeNumFromName(const_cast<char*>(TypeDescriptor<T>::kTypeName));
if (typenum != NPY_NOTYPE) {
PyArray_Descr* descr = PyArray_DescrFromType(typenum);
// The test for an argmax function here is to verify that the
// bfloat16 implementation is sufficiently new, and, say, not from
// an older version of TF or JAX.
if (descr && descr->f && descr->f->argmax) {
TypeDescriptor<T>::npy_type = typenum;
TypeDescriptor<T>::type_ptr = reinterpret_cast<PyObject*>(descr->typeobj);
if (already_registered != nullptr) {
*already_registered = true;
}
return true;
}
}

// It's important that we heap-allocate our type. This is because tp_name
// is not a fully-qualified name for a heap-allocated type, and
// PyArray_TypeNumFromName() (above) looks at the tp_name field to find
// types. Existing implementations in JAX and TensorFlow look for "bfloat16",
// is not a fully-qualified name for a heap-allocated type.
// Existing implementations in JAX and TensorFlow look for "bfloat16",
// not "ml_dtypes.bfloat16" when searching for an implementation.
Safe_PyObjectPtr name =
make_safe(PyUnicode_FromString(TypeDescriptor<T>::kTypeName));
Expand Down
10 changes: 9 additions & 1 deletion ml_dtypes/tests/custom_float_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,14 @@
float8_e5m2fnuz = ml_dtypes.float8_e5m2fnuz


try:
# numpy >= 2.0
ComplexWarning = np.exceptions.ComplexWarning
except AttributeError:
# numpy < 2.0
ComplexWarning = np.ComplexWarning


@contextlib.contextmanager
def ignore_warning(**kw):
with warnings.catch_warnings():
Expand Down Expand Up @@ -703,7 +711,7 @@ def testCasts(self, float_type):
self.assertTrue(np.all(x == z))
self.assertEqual(dtype, z.dtype)

@ignore_warning(category=np.ComplexWarning)
@ignore_warning(category=ComplexWarning)
def testConformNumpyComplex(self, float_type):
for dtype in [np.complex64, np.complex128, np.clongdouble]:
x = np.array([1.5, 2.5 + 2.0j, 3.5], dtype=dtype)
Expand Down
1 change: 1 addition & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
[pytest]
filterwarnings =
error
ignore:numpy.core._multiarray_umat.*:DeprecationWarning

0 comments on commit 29edcb3

Please sign in to comment.