diff --git a/tests/static_map/unique_sequence_test.cu b/tests/static_map/unique_sequence_test.cu index 4ab864ab7..8e8d6c9ad 100644 --- a/tests/static_map/unique_sequence_test.cu +++ b/tests/static_map/unique_sequence_test.cu @@ -26,7 +26,6 @@ #include #include #include -#include #include #include @@ -34,17 +33,15 @@ using size_type = int32_t; +int32_t constexpr SENTINEL = -1; + template void test_unique_sequence(Map& map, size_type num_keys) { using Key = typename Map::key_type; using Value = typename Map::mapped_type; - thrust::device_vector d_keys(num_keys); - - thrust::sequence(thrust::device, d_keys.begin(), d_keys.end()); - - auto keys_begin = d_keys.begin(); + auto keys_begin = thrust::counting_iterator{0}; auto pairs_begin = thrust::make_transform_iterator( thrust::make_counting_iterator(0), cuda::proclaim_return_type>( @@ -128,6 +125,27 @@ void test_unique_sequence(Map& map, size_type num_keys) REQUIRE(cuco::test::all_of(zip, zip + num_keys, zip_equal)); } + SECTION("Conditional find should return valid values on even inputs.") + { + auto found_results = thrust::device_vector(num_keys); + auto gold_fn = cuda::proclaim_return_type([] __device__(auto const& i) { + return i % 2 == 0 ? static_cast(i) : Value{SENTINEL}; + }); + + map.find_if(keys_begin, + keys_begin + num_keys, + thrust::counting_iterator{0}, + is_even, + found_results.begin()); + + REQUIRE(cuco::test::equal( + found_results.begin(), + found_results.end(), + thrust::make_transform_iterator(thrust::counting_iterator{0}, gold_fn), + cuda::proclaim_return_type( + [] __device__(auto const& found, auto const& gold) { return found == gold; }))); + } + SECTION("All inserted key-values should be properly retrieved") { thrust::device_vector d_values(num_keys); @@ -188,7 +206,7 @@ TEMPLATE_TEST_CASE_SIG( probe, cuco::cuda_allocator, cuco::storage<2>>{ - extent_type{}, cuco::empty_key{-1}, cuco::empty_value{-1}}; + extent_type{}, cuco::empty_key{SENTINEL}, cuco::empty_value{SENTINEL}}; REQUIRE(map.capacity() == gold_capacity); diff --git a/tests/static_multimap/find_test.cu b/tests/static_multimap/find_test.cu index 51456b088..3fe0ae8bc 100644 --- a/tests/static_multimap/find_test.cu +++ b/tests/static_multimap/find_test.cu @@ -28,6 +28,9 @@ using size_type = int32_t; +int32_t constexpr KEY_SENTINEL = -1; +int32_t constexpr VAL_SENTINEL = -2; + template void test_multimap_find(Map& map, size_type num_keys) { @@ -70,6 +73,29 @@ void test_multimap_find(Map& map, size_type num_keys) REQUIRE(cuco::test::all_of(zip, zip + num_keys, zip_equal)); } + + SECTION("Conditional find should return valid values on even inputs.") + { + auto found_results = thrust::device_vector(num_keys); + auto is_even = + cuda::proclaim_return_type([] __device__(auto const& i) { return i % 2 == 0; }); + auto gold_fn = cuda::proclaim_return_type([] __device__(auto const& i) { + return i % 2 == 0 ? static_cast(i) * 2 : Value{VAL_SENTINEL}; + }); + + map.find_if(keys_begin, + keys_begin + num_keys, + thrust::counting_iterator{0}, + is_even, + found_results.begin()); + + REQUIRE(cuco::test::equal( + found_results.begin(), + found_results.end(), + thrust::make_transform_iterator(thrust::counting_iterator{0}, gold_fn), + cuda::proclaim_return_type( + [] __device__(auto const& found, auto const& gold) { return found == gold; }))); + } } TEMPLATE_TEST_CASE_SIG( @@ -100,7 +126,7 @@ TEMPLATE_TEST_CASE_SIG( probe, cuco::cuda_allocator, cuco::storage<2>>{ - num_keys, cuco::empty_key{-1}, cuco::empty_value{-2}}; + num_keys, cuco::empty_key{KEY_SENTINEL}, cuco::empty_value{VAL_SENTINEL}}; test_multimap_find(map, num_keys); } diff --git a/tests/static_multiset/find_test.cu b/tests/static_multiset/find_test.cu index b0945ab90..6379b60fb 100644 --- a/tests/static_multiset/find_test.cu +++ b/tests/static_multiset/find_test.cu @@ -29,16 +29,14 @@ using size_type = int32_t; +int32_t constexpr SENTINEL = -1; + template void test_unique_sequence(Set& set, size_type num_keys) { using Key = typename Set::key_type; - thrust::device_vector d_keys(num_keys); - - thrust::sequence(d_keys.begin(), d_keys.end()); - - auto keys_begin = d_keys.begin(); + auto keys_begin = thrust::counting_iterator{0}; thrust::device_vector d_contained(num_keys); auto zip_equal = cuda::proclaim_return_type( @@ -66,6 +64,28 @@ void test_unique_sequence(Set& set, size_type num_keys) REQUIRE(cuco::test::all_of(zip, zip + num_keys, zip_equal)); } + + SECTION("Conditional find should return valid values on even inputs.") + { + auto found_results = thrust::device_vector(num_keys); + auto is_even = + cuda::proclaim_return_type([] __device__(auto const& i) { return i % 2 == 0; }); + auto gold_fn = cuda::proclaim_return_type( + [] __device__(auto const& i) { return i % 2 == 0 ? static_cast(i) : Key{SENTINEL}; }); + + set.find_if(keys_begin, + keys_begin + num_keys, + thrust::counting_iterator{0}, + is_even, + found_results.begin()); + + REQUIRE(cuco::test::equal( + found_results.begin(), + found_results.end(), + thrust::make_transform_iterator(thrust::counting_iterator{0}, gold_fn), + cuda::proclaim_return_type( + [] __device__(auto const& found, auto const& gold) { return found == gold; }))); + } } TEMPLATE_TEST_CASE_SIG( @@ -87,8 +107,8 @@ TEMPLATE_TEST_CASE_SIG( cuco::linear_probing>, cuco::double_hashing>>; - auto set = - cuco::static_multiset{num_keys, cuco::empty_key{-1}, {}, probe{}, {}, cuco::storage<2>{}}; + auto set = cuco::static_multiset{ + num_keys, cuco::empty_key{SENTINEL}, {}, probe{}, {}, cuco::storage<2>{}}; test_unique_sequence(set, num_keys); } diff --git a/tests/static_set/unique_sequence_test.cu b/tests/static_set/unique_sequence_test.cu index 5e8299fa0..0cd2924a9 100644 --- a/tests/static_set/unique_sequence_test.cu +++ b/tests/static_set/unique_sequence_test.cu @@ -24,7 +24,6 @@ #include #include #include -#include #include #include