Skip to content

Commit

Permalink
Improve (simpler, a bit faster) the splitter multithreading code.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 690660645
  • Loading branch information
achoum authored and copybara-github committed Oct 28, 2024
1 parent 195970b commit d0516e0
Show file tree
Hide file tree
Showing 5 changed files with 338 additions and 322 deletions.
1 change: 1 addition & 0 deletions yggdrasil_decision_forests/learner/decision_tree/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ cc_test(
"//yggdrasil_decision_forests/utils:test",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/log",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_googletest//:gtest_main",
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include "gtest/gtest.h"
#include "absl/container/flat_hash_set.h"
#include "absl/log/log.h"
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "yggdrasil_decision_forests/dataset/data_spec.h"
#include "yggdrasil_decision_forests/dataset/data_spec.pb.h"
Expand Down Expand Up @@ -81,35 +82,31 @@ struct FakeLabelStats : LabelStats {};
// A fake consumer that persistently fails to find a valid attribute.
SplitterWorkResponse FakeFindBestConditionConcurrentConsumerAlwaysInvalid(
SplitterWorkRequest request) {
return SplitterWorkResponse{
.manager_data = request.manager_data,
.status = SplitSearchResult::kInvalidAttribute,
};
return SplitterWorkResponse(request.manager_data,
SplitSearchResult::kInvalidAttribute, {});
}

// A fake consumer that sets the split score to 10 times the request index.
SplitterWorkResponse FakeFindBestConditionConcurrentConsumerMultiplicative(
SplitterWorkRequest request) {
SplitterWorkResponse response{
.manager_data = request.manager_data,
.status = SplitSearchResult::kBetterSplitFound,
};
request.condition->set_split_score(request.attribute_idx * 10.f);
SplitterWorkResponse response(request.manager_data,
SplitSearchResult::kBetterSplitFound,
absl::make_unique<proto::NodeCondition>());
response.condition->set_split_score(request.attribute_idx * 10.f);
return response;
}

// A fake consumer that fails if the attribute_idx is even, and set the score to
// 10 times the attribute_idx otherwise.
SplitterWorkResponse FakeFindBestConditionConcurrentConsumerAlternate(
SplitterWorkRequest request) {
auto response = SplitterWorkResponse{
.manager_data = request.manager_data,
.status = SplitSearchResult::kBetterSplitFound,
};
auto response = SplitterWorkResponse(
request.manager_data, SplitSearchResult::kBetterSplitFound,
absl::make_unique<proto::NodeCondition>());
if (request.attribute_idx % 2 == 0) {
response.status = SplitSearchResult::kInvalidAttribute;
}
request.condition->set_split_score(request.attribute_idx * 10.f);
response.condition->set_split_score(request.attribute_idx * 10.f);
return response;
}

Expand Down Expand Up @@ -2128,7 +2125,6 @@ TEST(DecisionTree, FindBestConditionConcurrentManager_AlwaysInvalid) {

EXPECT_EQ(cache.splitter_cache_list.size(), 2);
EXPECT_EQ(cache.durable_response_list.size(), 20);
EXPECT_EQ(cache.condition_list.size(), 4);
EXPECT_FALSE(result);
}

Expand Down Expand Up @@ -2264,7 +2260,6 @@ TEST(DecisionTree, FindBestConditionConcurrentManagerScaled) {

EXPECT_EQ(cache.splitter_cache_list.size(), 10);
EXPECT_EQ(cache.durable_response_list.size(), 100);
EXPECT_EQ(cache.condition_list.size(), 20);
EXPECT_FALSE(result);

random.seed(4321);
Expand Down
Loading

0 comments on commit d0516e0

Please sign in to comment.