Skip to content

Commit

Permalink
Merge pull request #42 from gsbDBI/41-how-to-model-outside-option
Browse files Browse the repository at this point in the history
Add implementation of outside option to the main branch.
  • Loading branch information
TianyuDu authored Apr 22, 2024
2 parents 8fff250 + 95a052f commit 9a071c4
Show file tree
Hide file tree
Showing 7 changed files with 840 additions and 7 deletions.
28 changes: 25 additions & 3 deletions torch_choice/data/choice_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
Update: Jan. 2, 2023
"""
import copy
import warnings
from typing import Dict, List, Optional, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -206,7 +207,14 @@ def num_users(self) -> int:
if self._num_users is not None:
return self._num_users
elif self.user_index is not None:
# infer from the number of unique items in user_index.
num_unique = len(torch.unique(self.user_index))
expected_num_users = int(self.user_index.max()) + 1
if num_unique != expected_num_users:
warnings.warn(f"The number of users is inferred from the number of unique users in the user_index tensor. The user_index tensor in the ChoiceDataset ranges from {int(self.user_index.min())} to {int(self.user_index.max())}. The ChoiceDataset assumes user_index to be 0-indexed and encoded using consecutive integers. There are {expected_num_users} users expected given max(user_index). However, there are {num_unique} unique values in the user_index . This could be caused by missing users in the dataset (i.e., some users are not in user_index at all). If this is not expected, please check the user_index tensor. For a safer behavior, please provide the number of users explicitly by using the num_users keyword while initializing the ChoiceDataset class.")
else:
warnings.warn(f"The number of users is inferred from the number of unique users in the user_index tensor. This might lead to unexpected behaviors if some users never appeared in the user_index tensor. For a safer behavior, please provide the number of users explicitly by using the num_users keyword while initializing the ChoiceDataset class.")

# infer from the number of unique users using the user_index.
return len(torch.unique(self.user_index))
else:
return 1
Expand All @@ -223,7 +231,15 @@ def num_items(self) -> int:
return self._num_items
else:
# infer the number of items from item_index.
return len(torch.unique(self.item_index))
# the -1 is an optional special symbol for outside option, do not count it towards the number of items.
num_unique = len(torch.unique(self.item_index[self.item_index != -1]))
expected_num_items = int(self.item_index[self.item_index != -1].max()) + 1
if num_unique != expected_num_items:
warnings.warn(f"The number of items is inferred from the number of unique items, excluding -1's denoting outside options, in the item_index tensor. The item_index tensor in the ChoiceDataset ranges from {int(self.item_index[self.item_index != -1].min())} to {int(self.item_index[self.item_index != -1].max())}, excluding -1's. The ChoiceDataset assumes item_index to be 0-indexed and encoded using consecutive integers. There are {expected_num_items} items expected given max(item_index). However, there are {num_unique} unique values in item_index. This could be caused by missing items in the dataset (i.e., some items are not in item_index at all). If this is not expected, please check the item_index tensor. For a safer behavior, please provide the number of items explicitly by using the num_items keyword while initializing the ChoiceDataset class.")
else:
warnings.warn(f"The number of items is inferred from the number of unique items, excluding -1's denoting outside options, in the item_index tensor. This might lead to unexpected behaviors if some items never appeared in the item_index tensor. For a safer behavior, please provide the number of items explicitly by using the num_items keyword while initializing the ChoiceDataset class.")

return len(torch.unique(self.item_index[self.item_index != -1]))

@property
def num_sessions(self) -> int:
Expand All @@ -236,6 +252,12 @@ def num_sessions(self) -> int:
# return the _num_sessions provided in the constructor.
return self._num_sessions
else:
num_unique = len(torch.unique(self.session_index))
expected_num_sessions = int(self.session_index.max()) + 1
if num_unique != expected_num_sessions:
warnings.warn(f"The number of sessions is inferred from the number of unique sessions in the session_index tensor. The session_index tensor in the ChoiceDataset ranges from {int(self.session_index.min())} to {int(self.session_index.max())}. The ChoiceDataset assumes session_index to be 0-indexed and encoded using consecutive integers. There are {expected_num_sessions} sessions expected given max(session_index). However, there are {num_unique} unique values in the session_index . This could be caused by missing sessions in the dataset (i.e., some sessions are not in session_index at all). If this is not expected, please check the session_index tensor. For a safer behavior, please provide the number of sessions explicitly by using the num_sessions keyword while initializing the ChoiceDataset class.")
else:
warnings.warn(f"The number of sessions is inferred from the number of unique sessions in the session_index tensor. This might lead to unexpected behaviors if some sessions never appeared in the session_index tensor. For a safer behavior, please provide the number of sessions explicitly by using the num_sessions keyword while initializing the ChoiceDataset class.")
# infer the number of sessions from session_index.
return len(torch.unique(self.session_index))

Expand Down Expand Up @@ -451,7 +473,7 @@ def _expand_tensor(self, key: str, val: torch.Tensor) -> torch.Tensor:
else:
raise ValueError(f'Warning: the input key {key} is not an attribute of the dataset, will NOT modify the provided tensor.')

assert out.shape == (len(self), self.num_items, num_params)
assert out.shape == (len(self), self.num_items, num_params), f'Error: the output shape {out.shape} is not correct, expected: {(len(self), self.num_items, num_params)}.'
return out

@staticmethod
Expand Down
25 changes: 24 additions & 1 deletion torch_choice/model/conditional_logit_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ def __init__(self,
num_users: Optional[int]=None,
regularization: Optional[str]=None,
regularization_weight: Optional[float]=None,
weight_initialization: Optional[Union[str, Dict[str, str]]]=None
weight_initialization: Optional[Union[str, Dict[str, str]]]=None,
model_outside_option: Optional[bool]=False
) -> None:
"""
Args:
Expand Down Expand Up @@ -106,6 +107,12 @@ def __init__(self,
Alternatively, users can pass a dictionary with keys exactly the same as the `coef_variation_dict` dictionary,
and values from {'normal', 'uniform', 'zero'} to initialize coefficients of different types of variables differently.
By default, all coefficients are initialized following a standard normal distribution.
model_outside_option (Optional[bool]): whether to explicitly model the outside option (i.e., the consumer did not buy anything).
To enable modeling outside option, the outside option is indicated by `item_index[n] == -1` in the item-index-tensor.
In this case, the item-index-tensor can contain values in `{-1, 0, 1, ..., num_items-1}`.
Otherwise, if the outside option is not modelled, the item-index-tensor should only contain values in `{0, 1, ..., num_items-1}`.
The utility of the outside option is always set to 0 while computing the probability.
By default, model_outside_option is set to False and the model does not model the outside option.
"""
# ==============================================================================================================
# Check that the model received a valid combination of inputs so that it can be initialized.
Expand Down Expand Up @@ -197,6 +204,7 @@ def __init__(self,
# A ModuleDict is required to properly register all trainable parameters.
# self.parameter() will fail if a python dictionary is used instead.
self.coef_dict = nn.ModuleDict(coef_dict)
self.model_outside_option = model_outside_option

def __repr__(self) -> str:
"""Return a string representation of the model.
Expand Down Expand Up @@ -275,6 +283,13 @@ def forward(self,
if batch.item_availability is not None:
# mask out unavailable items.
total_utility[~batch.item_availability[batch.session_index, :]] = torch.finfo(total_utility.dtype).min / 2

# accommodate the outside option.
if self.model_outside_option:
# the outside option has zero utility.
util_zero = torch.zeros(total_utility.size(0), 1, device=batch.device) # (len(batch), 1) zero tensor.
# outside option is indicated by item_index == -1, we put it at the end.
total_utility = torch.cat((total_utility, util_zero), dim=1) # (len(batch), num_items+1)
return total_utility


Expand All @@ -297,7 +312,15 @@ def negative_log_likelihood(self, batch: ChoiceDataset, y: torch.Tensor, is_trai
self.eval()
# (num_trips, num_items)
total_utility = self.forward(batch)
# check shapes.
if self.model_outside_option:
assert total_utility.shape == (len(batch), self.num_items+1)
assert torch.all(total_utility[:, -1] == 0), "The last column of total_utility should be all zeros, which corresponds to the outside option."
else:
assert total_utility.shape == (len(batch), self.num_items)
logP = torch.log_softmax(total_utility, dim=1)
# since y == -1 indicates the outside option and the last column of total_utility is the outside option, the following
# indexing should correctly retrieve the log-likelihood even for outside options.
nll = - logP[torch.arange(len(y)), y].sum()
return nll

Expand Down
34 changes: 32 additions & 2 deletions torch_choice/model/nested_logit_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ def __init__(self,
regularization: Optional[str]=None,
regularization_weight: Optional[float]=None,
nest_weight_initialization: Optional[Union[str, Dict[str, str]]]=None,
item_weight_initialization: Optional[Union[str, Dict[str, str]]]=None
item_weight_initialization: Optional[Union[str, Dict[str, str]]]=None,
model_outside_option: Optional[bool]=False
) -> None:
"""Initialization method of the nested logit model.
Expand Down Expand Up @@ -91,6 +92,13 @@ def __init__(self,
{nest, item}_weight_initialization (Optional[Union[str, Dict[str, str]]]): methods to initialize the weights of
coefficients for {nest, item} level model. Please refer to the `weight_initialization` keyword in ConditionalLogitModel's documentation for more details.
model_outside_option (Optional[bool]): whether to explicitly model the outside option (i.e., the consumer did not buy anything).
To enable modeling outside option, the outside option is indicated by `item_index[n] == -1` in the item-index-tensor.
In this case, the item-index-tensor can contain values in `{-1, 0, 1, ..., num_items-1}`.
Otherwise, if the outside option is not modelled, the item-index-tensor should only contain values in `{0, 1, ..., num_items-1}`.
The utility of the outside option is always set to 0 while computing the probability.
By default, model_outside_option is set to False and the model does not model the outside option.
"""
# handle nest level model.
using_formula_to_initiate = (item_formula is not None) and (nest_formula is not None)
Expand Down Expand Up @@ -158,6 +166,8 @@ def __init__(self,
if (self.regularization is None) and (self.regularization_weight is not None):
raise ValueError(f'You specified no regularization but you provide regularization_weight={self.regularization_weight}, you should leave regularization_weight as None if you do not want to regularize the model.')

self.model_outside_option = model_outside_option

@property
def num_params(self) -> int:
"""Get the total number of parameters. For example, if there is only an user-specific coefficient to be multiplied
Expand Down Expand Up @@ -301,7 +311,7 @@ def _forward(self,
Y += coef(item_x_dict[corresponding_observable], user_index)

if item_availability is not None:
Y[~item_availability] =torch.finfo(Y.dtype).min / 2
Y[~item_availability] = torch.finfo(Y.dtype).min / 2

# =============================================================================
# compute the inclusive value of each nest.
Expand All @@ -321,6 +331,13 @@ def _forward(self,
# logP_item[t, i] = log P(ni|Bk), where Bk is the nest item i is in, n is the user in trip t.
logP_item = Y - I # (T, num_items)

if self.model_outside_option:
# if the model explicitly models the outside option, we need to add a column of zeros to logP_item.
# log P(ni|Bk) = 0 for the outside option since Y = 0 and the outside option has its own nest.
logP_item = torch.cat((logP_item, torch.zeros(T, 1).to(device)), dim=1)
assert logP_item.shape == (T, self.num_items+1)
assert torch.all(logP_item[:, -1] == 0)

# =============================================================================
# logP_nest[t, i] = log P(Bk), for item i in trip t, the probability of choosing the nest/bucket
# item i belongs to. logP_nest has shape (T, num_items)
Expand All @@ -331,6 +348,12 @@ def _forward(self,
logit[:, Bk] = (W[:, k] + self.lambdas[k] * inclusive_value[k]).view(-1, 1) # (T, |Bk|)
# only count each nest once in the logsumexp within the nest level model.
cols = [x[0] for x in self.nest_to_item.values()]
if self.model_outside_option:
# the last column corresponds to the outside option, which has W+lambda*I = 0 since W = I = Y = 0 for the outside option.
logit = torch.cat((logit, torch.zeros(T, 1).to(device)), dim=1)
assert logit.shape == (T, self.num_items+1)
# we have already added W+lambda*I for each "actual" nest, now we add the "fake" nest for the outside option.
cols.append(-1)
logP_nest = logit - torch.logsumexp(logit[:, cols], dim=1, keepdim=True)

# =============================================================================
Expand Down Expand Up @@ -372,6 +395,13 @@ def negative_log_likelihood(self,
self.eval()
# (num_trips, num_items)
logP = self.forward(batch)
# check shapes
if self.model_outside_option:
assert logP.shape == (len(batch['item']), self.num_items+1)
else:
assert logP.shape == (len(batch['item']), self.num_items)
# since y == -1 indicates the outside option and the last column of total_utility is the outside option, the following
# indexing should correctly retrieve the log-likelihood even for outside options.
nll = - logP[torch.arange(len(y)), y].sum()
return nll

Expand Down
16 changes: 16 additions & 0 deletions torch_choice/utils/easy_data_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ def __init__(self,
price_observable_columns: Optional[List[str]] = None,
itemsession_observable_columns: Optional[List[str]] = None,
useritemsession_observable_columns: Optional[List[str]] = None,
num_items: Optional[int] = None,
num_users: Optional[int] = None,
num_sessions: Optional[int] = None,
# Misc.
device: str = 'cpu'):
"""The initialization method of EasyDatasetWrapper.
Expand Down Expand Up @@ -90,6 +93,12 @@ def __init__(self,
The itemsession_observable_column is an alias for the `price_observable_column` argument for backward compatibility,
all elements of `itemsession_observable_columns` will be appended to `price_observable_column`.
num_items (Optional[int], optional): the number of items in the dataset to pass to the ChoiceDataset. Defaults to None.
num_users (Optional[int], optional): the number of users in the dataset to pass to the ChoiceDataset. Defaults to None.
num_sessions (Optional[int], optional): the number of sessions in the dataset to pass to the ChoiceDataset. Defaults to None.
Raises:
ValueError: _description_
"""
Expand Down Expand Up @@ -139,6 +148,10 @@ def __init__(self,

self.observable_data_to_observable_tensors()

# read in explicit numbers of items, users, and sessions.
self._num_items = num_items
self._num_users = num_users
self._num_sessions = num_sessions
self.create_choice_dataset()
print('Finished Creating Choice Dataset.')

Expand Down Expand Up @@ -334,6 +347,9 @@ def create_choice_dataset(self) -> None:
user_index=torch.LongTensor(self.user_index) if self.user_index is not None else None,
session_index=torch.LongTensor(self.session_index) if self.session_index is not None else None,
item_availability=self.item_availability,
num_items=self._num_items,
num_users=self._num_users,
num_sessions=self._num_sessions,
# keyword arguments for observables.
**self.item_observable_tensors,
**self.user_observable_tensors,
Expand Down
12 changes: 11 additions & 1 deletion torch_choice/utils/run_helper_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,17 @@ def run(model: Union [ConditionalLogitModel, NestedLogitModel],
if isinstance(model, ConditionalLogitModel):
def nll_loss(model):
y_pred = model(dataset_for_std)
return F.cross_entropy(y_pred, dataset_for_std.item_index, reduction='sum')
item_index = dataset_for_std.item_index.clone()
if model.model_outside_option:
assert y_pred.shape == (len(dataset_for_std), model.num_items+1)
# y_pred has shape (len(dataset_for_std.choice_set), model.num_items+1) since the last column is the probability of the outside option.
# F.cross_entropy is not smart enough to handle the -1 outside option in y.
# Even though y_pred[:, -1] nad y_pred[:, model.num_items] are the same, F.cross_entropy does not know.
# We need to fix it manually.
# manually modify the index for the outside option.
item_index[item_index == -1] = model.num_items

return F.cross_entropy(y_pred, item_index, reduction='sum')
elif isinstance(model, NestedLogitModel):
def nll_loss(model):
d = dataset_for_std[torch.arange(len(dataset_for_std))]
Expand Down
Loading

0 comments on commit 9a071c4

Please sign in to comment.