Skip to content

Commit

Permalink
Add integer datatypes and add types to tests
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexanderKalistratov committed Oct 9, 2024
1 parent c3c194c commit 8842fa6
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 14 deletions.
22 changes: 13 additions & 9 deletions .github/workflows/conda-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ jobs:
${{ runner.os }}-conda-${{ env.CACHE_NUMBER }}-
- name: Install dpnp
run: mamba install ${{ env.PACKAGE_NAME }}=${{ env.PACKAGE_VERSION }} pytest python=${{ matrix.python }} ${{ env.TEST_CHANNELS }}
run: mamba install ${{ env.PACKAGE_NAME }}=${{ env.PACKAGE_VERSION }} pytest pytest-xdist python=${{ matrix.python }} ${{ env.TEST_CHANNELS }}
env:
TEST_CHANNELS: '-c ${{ env.channel-path }} ${{ env.CHANNELS }}'
MAMBA_NO_LOW_SPEED_LIMIT: 1
Expand All @@ -257,7 +257,8 @@ jobs:
- name: Run tests
if: env.RERUN_TESTS_ON_FAILURE != 'true'
run: |
python -m pytest -q -ra --disable-warnings -vv ${{ env.TEST_SCOPE }}
export DPNP_TEST_ALL_TYPES=1
python -m pytest -n auto -q -ra --disable-warnings -vv ${{ env.TEST_SCOPE }}
working-directory: ${{ env.tests-path }}

- name: Run tests
Expand All @@ -266,14 +267,15 @@ jobs:
uses: nick-fields/retry@7152eba30c6575329ac0576536151aca5a72780e # v3.0.0
with:
shell: bash
timeout_minutes: 10
timeout_minutes: 45
max_attempts: ${{ env.RUN_TESTS_MAX_ATTEMPTS }}
retry_on: any
command: |
. $CONDA/etc/profile.d/conda.sh
conda activate ${{ env.TEST_ENV_NAME }}
cd ${{ env.tests-path }}
python -m pytest -q -ra --disable-warnings -vv ${{ env.TEST_SCOPE }}
export DPNP_TEST_ALL_TYPES=1
python -m pytest -n auto -q -ra --disable-warnings -vv ${{ env.TEST_SCOPE }}
test_windows:
name: Test ['windows-2019', python='${{ matrix.python }}']
Expand Down Expand Up @@ -387,7 +389,7 @@ jobs:
- name: Install dpnp
run: |
@echo on
mamba install ${{ env.PACKAGE_NAME }}=${{ env.PACKAGE_VERSION }} pytest python=${{ matrix.python }} ${{ env.TEST_CHANNELS }}
mamba install ${{ env.PACKAGE_NAME }}=${{ env.PACKAGE_VERSION }} pytest pytest-xdist python=${{ matrix.python }} ${{ env.TEST_CHANNELS }}
env:
TEST_CHANNELS: '-c ${{ env.channel-path }} ${{ env.CHANNELS }}'
MAMBA_NO_LOW_SPEED_LIMIT: 1
Expand All @@ -412,7 +414,8 @@ jobs:
- name: Run tests
if: env.RERUN_TESTS_ON_FAILURE != 'true'
run: |
python -m pytest -q -ra --disable-warnings -vv ${{ env.TEST_SCOPE }}
set DPNP_TEST_ALL_TYPES=1
python -m pytest -n auto -q -ra --disable-warnings -vv ${{ env.TEST_SCOPE }}
working-directory: ${{ env.tests-path }}

- name: Run tests
Expand All @@ -421,13 +424,14 @@ jobs:
uses: nick-fields/retry@7152eba30c6575329ac0576536151aca5a72780e # v3.0.0
with:
shell: cmd
timeout_minutes: 15
timeout_minutes: 45
max_attempts: ${{ env.RUN_TESTS_MAX_ATTEMPTS }}
retry_on: any
command: >-
mamba activate ${{ env.TEST_ENV_NAME }}
set DPNP_TEST_ALL_TYPES=1
& mamba activate ${{ env.TEST_ENV_NAME }}
& cd ${{ env.tests-path }}
& python -m pytest -q -ra --disable-warnings -vv ${{ env.TEST_SCOPE }}
& python -m pytest -n auto -q -ra --disable-warnings -vv ${{ env.TEST_SCOPE }}
upload:
name: Upload ['${{ matrix.os }}', python='${{ matrix.python }}']
Expand Down
12 changes: 12 additions & 0 deletions dpnp/dpnp_iface_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,14 @@
"inf",
"int",
"int_",
"int8",
"int16",
"int32",
"int64",
"uint8",
"uint16",
"uint32",
"uint64",
"integer",
"intc",
"intp",
Expand Down Expand Up @@ -95,8 +101,14 @@
inexact = numpy.inexact
int = numpy.int_
int_ = numpy.int_
int8 = numpy.int8
int16 = numpy.int16
int32 = numpy.int32
int64 = numpy.int64
uint8 = numpy.uint8
uint16 = numpy.uint16
uint32 = numpy.uint32
uint64 = numpy.uint64
integer = numpy.integer
intc = numpy.intc
intp = numpy.intp
Expand Down
3 changes: 3 additions & 0 deletions tests/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
import os

all_types = int(os.getenv("DPNP_TEST_ALL_TYPES", 0))
13 changes: 13 additions & 0 deletions tests/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from numpy.testing import assert_allclose, assert_array_equal

import dpnp
from tests import config


def assert_dtype_allclose(
Expand Down Expand Up @@ -88,6 +89,18 @@ def get_integer_dtypes():
Build a list of integer types supported by DPNP.
"""

if config.all_types:
return [
dpnp.int8,
dpnp.int16,
dpnp.int32,
dpnp.int64,
dpnp.uint8,
dpnp.uint16,
dpnp.uint32,
dpnp.uint64,
]

return [dpnp.int32, dpnp.int64]


Expand Down
31 changes: 26 additions & 5 deletions tests/third_party/cupy/testing/_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from dpctl.tensor._numpy_helper import AxisError

import dpnp as cupy
from tests import config
from tests.third_party.cupy.testing import _array, _parameterized
from tests.third_party.cupy.testing._pytest_impl import is_available

Expand Down Expand Up @@ -1039,19 +1040,39 @@ def _get_supported_complex_dtypes():
return (numpy.complex64,)


def _get_int_dtypes():
if config.all_types:
return _signed_dtypes + _unsigned_dtypes
else:
return (numpy.int64, numpy.int32)


_complex_dtypes = _get_supported_complex_dtypes()
_regular_float_dtypes = _get_supported_float_dtypes()
_float_dtypes = _regular_float_dtypes
_signed_dtypes = ()
_float_dtypes = _regular_float_dtypes + (numpy.float16,)
_signed_dtypes = tuple(numpy.dtype(i).type for i in "bhilq")
_unsigned_dtypes = tuple(numpy.dtype(i).type for i in "BHILQ")
_int_dtypes = _signed_dtypes + _unsigned_dtypes
_int_bool_dtypes = _int_dtypes
_int_dtypes = _get_int_dtypes()
_int_bool_dtypes = _int_dtypes + (numpy.bool_,)
_regular_dtypes = _regular_float_dtypes + _int_bool_dtypes
_dtypes = _float_dtypes + _int_bool_dtypes


def _make_all_dtypes(no_float16, no_bool, no_complex):
return (numpy.int64, numpy.int32) + _get_supported_float_dtypes()
if no_float16:
dtypes = _regular_float_dtypes
else:
dtypes = _float_dtypes

if no_bool:
dtypes += _int_dtypes
else:
dtypes += _int_bool_dtypes

if not no_complex:
dtypes += _complex_dtypes

return dtypes


def for_all_dtypes(
Expand Down

0 comments on commit 8842fa6

Please sign in to comment.