Skip to content

Commit

Permalink
Distributed kNN scalability and optimizations (uxlfoundation#2558)
Browse files Browse the repository at this point in the history
* Profiling additions for benchmarking

* dblock cap+last iter,split_table profile,var names

* trying revert of data_management

* custom max and split table event

* address some todos and cleanup finalize

* remove temp_resp_ + clang

* send recv replace debug

* updated debug

* extended profiling

* temporary for CI build

* cleanup and removal of unneeded profiling

* syncing data_management with master

* I_MPI_OFFLOAD condition for green bazel

* temporary conditionals add for bench

* for bench only

* detailed select_indexed profiling

* removing select_indexed_local calls

* restoring communicator (see uxlfoundation#2577)

* select_indexed debugging removals

* search_dpc debugging cleanup

* knn cleanup and clang

* single gpu/distributed unification

* addressing comments

* correction to previous

* clean up comments

* addressing some comments

* clang
  • Loading branch information
ethanglaser authored Dec 21, 2023
1 parent a1075af commit 4ff1b59
Show file tree
Hide file tree
Showing 5 changed files with 138 additions and 177 deletions.
3 changes: 3 additions & 0 deletions cpp/oneapi/dal/algo/knn/backend/gpu/infer_kernel_impl_dpc.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
#include "oneapi/dal/table/row_accessor.hpp"

#include "oneapi/dal/detail/common.hpp"
#include "oneapi/dal/detail/profiler.hpp"

namespace oneapi::dal::knn::backend {

Expand Down Expand Up @@ -166,6 +167,7 @@ class knn_callback {
pr::ndview<idx_t, 2>& inp_indices,
pr::ndview<Float, 2>& inp_distances,
const bk::event_vector& deps = {}) {
ONEDAL_PROFILER_TASK(query_loop.callback, queue_);
sycl::event copy_indices, copy_distances, comp_responses;

const auto bounds = this->block_bounds(qb_id);
Expand Down Expand Up @@ -473,6 +475,7 @@ sycl::event bf_kernel(sycl::queue& queue,
distance_impl->get_daal_distance_type() == daal_distance_t::cosine;
const bool is_euclidean_distance =
is_minkowski_distance && (distance_impl->get_degree() == 2.0);
ONEDAL_ASSERT(is_minkowski_distance ^ is_chebyshev_distance ^ is_cosine_distance);

sycl::event search_event;

Expand Down
223 changes: 87 additions & 136 deletions cpp/oneapi/dal/algo/knn/backend/gpu/infer_kernel_impl_dpc_distr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
#include "oneapi/dal/table/row_accessor.hpp"

#include "oneapi/dal/detail/common.hpp"
#include "oneapi/dal/detail/profiler.hpp"

namespace oneapi::dal::knn::backend {

Expand Down Expand Up @@ -72,13 +73,7 @@ class knn_callback_distr {
result_options_(results),
query_block_(query_block),
query_length_(query_length),
k_neighbors_(k_neighbors) {
if (result_options_.test(result_options::responses)) {
this->temp_resp_ = pr::ndarray<res_t, 2>::empty(q,
{ query_block, k_neighbors },
sycl::usm::alloc::device);
}
}
k_neighbors_(k_neighbors) {}

auto& set_euclidean_distance(bool is_euclidean_distance) {
this->compute_sqrt_ = is_euclidean_distance;
Expand Down Expand Up @@ -196,50 +191,27 @@ class knn_callback_distr {
return *this;
}

sycl::event finalize(std::int64_t qb_id,
pr::ndview<idx_t, 2>& inp_indices,
pr::ndview<Float, 2>& inp_distances,
const bk::event_vector& deps = {}) {
sycl::event copy_indices, copy_distances, comp_responses;

const auto bounds = this->block_bounds(qb_id);

if (result_options_.test(result_options::indices)) {
copy_indices = this->output_indices(bounds, inp_indices, deps);
}

if (result_options_.test(result_options::distances)) {
copy_distances = this->output_distances(bounds, inp_distances, deps);
}

if (result_options_.test(result_options::responses)) {
using namespace bk;
const auto ndeps = deps + copy_indices + copy_distances;
comp_responses = this->output_responses(bounds, inp_indices, inp_distances, ndeps);
}

sycl::event::wait_and_throw({ copy_indices, copy_distances, comp_responses });
return sycl::event();
}

sycl::event operator()(std::int64_t qb_id,
pr::ndview<idx_t, 2>& inp_indices,
pr::ndview<Float, 2>& inp_distances,
const bk::event_vector& deps = {}) {
ONEDAL_PROFILER_TASK(query_loop.callback, queue_);
sycl::event copy_actual_dist_event, copy_current_dist_event, copy_actual_indc_event,
copy_current_indc_event, copy_actual_resp_event, copy_current_resp_event;
const auto& [first, last] = this->block_bounds(qb_id);
const auto& bounds = this->block_bounds(qb_id);
const auto& [first, last] = bounds;
const auto len = last - first;
ONEDAL_ASSERT(last > first);
ONEDAL_ASSERT(inp_indices.get_dimension(0) == len);
ONEDAL_ASSERT(inp_indices.get_dimension(1) == k_neighbors_);
ONEDAL_ASSERT(inp_distances.get_dimension(0) == len);
ONEDAL_ASSERT(inp_distances.get_dimension(1) == k_neighbors_);

auto inp_responses = this->temp_resp_.get_row_slice(0, len);
auto current_min_resp_dest = part_responses_.get_col_slice(k_neighbors_, 2 * k_neighbors_)
.get_row_slice(first, last);

auto select_inp_resp_event =
pr::select_indexed(queue_, inp_indices, train_responses_, inp_responses, deps);
copy_current_resp_event =
pr::select_indexed(queue_, inp_indices, train_responses_, current_min_resp_dest, deps);

const pr::ndshape<2> typical_blocking(last - first, 2 * k_neighbors_);
auto select = selc_t(queue_, typical_blocking, k_neighbors_);
Expand All @@ -250,8 +222,10 @@ class knn_callback_distr {

// add global offset value to input indices
ONEDAL_ASSERT(global_index_offset_ != -1);
auto treat_event =
pr::treat_indices(queue_, inp_indices, global_index_offset_, { select_inp_resp_event });
auto treat_event = pr::treat_indices(queue_,
inp_indices,
global_index_offset_,
{ copy_current_resp_event });

auto actual_min_dist_copy_dest =
part_distances_.get_col_slice(0, k_neighbors_).get_row_slice(first, last);
Expand All @@ -273,39 +247,48 @@ class knn_callback_distr {

auto actual_min_resp_copy_dest =
part_responses_.get_col_slice(0, k_neighbors_).get_row_slice(first, last);
auto current_min_resp_dest = part_responses_.get_col_slice(k_neighbors_, 2 * k_neighbors_)
.get_row_slice(first, last);
copy_actual_resp_event =
pr::copy(queue_, actual_min_resp_copy_dest, min_resp_dest, { treat_event });
copy_current_resp_event =
pr::copy(queue_, current_min_resp_dest, inp_responses, { treat_event });

auto kselect_block = part_distances_.get_row_slice(first, last);
auto selt_event = select(queue_,
kselect_block,
k_neighbors_,
min_dist_dest,
min_indc_dest,
{ copy_actual_dist_event,
copy_current_dist_event,
copy_actual_indc_event,
copy_current_indc_event,
copy_actual_resp_event,
copy_current_resp_event });
auto resps_event = select_indexed(queue_,
min_indc_dest,
part_responses_.get_row_slice(first, last),
min_resp_dest,
{ selt_event });
auto final_event = select_indexed(queue_,
min_indc_dest,
part_indices_.get_row_slice(first, last),
min_indc_dest,
{ resps_event });

sycl::event select_event;
{
ONEDAL_PROFILER_TASK(query_loop.selection, queue_);
auto kselect_block = part_distances_.get_row_slice(first, last);
select_event = select(queue_,
kselect_block,
k_neighbors_,
min_dist_dest,
min_indc_dest,
{ copy_actual_dist_event,
copy_current_dist_event,
copy_actual_indc_event,
copy_current_indc_event,
copy_actual_resp_event,
copy_current_resp_event });
}
auto select_resp_event = select_indexed(queue_,
min_indc_dest,
part_responses_.get_row_slice(first, last),
min_resp_dest,
{ select_event });
auto select_indc_event = select_indexed(queue_,
min_indc_dest,
part_indices_.get_row_slice(first, last),
min_indc_dest,
{ select_resp_event });
if (last_iteration_) {
final_event = finalize(qb_id, indices_, distances_, { final_event });
sycl::event copy_sqrt_event;
if (this->compute_sqrt_) {
copy_sqrt_event =
copy_with_sqrt(queue_, min_dist_dest, min_dist_dest, { select_indc_event });
}
auto final_event = this->output_responses(bounds,
indices_,
distances_,
{ select_indc_event, copy_sqrt_event });
return final_event;
}
return final_event;
return select_indc_event;
}

protected:
Expand All @@ -320,45 +303,6 @@ class knn_callback_distr {
return std::make_pair(first, last);
}

sycl::event output_distances(const std::pair<idx_t, idx_t>& bnds,
const pr::ndview<dst_t, 2>& inp_dts,
const bk::event_vector& deps = {}) {
ONEDAL_ASSERT(inp_dts.has_data());
ONEDAL_ASSERT(this->result_options_.test(result_options::distances));

const auto& [first, last] = bnds;
ONEDAL_ASSERT(last > first);
auto& queue = this->queue_;

auto out_dts = this->distances_.get_row_slice(first, last);
ONEDAL_ASSERT((last - first) == inp_dts.get_dimension(0));
ONEDAL_ASSERT((last - first) == out_dts.get_dimension(0));

const bool& csqrt = this->compute_sqrt_;
if (!csqrt)
return pr::copy(queue, out_dts, inp_dts, deps);
else
return copy_with_sqrt(queue, inp_dts, out_dts, deps);
}

sycl::event output_indices(const std::pair<idx_t, idx_t>& bnds,
const pr::ndview<idx_t, 2>& inp_ids,
const bk::event_vector& deps = {}) {
ONEDAL_ASSERT(inp_ids.has_data());
ONEDAL_ASSERT(this->result_options_.test(result_options::indices));

const auto& [first, last] = bnds;
ONEDAL_ASSERT(last > first);
auto& queue = this->queue_;

auto out_ids = this->indices_.get_row_slice(first, last);
ONEDAL_ASSERT((last - first) == inp_ids.get_dimension(0));
ONEDAL_ASSERT((last - first) == out_ids.get_dimension(0));
ONEDAL_ASSERT(inp_ids.get_shape() == out_ids.get_shape());

return pr::copy(queue, out_ids, inp_ids, deps);
}

template <typename T = Task, typename = detail::enable_if_classification_t<T>>
sycl::event do_ucls(const std::pair<idx_t, idx_t>& bnds,
const pr::ndview<res_t, 2>& tmp_rps,
Expand Down Expand Up @@ -481,7 +425,6 @@ class knn_callback_distr {
const result_option_id result_options_;
const std::int64_t query_block_, query_length_, k_neighbors_;
pr::ndview<res_t, 1> train_responses_;
pr::ndarray<res_t, 2> temp_resp_;
pr::ndview<res_t, 1> responses_;
pr::ndview<res_t, 2> part_responses_;
pr::ndview<res_t, 2> intermediate_responses_;
Expand Down Expand Up @@ -519,7 +462,7 @@ sycl::event bf_kernel_distr(sycl::queue& queue,
// Input arrays test section
ONEDAL_ASSERT(train.has_data());
ONEDAL_ASSERT(query.has_data());
[[maybe_unused]] auto tcount = train.get_row_count();
const auto tcount = train.get_row_count();
const auto qcount = query.get_dimension(0);
const auto fcount = train.get_column_count();
const auto kcount = desc.get_neighbor_count();
Expand Down Expand Up @@ -558,6 +501,13 @@ sycl::event bf_kernel_distr(sycl::queue& queue,

comm.allgather(tcount, node_sample_counts.flatten()).wait();

// TODO: implement max/min for ndarray
std::int64_t max_tcount = 0;
for (std::int64_t index = 0; index < node_sample_counts.get_count(); ++index) {
max_tcount = std::max(node_sample_counts.at(index), max_tcount);
}
block_size = std::min(max_tcount, block_size);

auto current_rank = comm.get_rank();
auto prev_node = (current_rank - 1 + rank_count) % rank_count;
auto next_node = (current_rank + 1) % rank_count;
Expand Down Expand Up @@ -631,12 +581,14 @@ sycl::event bf_kernel_distr(sycl::queue& queue,
distance_impl->get_daal_distance_type() == daal_distance_t::cosine;
const bool is_euclidean_distance =
is_minkowski_distance && (distance_impl->get_degree() == 2.0);
ONEDAL_ASSERT(is_minkowski_distance ^ is_chebyshev_distance ^ is_cosine_distance);

const auto it = std::find(nodes.begin(), nodes.end(), current_rank);
auto first_block_index = std::distance(nodes.begin(), it);
auto relative_block_offset = std::distance(nodes.begin(), it);
ONEDAL_ASSERT(it != nodes.end());

for (std::int64_t block_number = 0; block_number < block_count; ++block_number) {
for (std::int64_t relative_block_idx = 0; relative_block_idx < block_count;
++relative_block_idx) {
auto current_block = train_block_queue.front();
train_block_queue.pop_front();
ONEDAL_ASSERT(current_block.has_data());
Expand All @@ -646,19 +598,20 @@ sycl::event bf_kernel_distr(sycl::queue& queue,
pr::ndview<res_t, 1>::wrap(current_tresps.get_data(), { current_tresps.get_count() });
tresps_queue.pop_front();

auto block_index = (block_number + first_block_index) % block_count;
ONEDAL_ASSERT(block_index + 1 < bounds_size);
auto actual_rows_in_block = boundaries.at(block_index + 1) - boundaries.at(block_index);
auto absolute_block_idx = (relative_block_idx + relative_block_offset) % block_count;
ONEDAL_ASSERT(absolute_block_idx + 1 < bounds_size);
auto actual_rows_in_block =
boundaries.at(absolute_block_idx + 1) - boundaries.at(absolute_block_idx);

auto sc = current_block.get_dimension(0);
ONEDAL_ASSERT(sc >= actual_rows_in_block);
auto curr_k = std::min(actual_rows_in_block, kcount);
auto actual_current_block = current_block.get_row_slice(0, actual_rows_in_block);
auto actual_current_tresps = current_tresps_1d.get_slice(0, actual_rows_in_block);

callback.set_global_index_offset(boundaries.at(block_index));
callback.set_global_index_offset(boundaries.at(absolute_block_idx));
callback.set_train_responses(actual_current_tresps);
if (block_number == block_count - 1) {
if (relative_block_idx == block_count - 1) {
callback.set_last_iteration(true);
}
if (is_cosine_distance) {
Expand Down Expand Up @@ -699,26 +652,24 @@ sycl::event bf_kernel_distr(sycl::queue& queue,
next_event = search(query, callback, qbcount, curr_k, { next_event });
}

auto send_count = current_block.get_count();
ONEDAL_ASSERT(send_count >= 0);
ONEDAL_ASSERT(send_count <= de::limits<int>::max());
// send recv replace
comm.sendrecv_replace(array<Float>::wrap(queue,
current_block.get_mutable_data(),
send_count,
{ next_event }),
prev_node,
next_node)
.wait();
train_block_queue.emplace_back(current_block);
comm.sendrecv_replace(array<res_t>::wrap(queue,
current_tresps.get_mutable_data(),
current_tresps.get_count(),
{ next_event }),
prev_node,
next_node)
.wait();
tresps_queue.emplace_back(current_tresps);
if (relative_block_idx < block_count - 1) {
ONEDAL_PROFILER_TASK(distributed_loop.sendrecv_replace, queue);
auto send_count = current_block.get_count();
ONEDAL_ASSERT(send_count >= 0);
ONEDAL_ASSERT(send_count <= de::limits<int>::max());
auto send_train_block = array<Float>::wrap(queue,
current_block.get_mutable_data(),
send_count,
{ next_event });
comm.sendrecv_replace(send_train_block, prev_node, next_node).wait();
train_block_queue.emplace_back(current_block);
auto send_resps_block = array<res_t>::wrap(queue,
current_tresps.get_mutable_data(),
current_tresps.get_count(),
{ next_event });
comm.sendrecv_replace(send_resps_block, prev_node, next_node).wait();
tresps_queue.emplace_back(current_tresps);
}
}

return next_event;
Expand Down
Loading

0 comments on commit 4ff1b59

Please sign in to comment.