Skip to content

Commit

Permalink
add more tests for sycl-objectives
Browse files Browse the repository at this point in the history
  • Loading branch information
Dmitry Razdoburdin committed Dec 6, 2024
1 parent a03b92e commit bd83108
Show file tree
Hide file tree
Showing 13 changed files with 412 additions and 181 deletions.
6 changes: 4 additions & 2 deletions src/common/ranking_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,9 @@ class NDCGCache : public RankingCache {
}

linalg::VectorView<double const> InvIDCG(Context const* ctx) const {
return inv_idcg_.View(ctx->Device());
// This function doesn't have sycl-specific implementation yet.
// For that reason we transfer data to host in case of sycl is used for propper execution.
return inv_idcg_.View(ctx->Device().IsSycl() ? DeviceOrd::CPU() : ctx->Device());
}
common::Span<double const> Discount(Context const* ctx) const {
return ctx->IsCUDA() ? discounts_.ConstDeviceSpan() : discounts_.ConstHostSpan();
Expand All @@ -330,7 +332,7 @@ class NDCGCache : public RankingCache {
dcg_.SetDevice(ctx->Device());
dcg_.Reshape(this->Groups());
}
return dcg_.View(ctx->Device());
return dcg_.View(ctx->Device().IsSycl() ? DeviceOrd::CPU() : ctx->Device());
}
};

Expand Down
4 changes: 2 additions & 2 deletions src/objective/lambdarank_obj.cc
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,8 @@ class LambdaRankObj : public FitIntercept {
lj_full_.View(ctx_->Device()), &ti_plus_, &tj_minus_,
&li_, &lj_, p_cache_);
} else {
cpu_impl::LambdaRankUpdatePositionBias(ctx_, li_full_.View(ctx_->Device()),
lj_full_.View(ctx_->Device()), &ti_plus_, &tj_minus_,
cpu_impl::LambdaRankUpdatePositionBias(ctx_, li_full_.View(ctx_->Device().IsSycl() ? DeviceOrd::CPU() : ctx_->Device()),
lj_full_.View(ctx_->Device().IsSycl() ? DeviceOrd::CPU() : ctx_->Device()), &ti_plus_, &tj_minus_,
&li_, &lj_, p_cache_);
}

Expand Down
26 changes: 11 additions & 15 deletions tests/cpp/objective/test_aft_obj.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@
#include "xgboost/objective.h"
#include "xgboost/logging.h"
#include "../helpers.h"
#include "test_aft_obj.h"

namespace xgboost::common {
TEST(Objective, DeclareUnifiedTest(AFTObjConfiguration)) {
auto ctx = MakeCUDACtx(GPUIDX);
std::unique_ptr<ObjFunction> objective(ObjFunction::Create("survival:aft", &ctx));
void TestAFTObjConfiguration(const Context* ctx) {
std::unique_ptr<ObjFunction> objective(ObjFunction::Create("survival:aft", ctx));
objective->Configure({ {"aft_loss_distribution", "logistic"},
{"aft_loss_distribution_scale", "5"} });

Expand Down Expand Up @@ -73,9 +73,8 @@ static inline void CheckGPairOverGridPoints(
}
}

TEST(Objective, DeclareUnifiedTest(AFTObjGPairUncensoredLabels)) {
auto ctx = MakeCUDACtx(GPUIDX);
std::unique_ptr<ObjFunction> obj(ObjFunction::Create("survival:aft", &ctx));
void TestAFTObjGPairUncensoredLabels(const Context* ctx) {
std::unique_ptr<ObjFunction> obj(ObjFunction::Create("survival:aft", ctx));

CheckGPairOverGridPoints(obj.get(), 100.0f, 100.0f, "normal",
{ -3.9120f, -3.4013f, -2.8905f, -2.3798f, -1.8691f, -1.3583f, -0.8476f, -0.3368f, 0.1739f,
Expand All @@ -97,9 +96,8 @@ TEST(Objective, DeclareUnifiedTest(AFTObjGPairUncensoredLabels)) {
0.3026f, 0.1816f, 0.1090f, 0.0654f, 0.0392f, 0.0235f, 0.0141f, 0.0085f, 0.0051f, 0.0031f });
}

TEST(Objective, DeclareUnifiedTest(AFTObjGPairLeftCensoredLabels)) {
auto ctx = MakeCUDACtx(GPUIDX);
std::unique_ptr<ObjFunction> obj(ObjFunction::Create("survival:aft", &ctx));
void TestAFTObjGPairLeftCensoredLabels(const Context* ctx) {
std::unique_ptr<ObjFunction> obj(ObjFunction::Create("survival:aft", ctx));

CheckGPairOverGridPoints(obj.get(), 0.0f, 20.0f, "normal",
{ 0.0285f, 0.0832f, 0.1951f, 0.3804f, 0.6403f, 0.9643f, 1.3379f, 1.7475f, 2.1828f, 2.6361f,
Expand All @@ -118,9 +116,8 @@ TEST(Objective, DeclareUnifiedTest(AFTObjGPairLeftCensoredLabels)) {
0.0296f, 0.0179f, 0.0108f, 0.0065f, 0.0039f, 0.0024f, 0.0014f, 0.0008f, 0.0005f, 0.0003f });
}

TEST(Objective, DeclareUnifiedTest(AFTObjGPairRightCensoredLabels)) {
auto ctx = MakeCUDACtx(GPUIDX);
std::unique_ptr<ObjFunction> obj(ObjFunction::Create("survival:aft", &ctx));
void TestAFTObjGPairRightCensoredLabels(const Context* ctx) {
std::unique_ptr<ObjFunction> obj(ObjFunction::Create("survival:aft", ctx));

CheckGPairOverGridPoints(obj.get(), 60.0f, std::numeric_limits<float>::infinity(), "normal",
{ -3.6583f, -3.1815f, -2.7135f, -2.2577f, -1.8190f, -1.4044f, -1.0239f, -0.6905f, -0.4190f,
Expand All @@ -142,9 +139,8 @@ TEST(Objective, DeclareUnifiedTest(AFTObjGPairRightCensoredLabels)) {
0.1816f, 0.1089f, 0.0654f, 0.0392f, 0.0235f, 0.0141f, 0.0085f, 0.0051f, 0.0031f, 0.0018f });
}

TEST(Objective, DeclareUnifiedTest(AFTObjGPairIntervalCensoredLabels)) {
auto ctx = MakeCUDACtx(GPUIDX);
std::unique_ptr<ObjFunction> obj(ObjFunction::Create("survival:aft", &ctx));
void TestAFTObjGPairIntervalCensoredLabels(const Context* ctx) {
std::unique_ptr<ObjFunction> obj(ObjFunction::Create("survival:aft", ctx));

CheckGPairOverGridPoints(obj.get(), 16.0f, 200.0f, "normal",
{ -2.4435f, -1.9965f, -1.5691f, -1.1679f, -0.7990f, -0.4649f, -0.1596f, 0.1336f, 0.4370f,
Expand Down
23 changes: 23 additions & 0 deletions tests/cpp/objective/test_aft_obj.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/**
* Copyright 2020-2024 by XGBoost Contributors
*/
#ifndef XGBOOST_TEST_AFT_OBJ_H_
#define XGBOOST_TEST_AFT_OBJ_H_

#include <xgboost/context.h> // for Context

namespace xgboost::common {

void TestAFTObjConfiguration(const Context* ctx);

void TestAFTObjGPairUncensoredLabels(const Context* ctx);

void TestAFTObjGPairLeftCensoredLabels(const Context* ctx);

void TestAFTObjGPairRightCensoredLabels(const Context* ctx);

void TestAFTObjGPairIntervalCensoredLabels(const Context* ctx);

} // namespace xgboost::common

#endif // XGBOOST_TEST_AFT_OBJ_H_
41 changes: 41 additions & 0 deletions tests/cpp/objective/test_aft_obj_cpu.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/**
* Copyright 2020-2024, XGBoost Contributors
*/
#include <gtest/gtest.h>
#include <memory>
#include <vector>
#include <limits>
#include <cmath>

#include "xgboost/objective.h"
#include "xgboost/logging.h"
#include "../helpers.h"
#include "test_aft_obj.h"

namespace xgboost::common {
TEST(Objective, DeclareUnifiedTest(AFTObjConfiguration)) {
auto ctx = MakeCUDACtx(GPUIDX);
TestAFTObjConfiguration(&ctx);
}

TEST(Objective, DeclareUnifiedTest(AFTObjGPairUncensoredLabels)) {
auto ctx = MakeCUDACtx(GPUIDX);
TestAFTObjGPairUncensoredLabels(&ctx);
}

TEST(Objective, DeclareUnifiedTest(AFTObjGPairLeftCensoredLabels)) {
auto ctx = MakeCUDACtx(GPUIDX);
TestAFTObjGPairLeftCensoredLabels(&ctx);
}

TEST(Objective, DeclareUnifiedTest(AFTObjGPairRightCensoredLabels)) {
auto ctx = MakeCUDACtx(GPUIDX);
TestAFTObjGPairRightCensoredLabels(&ctx);
}

TEST(Objective, DeclareUnifiedTest(AFTObjGPairIntervalCensoredLabels)) {
auto ctx = MakeCUDACtx(GPUIDX);
TestAFTObjGPairIntervalCensoredLabels(&ctx);
}

} // namespace xgboost::common
10 changes: 4 additions & 6 deletions tests/cpp/objective/test_lambdarank_obj.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ void TestNDCGGPair(Context const* ctx) {
{0, 2, 4},
{2.06611f, -2.06611f, 0.0f, 0.0f},
{2.169331f, 2.169331f, 0.0f, 0.0f});

CheckRankingObjFunction(obj,
{0, 0.1f, 0, 0.1f},
{0, 1, 0, 1},
Expand All @@ -65,7 +64,6 @@ void TestNDCGGPair(Context const* ctx) {
{2.06611f, -2.06611f, 2.06611f, -2.06611f},
{2.169331f, 2.169331f, 2.169331f, 2.169331f});
}

std::unique_ptr<xgboost::ObjFunction> obj{xgboost::ObjFunction::Create("rank:ndcg", ctx)};
obj->Configure(Args{{"lambdarank_pair_method", "topk"}});

Expand Down Expand Up @@ -246,9 +244,9 @@ void TestMAPStat(Context const* ctx) {

predt.SetDevice(ctx->Device());
auto rank_idx =
p_cache->SortedIdx(ctx, ctx->IsCPU() ? predt.ConstHostSpan() : predt.ConstDeviceSpan());
p_cache->SortedIdx(ctx, !ctx->IsCUDA() ? predt.ConstHostSpan() : predt.ConstDeviceSpan());

if (ctx->IsCPU()) {
if (!ctx->IsCUDA()) {
obj::cpu_impl::MAPStat(ctx, info.labels.HostView().Slice(linalg::All(), 0), rank_idx,
p_cache);
} else {
Expand Down Expand Up @@ -283,9 +281,9 @@ void TestMAPStat(Context const* ctx) {

predt.SetDevice(ctx->Device());
auto rank_idx =
p_cache->SortedIdx(ctx, ctx->IsCPU() ? predt.ConstHostSpan() : predt.ConstDeviceSpan());
p_cache->SortedIdx(ctx, !ctx->IsCUDA() ? predt.ConstHostSpan() : predt.ConstDeviceSpan());

if (ctx->IsCPU()) {
if (!ctx->IsCUDA()) {
obj::cpu_impl::MAPStat(ctx, info.labels.HostView().Slice(linalg::All(), 0), rank_idx,
p_cache);
} else {
Expand Down
28 changes: 12 additions & 16 deletions tests/cpp/objective/test_objective.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,23 +29,19 @@ TEST(Objective, PredTransform) {
size_t n = 100;

for (const auto& entry : ::dmlc::Registry<::xgboost::ObjFunctionReg>::List()) {
// SYCL implementations are skipped for this test
const std::string sycl_postfix = "sycl";
if ((entry->name.size() >= sycl_postfix.size()) && !std::equal(sycl_postfix.rbegin(), sycl_postfix.rend(), entry->name.rbegin())) {
std::unique_ptr<xgboost::ObjFunction> obj{xgboost::ObjFunction::Create(entry->name, &tparam)};
if (entry->name.find("multi") != std::string::npos) {
obj->Configure(Args{{"num_class", "2"}});
}
if (entry->name.find("quantile") != std::string::npos) {
obj->Configure(Args{{"quantile_alpha", "0.5"}});
}
HostDeviceVector<float> predts;
predts.Resize(n, 3.14f); // prediction is performed on host.
ASSERT_FALSE(predts.DeviceCanRead());
obj->PredTransform(&predts);
ASSERT_FALSE(predts.DeviceCanRead());
ASSERT_TRUE(predts.HostCanWrite());
std::unique_ptr<xgboost::ObjFunction> obj{xgboost::ObjFunction::Create(entry->name, &tparam)};
if (entry->name.find("multi") != std::string::npos) {
obj->Configure(Args{{"num_class", "2"}});
}
if (entry->name.find("quantile") != std::string::npos) {
obj->Configure(Args{{"quantile_alpha", "0.5"}});
}
HostDeviceVector<float> predts;
predts.Resize(n, 3.14f); // prediction is performed on host.
ASSERT_FALSE(predts.DeviceCanRead());
obj->PredTransform(&predts);
ASSERT_FALSE(predts.DeviceCanRead());
ASSERT_TRUE(predts.HostCanWrite());
}
}

Expand Down
Loading

0 comments on commit bd83108

Please sign in to comment.