Skip to content

Commit

Permalink
Add find APIs for map with custom hash and key equality functions (#665)
Browse files Browse the repository at this point in the history
Related to #662

This PR adds `find` APIs for `static_map` to take a custom hash and key
comparator.
  • Loading branch information
PointKernel authored Feb 6, 2025
1 parent 67110a5 commit 2dcd6f2
Show file tree
Hide file tree
Showing 3 changed files with 171 additions and 12 deletions.
59 changes: 58 additions & 1 deletion include/cuco/detail/static_map/static_map.inl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
* Copyright (c) 2023-2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -491,6 +491,30 @@ void static_map<Key, T, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Stora
impl_->find_async(first, last, output_begin, ref(op::find), stream);
}

template <class Key,
class T,
class Extent,
cuda::thread_scope Scope,
class KeyEqual,
class ProbingScheme,
class Allocator,
class Storage>
template <typename InputIt, typename ProbeEqual, typename ProbeHash, typename OutputIt>
void static_map<Key, T, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::find_async(
InputIt first,
InputIt last,
ProbeEqual const& probe_equal,
ProbeHash const& probe_hash,
OutputIt output_begin,
cuda::stream_ref stream) const
{
impl_->find_async(first,
last,
output_begin,
ref(op::find).rebind_key_eq(probe_equal).rebind_hash_function(probe_hash),
stream);
}

template <class Key,
class T,
class Extent,
Expand Down Expand Up @@ -532,6 +556,39 @@ void static_map<Key, T, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Stora
impl_->find_if_async(first, last, stencil, pred, output_begin, ref(op::find), stream);
}

template <class Key,
class T,
class Extent,
cuda::thread_scope Scope,
class KeyEqual,
class ProbingScheme,
class Allocator,
class Storage>
template <typename InputIt,
typename StencilIt,
typename Predicate,
typename ProbeEqual,
typename ProbeHash,
typename OutputIt>
void static_map<Key, T, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::find_if_async(
InputIt first,
InputIt last,
StencilIt stencil,
Predicate pred,
ProbeEqual const& probe_equal,
ProbeHash const& probe_hash,
OutputIt output_begin,
cuda::stream_ref stream) const
{
impl_->find_if_async(first,
last,
stencil,
pred,
output_begin,
ref(op::find).rebind_key_eq(probe_equal).rebind_hash_function(probe_hash),
stream);
}

template <class Key,
class T,
class Extent,
Expand Down
71 changes: 70 additions & 1 deletion include/cuco/static_map.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020-2024, NVIDIA CORPORATION.
* Copyright (c) 2020-2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -767,6 +767,33 @@ class static_map {
OutputIt output_begin,
cuda::stream_ref stream = {}) const;

/**
* @brief For all keys in the range `[first, last)`, asynchronously finds an element with key
* equivalent to the query key.
*
* @note If the key `*(first + i)` has a matched `element` in the map, copies the payload of
* `element` to `(output_begin + i)`. Else, copies the empty value sentinel.
*
* @tparam InputIt Device accessible input iterator
* @tparam ProbeEqual Binary callable equal type
* @tparam ProbeHash Unary callable hasher type
* @tparam OutputIt Device accessible output iterator assignable from the map's `mapped_type`
*
* @param first Beginning of the sequence of keys
* @param last End of the sequence of keys
* @param probe_equal The binary function to compare map keys and probe keys for equality
* @param probe_hash The unary function to hash probe keys
* @param output_begin Beginning of the sequence of elements retrieved for each key
* @param stream Stream used for executing the kernels
*/
template <typename InputIt, typename ProbeEqual, typename ProbeHash, typename OutputIt>
void find_async(InputIt first,
InputIt last,
ProbeEqual const& probe_equal,
ProbeHash const& probe_hash,
OutputIt output_begin,
cuda::stream_ref stream = {}) const;

/**
* @brief For all keys in the range `[first, last)`, finds a match with its key equivalent to the
* query key.
Expand Down Expand Up @@ -831,6 +858,48 @@ class static_map {
OutputIt output_begin,
cuda::stream_ref stream = {}) const;

/**
* @brief For all keys in the range `[first, last)`, asynchronously finds
* a match with its key equivalent to the query key.
*
* @note If `pred( *(stencil + i) )` is true, stores the payload of the
* matched key or the `empty_value_sentienl` to `(output_begin + i)`. If `pred( *(stencil + i) )`
* is false, always stores the `empty_value_sentienl` to `(output_begin + i)`.
*
* @tparam InputIt Device accessible input iterator
* @tparam StencilIt Device accessible random access iterator whose `value_type` is convertible to
* Predicate's argument type
* @tparam Predicate Unary predicate callable whose return type must be convertible to `bool` and
* argument type is convertible from <tt>std::iterator_traits<StencilIt>::value_type</tt>
* @tparam ProbeEqual Binary callable equal type
* @tparam ProbeHash Unary callable hasher type
* @tparam OutputIt Device accessible output iterator
*
* @param first Beginning of the sequence of keys
* @param last End of the sequence of keys
* @param stencil Beginning of the stencil sequence
* @param pred Predicate to test on every element in the range `[stencil, stencil +
* std::distance(first, last))`
* @param probe_equal The binary function to compare map keys and probe keys for equality
* @param probe_hash The unary function to hash probe keys
* @param output_begin Beginning of the sequence of matches retrieved for each key
* @param stream Stream used for executing the kernels
*/
template <typename InputIt,
typename StencilIt,
typename Predicate,
typename ProbeEqual,
typename ProbeHash,
typename OutputIt>
void find_if_async(InputIt first,
InputIt last,
StencilIt stencil,
Predicate pred,
ProbeEqual const& probe_equal,
ProbeHash const& probe_hash,
OutputIt output_begin,
cuda::stream_ref stream = {}) const;

/**
* @brief Applies the given function object `callback_op` to the copy of every filled slot in the
* container
Expand Down
53 changes: 43 additions & 10 deletions tests/static_map/find_test.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020-2024, NVIDIA CORPORATION.
* Copyright (c) 2020-2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -35,6 +35,14 @@ using size_type = int32_t;

int32_t constexpr SENTINEL = -1;

struct always_false {
template <typename T>
__device__ bool operator()(T const&, T const&) const
{
return false;
}
};

template <typename Map>
void test_unique_sequence(Map& map, size_type num_keys)
{
Expand All @@ -52,10 +60,10 @@ void test_unique_sequence(Map& map, size_type num_keys)
auto is_even =
cuda::proclaim_return_type<bool>([] __device__(auto const& i) { return i % 2 == 0; });

thrust::device_vector<Value> d_results(num_keys);

SECTION("Non-inserted keys have no matches")
{
thrust::device_vector<Value> d_results(num_keys);

map.find(keys_begin, keys_begin + num_keys, d_results.begin());
auto zip = thrust::make_zip_iterator(thrust::make_tuple(
d_results.begin(), thrust::constant_iterator<Key>{map.empty_key_sentinel()}));
Expand All @@ -67,34 +75,59 @@ void test_unique_sequence(Map& map, size_type num_keys)

SECTION("All inserted keys should be correctly recovered during find")
{
thrust::device_vector<Value> d_results(num_keys);

map.find(keys_begin, keys_begin + num_keys, d_results.begin());
auto zip = thrust::make_zip_iterator(thrust::make_tuple(d_results.begin(), keys_begin));

REQUIRE(cuco::test::all_of(zip, zip + num_keys, zip_equal));
}

SECTION("No keys should be found with custom always_false equal")
{
map.find_async(
keys_begin, keys_begin + num_keys, always_false{}, map.hash_function(), d_results.begin());
CUCO_CUDA_TRY(cudaDeviceSynchronize());
auto zip = thrust::make_zip_iterator(thrust::make_tuple(
d_results.begin(), thrust::constant_iterator<Value>{map.empty_value_sentinel()}));

REQUIRE(cuco::test::all_of(zip, zip + num_keys, zip_equal));
}

SECTION("Conditional find should return valid values on even inputs.")
{
auto found_results = thrust::device_vector<Key>(num_keys);
auto gold_fn = cuda::proclaim_return_type<Value>([] __device__(auto const& i) {
auto gold_fn = cuda::proclaim_return_type<Value>([] __device__(auto const& i) {
return i % 2 == 0 ? static_cast<Value>(i) : Value{SENTINEL};
});

map.find_if(keys_begin,
keys_begin + num_keys,
thrust::counting_iterator<std::size_t>{0},
is_even,
found_results.begin());
d_results.begin());

REQUIRE(cuco::test::equal(
found_results.begin(),
found_results.end(),
d_results.begin(),
d_results.end(),
thrust::make_transform_iterator(thrust::counting_iterator<Key>{0}, gold_fn),
cuda::proclaim_return_type<bool>(
[] __device__(auto const& found, auto const& gold) { return found == gold; })));
}

SECTION("Conditional find with always_false should always get sentinel.")
{
map.find_if_async(keys_begin,
keys_begin + num_keys,
thrust::counting_iterator<std::size_t>{0},
is_even,
always_false{},
map.hash_function(),
d_results.begin());

CUCO_CUDA_TRY(cudaDeviceSynchronize());
auto zip = thrust::make_zip_iterator(thrust::make_tuple(
d_results.begin(), thrust::constant_iterator<Value>{map.empty_value_sentinel()}));

REQUIRE(cuco::test::all_of(zip, zip + num_keys, zip_equal));
}
}

TEMPLATE_TEST_CASE_SIG(
Expand Down

0 comments on commit 2dcd6f2

Please sign in to comment.