Skip to content

Commit

Permalink
modify test to avoid false-positive results
Browse files Browse the repository at this point in the history
  • Loading branch information
Dmitry Razdoburdin committed Feb 28, 2024
1 parent baa24c7 commit 862afb2
Showing 1 changed file with 17 additions and 22 deletions.
39 changes: 17 additions & 22 deletions tests/cpp/plugin/test_sycl_partition_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <string>
#include <utility>
#include <vector>
#include <numeric>

#include "../../../plugin/sycl/common/partition_builder.h"
#include "../../../plugin/sycl/device_manager.h"
Expand Down Expand Up @@ -66,58 +67,52 @@ void TestPartitioning(float sparsity, int max_bins) {
partition_builder.MergeToArray(0, data_result, &event);
qu.wait_and_throw();

std::stringstream ss;
bst_float split_pt = gmat.cut.Values()[split_conditions[0]];
ss << "split_pt = " << split_pt << "\n";

std::vector<size_t> ridx_left;
std::vector<size_t> ridx_right;
std::vector<uint8_t> ridx_left(num_rows, 0);
std::vector<uint8_t> ridx_right(num_rows, 0);
for (auto &batch : gmat.p_fmat->GetBatches<SparsePage>()) {
const auto& data_vec = batch.data.HostVector();
const auto& offset_vec = batch.offset.HostVector();

size_t begin = offset_vec[0];
for (size_t idx = 0; idx < offset_vec.size() - 1; ++idx) {
size_t end = offset_vec[idx + 1];
ss << "idx = " << idx << "\tbegin = " << begin << "\tend = " << end;
if (begin < end) {
const auto& entry = data_vec[begin];
ss << "fvalue = " << entry.fvalue;
if (entry.fvalue < split_pt) {
ridx_left.push_back(idx);
ridx_left[idx] = 1;
} else {
ridx_right.push_back(idx);
ridx_right[idx] = 1;
}
} else {
// missing value
ss << "default_left = " << tree[0].DefaultLeft();
if (tree[0].DefaultLeft()) {
ridx_left.push_back(idx);
ridx_left[idx] = 1;
} else {
ridx_right.push_back(idx);
ridx_right[idx] = 1;
}
}
ss << "\n";
begin = end;
}
}
auto n_left = std::accumulate(ridx_left.begin(), ridx_left.end(), 0);
auto n_right = std::accumulate(ridx_right.begin(), ridx_right.end(), 0);

std::vector<size_t> row_indices_host(num_rows);
qu.memcpy(row_indices_host.data(), row_indices.Data(), num_rows * sizeof(size_t));
qu.wait_and_throw();
ss << "row_indices = ";
for (auto idx : row_indices_host) {
ss << idx << "\t";
}

ASSERT_EQ(ridx_left.size(), partition_builder.GetNLeftElems(0)) << ss.str();
for (size_t i = 0; i < ridx_left.size(); ++i) {
ASSERT_EQ(ridx_left[i], row_indices_host[i]) << ss.str();
ASSERT_EQ(n_left, partition_builder.GetNLeftElems(0));
for (size_t i = 0; i < n_left; ++i) {
auto idx = row_indices_host[i];
ASSERT_EQ(ridx_left[idx], 1);
}

ASSERT_EQ(ridx_right.size(), partition_builder.GetNRightElems(0)) << ss.str();
for (size_t i = 0; i < ridx_right.size(); ++i) {
ASSERT_EQ(ridx_right[i], row_indices_host[num_rows - 1 - i]) << ss.str();
ASSERT_EQ(n_right, partition_builder.GetNRightElems(0));
for (size_t i = 0; i < n_right; ++i) {
auto idx = row_indices_host[num_rows - 1 - i];
ASSERT_EQ(ridx_right[idx], 1);
}
}

Expand Down

0 comments on commit 862afb2

Please sign in to comment.