Skip to content

Commit

Permalink
[PYDF] Add a test for ranking
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 577760913
  • Loading branch information
rstz authored and copybara-github committed Oct 30, 2023
1 parent 5e6a00e commit 5d8f933
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 29 deletions.
12 changes: 11 additions & 1 deletion yggdrasil_decision_forests/port/python/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,16 @@

### Features

- Add support for ranking tasks
- Add support for reading from path, supporting multiple data types (csv,
tfrecord, ...)

#### Release music

Rhapsody in Blue. George Gershwin

### Features

- Add `model.distance(...)` to compute pairwise distance between examples.

## 0.0.3 - 2023-10-20
Expand All @@ -16,7 +26,7 @@
- Tree leaves retrieval
- C++ base updated to 1.7.0

### Release music
#### Release music

Schweigt stille, plaudert nicht (BWV 211). Johann Sebastian Bach

Expand Down
61 changes: 33 additions & 28 deletions yggdrasil_decision_forests/port/python/ydf/learner/learner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,27 @@ def test_with_validation(self):
label="income", num_trees=50
).train(train_ds, valid=test_ds)

def test_compare_pandas_and_path(self):
dataset_directory = os.path.join(test_utils.ydf_test_data_path(), "dataset")
train_path = os.path.join(dataset_directory, "adult_train.csv")
test_path = os.path.join(dataset_directory, "adult_test.csv")
label = "income"

pd_train = pd.read_csv(train_path)
pd_test = pd.read_csv(test_path)

learner = specialized_learners.RandomForestLearner(label=label)
model_from_pd = learner.train(pd_train)
accuracy_from_pd = model_from_pd.evaluate(pd_test).accuracy

learner_from_path = specialized_learners.RandomForestLearner(
label=label, data_spec=model_from_pd.data_spec()
)
model_from_path = learner_from_path.train(train_path)
accuracy_from_path = model_from_path.evaluate(pd_test).accuracy

self.assertAlmostEqual(accuracy_from_path, accuracy_from_pd)


class CARTLearnerTest(LearnerTest):

Expand Down Expand Up @@ -517,38 +538,22 @@ def test_with_validation(self):
logging.info("evaluation:\n%s", evaluation)
self.assertAlmostEqual(evaluation.accuracy, 0.87, 1)

def test_adult_from_csv(self):
def test_ranking(self):
dataset_directory = os.path.join(test_utils.ydf_test_data_path(), "dataset")
train_path = os.path.join(dataset_directory, "adult_train.csv")
test_path = os.path.join(dataset_directory, "adult_test.csv")
label = "income"
train_path = os.path.join(dataset_directory, "synthetic_ranking_train.csv")
test_path = os.path.join(dataset_directory, "synthetic_ranking_test.csv")
label = "LABEL"
ranking_group = "GROUP"

learner = specialized_learners.RandomForestLearner(label=label)

model = learner.train(train_path)
accuracy = model.evaluate(test_path).accuracy
self.assertGreaterEqual(accuracy, 0.864)

def test_compare_pandas_and_path(self):
dataset_directory = os.path.join(test_utils.ydf_test_data_path(), "dataset")
train_path = os.path.join(dataset_directory, "adult_train.csv")
test_path = os.path.join(dataset_directory, "adult_test.csv")
label = "income"

pd_train = pd.read_csv(train_path)
pd_test = pd.read_csv(test_path)

learner = specialized_learners.RandomForestLearner(label=label)
model_from_pd = learner.train(pd_train)
accuracy_from_pd = model_from_pd.evaluate(pd_test).accuracy

learner_from_path = specialized_learners.RandomForestLearner(
label=label, data_spec=model_from_pd.data_spec()
learner = specialized_learners.GradientBoostedTreesLearner(
label=label,
ranking_group=ranking_group,
task=generic_learner.Task.RANKING,
)
model_from_path = learner_from_path.train(train_path)
accuracy_from_path = model_from_path.evaluate(pd_test).accuracy

self.assertAlmostEqual(accuracy_from_path, accuracy_from_pd)
model = learner.train(train_path)
evaluation = model.evaluate(test_path)
self.assertAlmostEqual(evaluation.ndcg, 0.71, places=1)


class UtilityTest(LearnerTest):
Expand Down

0 comments on commit 5d8f933

Please sign in to comment.