Skip to content

Commit

Permalink
[Wasm] Fix computation of initial HT capacity.
Browse files Browse the repository at this point in the history
Up to now, we potentially stored a value larger than a `uint32_t` in
`initial_capacity` leading to undefined behavior. We now properly cast
the value to a `uint64_t` and check if it fits into a `uint32_t`,
otherwise we return the maximal value of `uint32_t`.

Furthermore, we put the repetitive computation of the initial hash table
capacity into its own function.
  • Loading branch information
JorisNix committed May 1, 2024
1 parent 159dd21 commit 519a3c3
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 32 deletions.
3 changes: 3 additions & 0 deletions src/backend/WasmAlgo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1456,6 +1456,9 @@ OpenAddressingHashTable<IsGlobal, ValueInPlace>::OpenAddressingHashTable(const S
}

/*----- Initialize capacity and absolute high watermark. -----*/
M_insist(initial_capacity < std::numeric_limits<uint32_t>::max(),
"incremented initial capacity would exceed data type");
++initial_capacity; // since at least one entry must always be unoccupied for lookups
/* at least capacity 4 to ensure absolute high watermark of at least 1 even for minimal percentage of 0.5 */
const auto capacity_init = std::max<uint32_t>(4, ceil_to_pow_2(initial_capacity));
const auto mask_init = capacity_init - 1U;
Expand Down
51 changes: 19 additions & 32 deletions src/backend/WasmOperator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -874,6 +874,22 @@ Ptr<void> get_base_address(const ThreadSafePooledString &table_name) {
return Module::Get().get_global<void*>(oss.str().c_str());
}

/** Computes the initial hash table capacity for \p op. The function ensures that the initial capacity is in the range
* [0, 2^32 - 1] such that the capacity does *not* exceed the `uint32_t` value limit. */
uint32_t compute_initial_ht_capacity(const Operator &op, double load_factor) {
uint64_t initial_capacity;
if (options::hash_table_initial_capacity) {
initial_capacity = *options::hash_table_initial_capacity;
} else {
if (op.has_info())
initial_capacity = static_cast<uint64_t>(std::ceil(op.info().estimated_cardinality / load_factor));
else if (auto scan = cast<const ScanOperator>(&op))
initial_capacity = static_cast<uint64_t>(std::ceil(scan->store().num_rows() / load_factor));
else
initial_capacity = 1024; // fallback
}
return std::in_range<uint32_t>(initial_capacity) ? initial_capacity : std::numeric_limits<uint32_t>::max();
}

/*======================================================================================================================
* NoOp
Expand Down Expand Up @@ -1447,22 +1463,13 @@ void HashBasedGrouping::execute(const Match<HashBasedGrouping> &M, setup_t setup
}

/*----- Compute initial capacity of hash table. -----*/
uint32_t initial_capacity;
if (options::hash_table_initial_capacity) {
initial_capacity = *options::hash_table_initial_capacity;
} else {
if (M.grouping.has_info())
initial_capacity = std::ceil(M.grouping.info().estimated_cardinality / M.load_factor);
else
initial_capacity = 1024; // fallback
}
uint32_t initial_capacity = compute_initial_ht_capacity(M.grouping, M.load_factor);

/*----- Create hash table. -----*/
std::unique_ptr<HashTable> ht;
std::vector<HashTable::index_t> key_indices(num_keys);
std::iota(key_indices.begin(), key_indices.end(), 0);
if (M.use_open_addressing_hashing) {
++initial_capacity; // since at least one entry must always be unoccupied for lookups
if (aggregates_size_in_bits < AGGREGATES_SIZE_THRESHOLD_IN_BITS)
ht = std::make_unique<GlobalOpenAddressingInPlaceHashTable>(ht_schema, std::move(key_indices),
initial_capacity);
Expand Down Expand Up @@ -3924,25 +3931,14 @@ void SimpleHashJoin<UniqueBuild, Predicated>::execute(const Match<SimpleHashJoin
}

/*----- Compute initial capacity of hash table. -----*/
uint32_t initial_capacity;
if (options::hash_table_initial_capacity) {
initial_capacity = *options::hash_table_initial_capacity;
} else {
if (M.build.has_info())
initial_capacity = std::ceil(M.build.info().estimated_cardinality / M.load_factor);
else if (auto scan = cast<const ScanOperator>(&M.build))
initial_capacity = std::ceil(scan->store().num_rows() / M.load_factor);
else
initial_capacity = 1024; // fallback
}
uint32_t initial_capacity = compute_initial_ht_capacity(M.build, M.load_factor);

/*----- Create hash table for build child. -----*/
std::unique_ptr<HashTable> ht;
std::vector<HashTable::index_t> build_key_indices;
for (auto &build_key : build_keys)
build_key_indices.push_back(ht_schema[build_key].first);
if (M.use_open_addressing_hashing) {
++initial_capacity; // since at least one entry must always be unoccupied for lookups
if (payload_size_in_bits < PAYLOAD_SIZE_THRESHOLD_IN_BITS)
ht = std::make_unique<GlobalOpenAddressingInPlaceHashTable>(ht_schema, std::move(build_key_indices),
initial_capacity);
Expand Down Expand Up @@ -4546,22 +4542,13 @@ void HashBasedGroupJoin::execute(const Match<HashBasedGroupJoin> &M, setup_t set
M_insist(build_keys.size() == num_keys);

/*----- Compute initial capacity of hash table. -----*/
uint32_t initial_capacity;
if (options::hash_table_initial_capacity) {
initial_capacity = *options::hash_table_initial_capacity;
} else {
if (M.grouping.has_info())
initial_capacity = std::ceil(M.grouping.info().estimated_cardinality / M.load_factor);
else
initial_capacity = 1024; // fallback
}
uint32_t initial_capacity = compute_initial_ht_capacity(M.grouping, M.load_factor);

/*----- Create hash table for build relation. -----*/
std::unique_ptr<HashTable> ht;
std::vector<HashTable::index_t> key_indices(num_keys);
std::iota(key_indices.begin(), key_indices.end(), 0);
if (M.use_open_addressing_hashing) {
++initial_capacity; // since at least one entry must always be unoccupied for lookups
if (aggregates_size_in_bits < AGGREGATES_SIZE_THRESHOLD_IN_BITS)
ht = std::make_unique<GlobalOpenAddressingInPlaceHashTable>(ht_schema, std::move(key_indices),
initial_capacity);
Expand Down

0 comments on commit 519a3c3

Please sign in to comment.