Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add --include_label_tables argument to optionally include (time-censored) labels as features in the db #272

Open
wants to merge 10 commits into
base: main
Choose a base branch
from

Conversation

adobles96
Copy link
Collaborator

We add the --include_label_tables argument which can take one of three values:

  1. "none" (default): Label tables won't be added to the db. The resulting graph is as before.
  2. "task_only": Will include label information for the current task only. I.e. if we're training a GNN for the hm item-sales task, only the labels from that task will be included.
  3. "all": Includes label information for all the tasks defined on the dataset. I.e. if we're training a GNN for the hm item-sales task, the labels for hm item-sales, item-churn, user-churn, etc. will be included.

Implementation details

The implementation is quite simple. If we wish to include a particular label table, we load it as a pandas dataframe, modify it's timestamp column by adding the timedelta of the corresponding task, and add it as a new table to the relbench Database object.

The adding of timedelta to the timestamp column is crucial to avoid leakage. For example, in a task meant to predict sales of an item in the next month, the original timestamp column in the label table has the date as of which the prediction should be made. Let's say this is Jan-01, in order to predict the sales in the month of January. When we add timedelta (which equals 1 month in this case), the new value for the timestamp becomes Feb-01, which means the label information (i.e. the number of sales for the month of January) is now censored before Feb-01. In other words, the GNN will have access to the number of sales for the month of January starting on Feb-01. It is worth noting that this only works so long as we uphold the convention that timedelta is constant within a given task.

In the end, this results in a graph with new "label" nodes that hold the label values at previous times and have an edge to the relevant entity node, thus making that information available to the GNN in a single hop.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants