Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CI: add a build against numpy nightly #129

Merged
merged 2 commits into from
Dec 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,27 @@ jobs:
- name: Run tests
run: |
pytest -n auto
build-nightly:
name: Python 3.12 with nightly numpy
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
with:
submodules: true
- name: Set up Python 3.12
uses: actions/setup-python@v4
with:
python-version: "3.12"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install setuptools wheel
python -m pip install -U --pre numpy \
-i https://pypi.anaconda.org/scientific-python-nightly-wheels/simple
python -c "import numpy; print(f'{numpy.__version__=}')"
- name: Build ml_dtypes
run: |
python -m pip install .[dev] --no-build-isolation
- name: Run tests
run: |
pytest -n auto
33 changes: 5 additions & 28 deletions ml_dtypes/_src/custom_float.h
Original file line number Diff line number Diff line change
Expand Up @@ -861,6 +861,9 @@ bool RegisterFloatUFuncs(PyObject* numpy) {
return ok;
}

// TODO(jakevdp): simplify the following. The already_registered check is no
// longer necessary, and heap allocation is probably not important any longer.
//
// Returns true if the numpy type for T is successfully registered, including if
// it was already registered (e.g. by a different library). If
// `already_registered` is non-null, it's set to true if the type was already
Expand All @@ -870,35 +873,9 @@ bool RegisterFloatDtype(PyObject* numpy, bool* already_registered = nullptr) {
if (already_registered != nullptr) {
*already_registered = false;
}
// If another module (presumably either TF or JAX) has registered a bfloat16
// type, use it. We don't want two bfloat16 types if we can avoid it since it
// leads to confusion if we have two different types with the same name. This
// assumes that the other module has a sufficiently complete bfloat16
// implementation. The only known NumPy bfloat16 extension at the time of
// writing is this one (distributed in TF and JAX).
// TODO(phawkins): distribute the bfloat16 extension as its own pip package,
jakevdp marked this conversation as resolved.
Show resolved Hide resolved
// so we can unambiguously refer to a single canonical definition of bfloat16.
int typenum =
PyArray_TypeNumFromName(const_cast<char*>(TypeDescriptor<T>::kTypeName));
if (typenum != NPY_NOTYPE) {
PyArray_Descr* descr = PyArray_DescrFromType(typenum);
// The test for an argmax function here is to verify that the
// bfloat16 implementation is sufficiently new, and, say, not from
// an older version of TF or JAX.
if (descr && descr->f && descr->f->argmax) {
TypeDescriptor<T>::npy_type = typenum;
TypeDescriptor<T>::type_ptr = reinterpret_cast<PyObject*>(descr->typeobj);
if (already_registered != nullptr) {
*already_registered = true;
}
return true;
}
}

// It's important that we heap-allocate our type. This is because tp_name
jakevdp marked this conversation as resolved.
Show resolved Hide resolved
// is not a fully-qualified name for a heap-allocated type, and
// PyArray_TypeNumFromName() (above) looks at the tp_name field to find
// types. Existing implementations in JAX and TensorFlow look for "bfloat16",
// is not a fully-qualified name for a heap-allocated type.
// Existing implementations in JAX and TensorFlow look for "bfloat16",
// not "ml_dtypes.bfloat16" when searching for an implementation.
Safe_PyObjectPtr name =
make_safe(PyUnicode_FromString(TypeDescriptor<T>::kTypeName));
Expand Down
10 changes: 9 additions & 1 deletion ml_dtypes/tests/custom_float_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,14 @@
float8_e5m2fnuz = ml_dtypes.float8_e5m2fnuz


try:
# numpy >= 2.0
ComplexWarning = np.exceptions.ComplexWarning
except AttributeError:
# numpy < 2.0
ComplexWarning = np.ComplexWarning


@contextlib.contextmanager
def ignore_warning(**kw):
with warnings.catch_warnings():
Expand Down Expand Up @@ -703,7 +711,7 @@ def testCasts(self, float_type):
self.assertTrue(np.all(x == z))
self.assertEqual(dtype, z.dtype)

@ignore_warning(category=np.ComplexWarning)
@ignore_warning(category=ComplexWarning)
def testConformNumpyComplex(self, float_type):
for dtype in [np.complex64, np.complex128, np.clongdouble]:
x = np.array([1.5, 2.5 + 2.0j, 3.5], dtype=dtype)
Expand Down
1 change: 1 addition & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
[pytest]
filterwarnings =
error
ignore:numpy.core._multiarray_umat.*:DeprecationWarning