Skip to content

Commit

Permalink
undo some changes
Browse files Browse the repository at this point in the history
  • Loading branch information
Dmitry Razdoburdin committed Dec 7, 2023
1 parent 6fa0f7f commit 7c91a31
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 13 deletions.
1 change: 0 additions & 1 deletion src/learner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -790,7 +790,6 @@ class LearnerConfiguration : public Learner {
// Rename one of them once binary IO is gone.
cfg_["max_delta_step"] = kMaxDeltaStepDefaultValue;
}

if (obj_ == nullptr || tparam_.objective != old.objective) {
obj_.reset(ObjFunction::Create(tparam_.objective, &ctx_));
}
Expand Down
18 changes: 6 additions & 12 deletions tests/cpp/objective/test_multiclass_obj.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,13 @@
namespace xgboost {

void TestSoftmaxMultiClassObjGPair(const Context* ctx) {
std::string obj_name = "multi:softmax";

std::vector<std::pair<std::string, std::string>> args {{"num_class", "3"}};
std::unique_ptr<ObjFunction> obj {
ObjFunction::Create(obj_name, ctx)
ObjFunction::Create("multi:softmax", ctx)
};

obj->Configure(args);
CheckConfigReload(obj, obj_name);
CheckConfigReload(obj, "multi:softmax");

CheckObjFunction(obj,
{1.0f, 0.0f, 2.0f, 2.0f, 0.0f, 1.0f}, // preds
Expand All @@ -38,14 +36,12 @@ void TestSoftmaxMultiClassObjGPair(const Context* ctx) {
}

void TestSoftmaxMultiClassBasic(const Context* ctx) {
std::string obj_name = "multi:softmax";

std::vector<std::pair<std::string, std::string>> args{
std::pair<std::string, std::string>("num_class", "3")};

std::unique_ptr<ObjFunction> obj{ObjFunction::Create(obj_name, ctx)};
std::unique_ptr<ObjFunction> obj{ObjFunction::Create("multi:softmax", ctx)};
obj->Configure(args);
CheckConfigReload(obj, obj_name);
CheckConfigReload(obj, "multi:softmax");

HostDeviceVector<bst_float> io_preds = {2.0f, 0.0f, 1.0f,
1.0f, 0.0f, 2.0f};
Expand All @@ -60,16 +56,14 @@ void TestSoftmaxMultiClassBasic(const Context* ctx) {
}

void TestSoftprobMultiClassBasic(const Context* ctx) {
std::string obj_name = "multi:softprob";

std::vector<std::pair<std::string, std::string>> args {
std::pair<std::string, std::string>("num_class", "3")};

std::unique_ptr<ObjFunction> obj {
ObjFunction::Create(obj_name, ctx)
ObjFunction::Create("multi:softprob", ctx)
};
obj->Configure(args);
CheckConfigReload(obj, obj_name);
CheckConfigReload(obj, "multi:softprob");

HostDeviceVector<bst_float> io_preds = {2.0f, 0.0f, 1.0f};
std::vector<bst_float> out_preds = {0.66524096f, 0.09003057f, 0.24472847f};
Expand Down

0 comments on commit 7c91a31

Please sign in to comment.