Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gamma coeff #38

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 76 additions & 12 deletions bemb/model/bayesian_coefficient.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
"""
from typing import Optional, Tuple, Union

import numpy as np
import torch
import torch.nn as nn
from torch.distributions.lowrank_multivariate_normal import LowRankMultivariateNormal
from torch.distributions.gamma import Gamma


class BayesianCoefficient(nn.Module):
Expand All @@ -21,7 +23,8 @@ def __init__(self,
num_obs: Optional[int] = None,
dim: int = 1,
prior_mean: float = 0.0,
prior_variance: Union[float, torch.Tensor] = 1.0
prior_variance: Union[float, torch.Tensor] = 1.0,
distribution: str = 'gaussian'
) -> None:
"""The Bayesian coefficient object represents a learnable tensor mu_i in R^k, where i is from a family (e.g., user, item)
so there are num_classes * num_obs learnable weights in total.
Expand Down Expand Up @@ -63,12 +66,34 @@ def __init__(self,
If a tensor with shape (num_classes, dim) is supplied, supplying a (num_classes, dim) tensor is amount
to specifying a different prior variance for each entry in the coefficient.
Defaults to 1.0.
distribution (str, optional): the distribution of the coefficient. Currently we support 'gaussian' and 'gamma'.
Defaults to 'gaussian'.
"""
super(BayesianCoefficient, self).__init__()
# do we use this at all? TODO: drop self.variation.
assert variation in ['item', 'user', 'constant', 'category']

self.variation = variation

assert distribution in ['gaussian', 'gamma'], f'Unsupported distribution {distribution}'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
assert distribution in ['gaussian', 'gamma'], f'Unsupported distribution {distribution}'
if distribution not in ['gaussian', 'gamma']:
raise ValueError( f'Unsupported distribution {distribution}')

if distribution == 'gamma':
'''
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you comment out this code chunk using '''? Can you remove it if we don't need it here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes this is an artefact of this being a temporary commit

assert not obs2prior, 'Gamma distribution is not supported for obs2prior at present.'
mean = 1.0
variance = 10.0
assert mean > 0, 'Gamma distribution requires mean > 0'
assert variance > 0, 'Gamma distribution requires variance > 0'
# shape (concentration) is mean^2/variance, rate is variance/mean for Gamma distribution.
shape = prior_mean ** 2 / prior_variance
rate = prior_mean / prior_variance
prior_mean = np.log(shape)
prior_variance = rate
'''
prior_mean = np.log(prior_mean)
TianyuDu marked this conversation as resolved.
Show resolved Hide resolved
prior_variance = prior_variance

self.distribution = distribution

self.obs2prior = obs2prior
if variation == 'constant' or variation == 'category':
if obs2prior:
Expand All @@ -89,13 +114,15 @@ def __init__(self,
if self.obs2prior:
# the mean of prior distribution depends on observables.
# initiate a Bayesian Coefficient with shape (dim, num_obs) standard Gaussian.
prior_H_dist = 'gaussian'
self.prior_H = BayesianCoefficient(variation='constant',
num_classes=dim,
obs2prior=False,
dim=num_obs,
prior_variance=1.0,
H_zero_mask=self.H_zero_mask,
is_H=True) # this is a distribution responsible for the obs2prior H term.
is_H=True,
distribution=prior_H_dist) # this is a distribution responsible for the obs2prior H term.

else:
self.register_buffer(
Expand All @@ -114,13 +141,21 @@ def __init__(self,
num_classes, dim) * self.prior_variance)

# create variational distribution.
self.variational_mean_flexible = nn.Parameter(
torch.randn(num_classes, dim), requires_grad=True)
if self.distribution == 'gaussian':
self.variational_mean_flexible = nn.Parameter(
torch.randn(num_classes, dim), requires_grad=True)
# TOOD(kanodiaayush): initialize the gamma distribution variational mean in a more principled way.
elif self.distribution == 'gamma':
# initialize using uniform distribution between 0.5 and 1.5
# for a gamma distribution, we store the concentration as log(concentration) = variational_mean_flexible
self.variational_mean_flexible = nn.Parameter(
torch.rand(num_classes, dim) + 0.5, requires_grad=True)

if self.is_H and self.H_zero_mask is not None:
assert self.H_zero_mask.shape == self.variational_mean_flexible.shape, \
f"The H_zero_mask should have exactly the shape as the H variable, `H_zero_mask`.shape is {self.H_zero_mask.shape}, `H`.shape is {self.variational_mean_flexible.shape} "

# for gamma distribution, we store the rate as log(rate) = variational_logstd
self.variational_logstd = nn.Parameter(
torch.randn(num_classes, dim), requires_grad=True)

Expand Down Expand Up @@ -163,6 +198,10 @@ def variational_mean(self) -> torch.Tensor:
else:
M = self.variational_mean_fixed + self.variational_mean_flexible

if self.distribution == 'gamma':
# M = torch.pow(M, 2) + 0.000001
M = M.exp() / self.variational_logstd.exp()

if self.is_H and (self.H_zero_mask is not None):
# a H-variable with zero-entry restriction.
# multiply zeros to entries with H_zero_mask[i, j] = 1.
Expand Down Expand Up @@ -196,7 +235,11 @@ def log_prior(self,
Returns:
torch.Tensor: the log prior of the variable with shape (num_seeds, num_classes).
"""
# p(sample)
# DEBUG_MARKER
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove the debug marker?

'''
print(sample)
print('log_prior')
'''
num_seeds, num_classes, dim = sample.shape
# shape (num_seeds, num_classes)
if self.obs2prior:
Expand All @@ -211,9 +254,19 @@ def log_prior(self,

else:
mu = self.prior_zero_mean
out = LowRankMultivariateNormal(loc=mu,
cov_factor=self.prior_cov_factor,
cov_diag=self.prior_cov_diag).log_prob(sample)

if self.distribution == 'gaussian':
out = LowRankMultivariateNormal(loc=mu,
cov_factor=self.prior_cov_factor,
cov_diag=self.prior_cov_diag).log_prob(sample)
elif self.distribution == 'gamma':
concentration = torch.exp(mu)
rate = self.prior_variance
out = Gamma(concentration=concentration,
rate=rate).log_prob(sample)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might be a bug! The rate and prior_variance are different, can you double-check it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, i meant for gamma prior_variance to represent rate. Earlier too I set prior_variance to rate. This should be correct.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for flagging it though.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is prior_variance the rate of prior or the log(rate) of prior?

# sum over the last dimension
out = torch.sum(out, dim=-1)

assert out.shape == (num_seeds, num_classes)
return out

Expand Down Expand Up @@ -250,6 +303,7 @@ def rsample(self, num_seeds: int = 1) -> Union[torch.Tensor, Tuple[torch.Tensor]
"""
value_sample = self.variational_distribution.rsample(
torch.Size([num_seeds]))
# DEBUG_MARKER
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove this

if self.obs2prior:
# sample obs2prior H as well.
H_sample = self.prior_H.rsample(num_seeds=num_seeds)
Expand All @@ -258,12 +312,22 @@ def rsample(self, num_seeds: int = 1) -> Union[torch.Tensor, Tuple[torch.Tensor]
return value_sample

@property
def variational_distribution(self) -> LowRankMultivariateNormal:
def variational_distribution(self) -> Union[LowRankMultivariateNormal, Gamma]:
"""Constructs the current variational distribution of the coefficient from current variational mean and covariance.
"""
return LowRankMultivariateNormal(loc=self.variational_mean,
cov_factor=self.variational_cov_factor,
cov_diag=torch.exp(self.variational_logstd))
if self.distribution == 'gaussian':
return LowRankMultivariateNormal(loc=self.variational_mean,
cov_factor=self.variational_cov_factor,
cov_diag=torch.exp(self.variational_logstd))
elif self.distribution == 'gamma':
# for a gamma distribution, we store the concentration as log(concentration) = variational_mean_flexible
assert self.variational_mean_fixed == None, 'Gamma distribution does not support fixed mean'
concentration = self.variational_mean_flexible.exp()
# for gamma distribution, we store the rate as log(rate) = variational_logstd
rate = torch.exp(self.variational_logstd)
return Gamma(concentration=concentration, rate=rate)
else:
raise NotImplementedError("Unknown variational distribution type.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
raise NotImplementedError("Unknown variational distribution type.")
raise NotImplementedError(f"Unknown variational distribution type {self.distribution}.")


@property
def device(self) -> torch.device:
Expand Down
57 changes: 53 additions & 4 deletions bemb/model/bemb.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,21 +39,26 @@ def parse_utility(utility_string: str) -> List[Dict[str, Union[List[str], None]]
A helper function parse utility string into a list of additive terms.

Example:
utility_string = 'lambda_item + theta_user * alpha_item + gamma_user * beta_item * price_obs'
utility_string = 'lambda_item + theta_user * alpha_item - gamma_user * beta_item * price_obs'
output = [
{
'coefficient': ['lambda_item'],
'observable': None
'observable': None,
'sign': 1.0,

},
{
'coefficient': ['theta_user', 'alpha_item'],
'observable': None
'sign': 1.0,
},
{
'coefficient': ['gamma_user', 'beta_item'],
'observable': 'price_obs'
'sign': -1.0,
}
]
Note that 'minus' is allowed in the utility string. If the first term is negative, the minus should be without a space.
"""
# split additive terms
coefficient_suffix = ('_item', '_user', '_constant', '_category')
Expand All @@ -76,10 +81,16 @@ def is_coefficient(name: str) -> bool:
def is_observable(name: str) -> bool:
return any(name.startswith(prefix) for prefix in observable_prefix)

utility_string = utility_string.replace(' - ', ' + -')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might want to find a way to parse utilities even when the user does not put spaces around + or -; let's do this later.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree

additive_terms = utility_string.split(' + ')
additive_decomposition = list()
for term in additive_terms:
atom = {'coefficient': [], 'observable': None}
if term.startswith('-'):
sign = -1.0
term = term[1:]
else:
sign = 1.0
atom = {'coefficient': [], 'observable': None, 'sign': sign}
# split multiplicative terms.
for x in term.split(' * '):
assert not (is_observable(x) and is_coefficient(x)), f"The element {x} is ambiguous, it follows naming convention of both an observable and a coefficient."
Expand Down Expand Up @@ -113,6 +124,7 @@ def __init__(self,
num_items: int,
pred_item: bool,
num_classes: int = 2,
coef_dist_dict: Dict[str, str] = {'default' : 'gaussian'},
H_zero_mask_dict: Optional[Dict[str, torch.BoolTensor]] = None,
prior_mean: Union[float, Dict[str, float]] = 0.0,
prior_variance: Union[float, Dict[str, float]] = 1.0,
Expand Down Expand Up @@ -140,6 +152,14 @@ def __init__(self,
lambda_item + theta_user * alpha_item + gamma_user * beta_item * price_obs
See the doc-string of parse_utility for an example.

coef_dist_dict (Dict[str, str]): a dictionary mapping coefficient name to coefficient distribution name.
The coefficient distribution name can be one of the following:
1. 'gaussian'
2. 'gamma' - obs2prior is not supported for gamma coefficients
If a coefficient does not appear in the dictionary, it will be assigned the distribution specified
by the 'default' key. By default, the default distribution is 'gaussian'.
For coefficients which have gamma distributions, prior mean and variance MUST be specified in the prior_mean and prior_variance arguments if obs2prior is False for this coefficient. If obs2prior is True, prior_variance is still required

obs2prior_dict (Dict[str, bool]): a dictionary maps coefficient name (e.g., 'lambda_item')
to a boolean indicating if observable (e.g., item_obs) enters the prior of the coefficient.

Expand Down Expand Up @@ -184,6 +204,8 @@ def __init__(self,
If no `prior_mean['default']` is provided, the default prior mean will be 0.0 for those coefficients
not in the prior_mean.keys().

For coefficients with gamma distributions, prior_mean specifies the shape parameter of the gamma prior.

Defaults to 0.0.

prior_variance (Union[float, Dict[str, float]], Dict[str, torch. Tensor]): the variance of prior distribution
Expand All @@ -203,6 +225,8 @@ def __init__(self,
If no `prior_variance['default']` is provided, the default prior variance will be 1.0 for those coefficients
not in the prior_variance.keys().

For coefficients with gamma distributions, prior_variance specifies the concentration parameter of the gamma prior.

Defaults to 1.0, which means all priors have identity matrix as the covariance matrix.

num_users (int, optional): number of users, required only if coefficient or observable
Expand Down Expand Up @@ -233,6 +257,7 @@ def __init__(self,
self.utility_formula = utility_formula
self.obs2prior_dict = obs2prior_dict
self.coef_dim_dict = coef_dim_dict
self.coef_dist_dict = coef_dist_dict
if H_zero_mask_dict is not None:
self.H_zero_mask_dict = H_zero_mask_dict
else:
Expand Down Expand Up @@ -325,6 +350,21 @@ def __init__(self,
for additive_term in self.formula:
for coef_name in additive_term['coefficient']:
variation = coef_name.split('_')[-1]

if coef_name not in self.coef_dist_dict.keys():
if 'default' in self.coef_dist_dict.keys():
self.coef_dist_dict[coef_name] = self.coef_dist_dict['default']
else:
warnings.warn(f"You provided a dictionary of coef_dist_dict, but coefficient {coef_name} is not a key in it. Supply a value for 'default' in the coef_dist_dict dictionary to use that as default value (e.g., coef_dist_dict['default'] = 'gaussian'); now using distribution='gaussian' since this is not supplied.")
self.coef_dist_dict[coef_name] = 'gaussian'

elif self.coef_dist_dict[coef_name] == 'gamma':
if not self.obs2prior_dict[coef_name]:
assert isinstance(self.prior_mean, dict) and coef_name in self.prior_mean.keys(), \
f"Prior mean for {coef_name} needs to be provided because it's posterior is estimated as a gamma distribution."
assert isinstance(self.prior_variance, dict) and coef_name in self.prior_variance.keys(), \
f"Prior variance for {coef_name} needs to be provided because it's posterior is estimated as a gamma distribution."

if isinstance(self.prior_mean, dict):
# the user didn't specify prior mean for this coefficient.
if coef_name not in self.prior_mean.keys():
Expand Down Expand Up @@ -367,7 +407,8 @@ def __init__(self,
prior_mean=mean,
prior_variance=s2,
H_zero_mask=H_zero_mask,
is_H=False)
is_H=False,
distribution=self.coef_dist_dict[coef_name])
self.coef_dict = nn.ModuleDict(coef_dict)

# ==============================================================================================================
Expand Down Expand Up @@ -907,6 +948,7 @@ def reshape_observable(obs, name):
sample_dict[coef_name], coef_name)
assert coef_sample.shape == (R, P, I, 1)
additive_term = coef_sample.view(R, P, I)
additive_term *= term['sign']

# Type II: factorized coefficient, e.g., <theta_user, lambda_item>.
elif len(term['coefficient']) == 2 and term['observable'] is None:
Expand All @@ -922,6 +964,7 @@ def reshape_observable(obs, name):
R, P, I, positive_integer)

additive_term = (coef_sample_0 * coef_sample_1).sum(dim=-1)
additive_term *= term['sign']

# Type III: single coefficient multiplied by observable, e.g., theta_user * x_obs_item.
elif len(term['coefficient']) == 1 and term['observable'] is not None:
Expand All @@ -935,6 +978,7 @@ def reshape_observable(obs, name):
assert obs.shape == (R, P, I, positive_integer)

additive_term = (coef_sample * obs).sum(dim=-1)
additive_term *= term['sign']

# Type IV: factorized coefficient multiplied by observable.
# e.g., gamma_user * beta_item * price_obs.
Expand Down Expand Up @@ -965,6 +1009,7 @@ def reshape_observable(obs, name):
coef = (coef_sample_0 * coef_sample_1).sum(dim=-1)

additive_term = (coef * obs).sum(dim=-1)
additive_term *= term['sign']

else:
raise ValueError(f'Undefined term type: {term}')
Expand Down Expand Up @@ -1138,6 +1183,7 @@ def reshape_observable(obs, name):
sample_dict[coef_name], coef_name)
assert coef_sample.shape == (R, total_computation, 1)
additive_term = coef_sample.view(R, total_computation)
additive_term *= term['sign']

# Type II: factorized coefficient, e.g., <theta_user, lambda_item>.
elif len(term['coefficient']) == 2 and term['observable'] is None:
Expand All @@ -1153,6 +1199,7 @@ def reshape_observable(obs, name):
R, total_computation, positive_integer)

additive_term = (coef_sample_0 * coef_sample_1).sum(dim=-1)
additive_term *= term['sign']

# Type III: single coefficient multiplied by observable, e.g., theta_user * x_obs_item.
elif len(term['coefficient']) == 1 and term['observable'] is not None:
Expand All @@ -1167,6 +1214,7 @@ def reshape_observable(obs, name):
assert obs.shape == (R, total_computation, positive_integer)

additive_term = (coef_sample * obs).sum(dim=-1)
additive_term *= term['sign']

# Type IV: factorized coefficient multiplied by observable.
# e.g., gamma_user * beta_item * price_obs.
Expand Down Expand Up @@ -1196,6 +1244,7 @@ def reshape_observable(obs, name):
coef = (coef_sample_0 * coef_sample_1).sum(dim=-1)

additive_term = (coef * obs).sum(dim=-1)
additive_term *= term['sign']
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we just add one single additive_term *= term['sign'] outside the if-else-if loop?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, should be possible but i'll make sure


else:
raise ValueError(f'Undefined term type: {term}')
Expand Down
8 changes: 8 additions & 0 deletions bemb/model/bemb_flex_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,14 @@ def training_step(self, batch, batch_idx):
loss = - elbo
return loss

# DEBUG_MARKER
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's remove debug marker.

'''
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
print(f"Epoch {self.current_epoch} has ended")
breakpoint()
'''
# DEBUG_MARKER

def _get_performance_dict(self, batch):
if self.model.pred_item:
log_p = self.model(batch, return_type='log_prob',
Expand Down
Loading
Loading