Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refreshing API and #2

Open
8 tasks
TianyuDu opened this issue Aug 6, 2022 · 0 comments
Open
8 tasks

Refreshing API and #2

TianyuDu opened this issue Aug 6, 2022 · 0 comments
Assignees

Comments

@TianyuDu
Copy link
Collaborator

TianyuDu commented Aug 6, 2022

Towards BEMB v1.0.

Corresponding branch: api-update

We are planning to refine and expand the current API of BEMBFlex.

The pred_item and multiple class prediction.

  • Note: this is a non-trivial extension: for each observation (consumer pick an item), we compute a scalar utility function for each available item. There is a single scalar utility computed from the chosen item (item_index[i]). This only allows us to do binary classification . We might have to drop this feature.
  • Currently the model supports predicting binary batch.label or multi-class batch.item_index. We plan to support arbitrary multi-class classifications.
    • In particular, you don't need to change anything if pred_item=True, the model will know the number of classes is exactly the num_items parameter. Also, in this case, your ChoiceDataset object does not need to have a label attribute, since the model will look for the item_index as the ground truth for training.
    • In contrast, if pred_item=False, now you need to supply a num_classes to the BEMBFlex.__init__() method. Also, you would need a label attribute in the ChoiceDataset object. The label attribute should be a LongTensor with values from {0, 1, ..., num_classes}.

Post-Estimation

  • Thanks to feedbacks from our valued users, we are planning to reorganize our post-estimation prediction methods for better user experience.
    • We will implement a method called predict_proba(), the same name as inference methods of scikit-learn models.
    • This method will have @torch.no_grad() as a decorator, so you can use it however you want without being worried about gradient tracking.
    • With pred_items = True, the batch needs item_index attribute only if it's involved in the utility computation (e.g., within-category computation).
    • With pred_items = False, the batch does not need to have a label attribute.
    • The preliminary API of predict_proba() is used as the following:
batch = ChoiceDataset(...)
bemb = BEMBFlex(..., pred_item=True, ...)
proba = bemb.predict_proba(batch)  # shape = (len(batch), num_items)

batch = ChoiceDataset(...)
# not that batch doesn't need to have a label attribute.
bemb = BEMBFlex(..., pred_item=False, num_classes=..., ...)
proba = bemb.predict_proba(batch)  # shape = (len(batch), num_classes)

Renaming Variables.

  • We received feedbacks that the naming of price-variation is ambiguous, we propose to change it to sessionitem-variation instead (this is precisely the definition of such variables).
@TianyuDu TianyuDu self-assigned this Aug 6, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant