You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
We are planning to refine and expand the current API of BEMBFlex.
The pred_item and multiple class prediction.
Note: this is a non-trivial extension: for each observation (consumer pick an item), we compute a scalar utility function for each available item. There is a single scalar utility computed from the chosen item (item_index[i]). This only allows us to do binary classification . We might have to drop this feature.
Currently the model supports predicting binarybatch.label or multi-class batch.item_index. We plan to support arbitrary multi-class classifications.
In particular, you don't need to change anything if pred_item=True, the model will know the number of classes is exactly the num_items parameter. Also, in this case, your ChoiceDataset object does not need to have a label attribute, since the model will look for the item_index as the ground truth for training.
In contrast, if pred_item=False, now you need to supply a num_classes to the BEMBFlex.__init__() method. Also, you would need a label attribute in the ChoiceDataset object. The label attribute should be a LongTensor with values from {0, 1, ..., num_classes}.
Post-Estimation
Thanks to feedbacks from our valued users, we are planning to reorganize our post-estimation prediction methods for better user experience.
We will implement a method called predict_proba(), the same name as inference methods of scikit-learn models.
This method will have @torch.no_grad() as a decorator, so you can use it however you want without being worried about gradient tracking.
With pred_items = True, the batch needs item_index attribute only if it's involved in the utility computation (e.g., within-category computation).
With pred_items = False, the batch does not need to have a label attribute.
The preliminary API of predict_proba() is used as the following:
batch = ChoiceDataset(...)
bemb = BEMBFlex(..., pred_item=True, ...)
proba = bemb.predict_proba(batch) # shape = (len(batch), num_items)
batch = ChoiceDataset(...)
# not that batch doesn't need to have a label attribute.
bemb = BEMBFlex(..., pred_item=False, num_classes=..., ...)
proba = bemb.predict_proba(batch) # shape = (len(batch), num_classes)
Renaming Variables.
We received feedbacks that the naming of price-variation is ambiguous, we propose to change it to sessionitem-variation instead (this is precisely the definition of such variables).
The text was updated successfully, but these errors were encountered:
Towards BEMB
v1.0
.Corresponding branch:
api-update
We are planning to refine and expand the current API of
BEMBFlex
.The
pred_item
and multiple class prediction.item_index[i]
). This only allows us to do binary classification . We might have to drop this feature.batch.label
or multi-classbatch.item_index
. We plan to support arbitrary multi-class classifications.pred_item=True
, the model will know the number of classes is exactly thenum_items
parameter. Also, in this case, yourChoiceDataset
object does not need to have alabel
attribute, since the model will look for theitem_index
as the ground truth for training.pred_item=False
, now you need to supply anum_classes
to theBEMBFlex.__init__()
method. Also, you would need alabel
attribute in theChoiceDataset
object. Thelabel
attribute should be aLongTensor
with values from{0, 1, ..., num_classes}
.Post-Estimation
predict_proba()
, the same name as inference methods of scikit-learn models.@torch.no_grad()
as a decorator, so you can use it however you want without being worried about gradient tracking.pred_items = True
, thebatch
needsitem_index
attribute only if it's involved in the utility computation (e.g., within-category computation).pred_items = False,
thebatch
does not need to have alabel
attribute.predict_proba()
is used as the following:Renaming Variables.
price
-variation is ambiguous, we propose to change it tosessionitem
-variation instead (this is precisely the definition of such variables).The text was updated successfully, but these errors were encountered: