Skip to content

Commit

Permalink
fix(tests): enshure sort keys are unique
Browse files Browse the repository at this point in the history
In tests which involve sorting by keys, we need to enshure that the keys
are not repeated.
Fixes #104
  • Loading branch information
nauaneed committed Sep 20, 2024
1 parent 3309757 commit 2db3578
Showing 1 changed file with 23 additions and 6 deletions.
29 changes: 23 additions & 6 deletions compyle/tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,8 +266,13 @@ def test_sort_by_keys(backend):
check_import(backend)

# Given
nparr1 = np.random.randint(0, 100, 16, dtype=np.int32)
nparr2 = np.random.randint(0, 100, 16, dtype=np.int32)
pre_nparr1 = np.random.randint(0, 100, 16, dtype=np.int32)
pre_nparr2 = np.random.randint(0, 100, 16, dtype=np.int32)

## drop non unique values
nparr1, indices = np.unique(pre_nparr1, return_index=True)
nparr2 = pre_nparr2[indices]

dev_array1, dev_array2 = array.wrap(nparr1, nparr2, backend=backend)

# When
Expand All @@ -286,8 +291,13 @@ def test_radix_sort_by_keys():
for use_openmp in [True, False]:
get_config().use_openmp = use_openmp
# Given
nparr1 = np.random.randint(0, 100, 16, dtype=np.int32)
nparr2 = np.random.randint(0, 100, 16, dtype=np.int32)
pre_nparr1 = np.random.randint(0, 100, 16, dtype=np.int32)
pre_nparr2 = np.random.randint(0, 100, 16, dtype=np.int32)

## drop non unique values
nparr1, indices = np.unique(pre_nparr1, return_index=True)
nparr2 = pre_nparr2[indices]

dev_array1, dev_array2 = array.wrap(nparr1, nparr2, backend=backend)

# When
Expand All @@ -298,6 +308,8 @@ def test_radix_sort_by_keys():
order = np.argsort(nparr1)
act_result1 = np.take(nparr1, order)
act_result2 = np.take(nparr2, order)
if not np.all(out_array1.get() == act_result1) or not np.all(out_array2.get() == act_result2):
print('About to fail')
assert np.all(out_array1.get() == act_result1)
assert np.all(out_array2.get() == act_result2)
get_config().use_openmp = False
Expand All @@ -310,8 +322,13 @@ def test_sort_by_keys_with_output(backend):
check_import(backend)

# Given
nparr1 = np.random.randint(0, 100, 16, dtype=np.int32)
nparr2 = np.random.randint(0, 100, 16, dtype=np.int32)
pre_nparr1 = np.random.randint(0, 100, 16, dtype=np.int32)
pre_nparr2 = np.random.randint(0, 100, 16, dtype=np.int32)

## drop non unique values
nparr1, indices = np.unique(pre_nparr1, return_index=True)
nparr2 = pre_nparr2[indices]

dev_array1, dev_array2 = array.wrap(nparr1, nparr2, backend=backend)
out_arrays = [
array.zeros_like(dev_array1),
Expand Down

0 comments on commit 2db3578

Please sign in to comment.