Add --include_label_tables
argument to optionally include (time-censored) labels as features in the db
#272
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
We add the
--include_label_tables
argument which can take one of three values:"none"
(default): Label tables won't be added to the db. The resulting graph is as before."task_only"
: Will include label information for the current task only. I.e. if we're training a GNN for thehm
item-sales
task, only the labels from that task will be included."all"
: Includes label information for all the tasks defined on the dataset. I.e. if we're training a GNN for thehm
item-sales
task, the labels forhm
item-sales
,item-churn
,user-churn
, etc. will be included.Implementation details
The implementation is quite simple. If we wish to include a particular label table, we load it as a pandas dataframe, modify it's
timestamp
column by adding thetimedelta
of the corresponding task, and add it as a new table to the relbenchDatabase
object.The adding of
timedelta
to thetimestamp
column is crucial to avoid leakage. For example, in a task meant to predict sales of an item in the next month, the originaltimestamp
column in the label table has the date as of which the prediction should be made. Let's say this is Jan-01, in order to predict the sales in the month of January. When we addtimedelta
(which equals 1 month in this case), the new value for the timestamp becomes Feb-01, which means the label information (i.e. the number of sales for the month of January) is now censored before Feb-01. In other words, the GNN will have access to the number of sales for the month of January starting on Feb-01. It is worth noting that this only works so long as we uphold the convention thattimedelta
is constant within a given task.In the end, this results in a graph with new "label" nodes that hold the label values at previous times and have an edge to the relevant entity node, thus making that information available to the GNN in a single hop.