-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Gamma coeff #38
base: main
Are you sure you want to change the base?
Gamma coeff #38
Conversation
""" | ||
super(BayesianCoefficient, self).__init__() | ||
# do we use this at all? TODO: drop self.variation. | ||
assert variation in ['item', 'user', 'constant', 'category'] | ||
|
||
self.variation = variation | ||
|
||
assert distribution in ['gaussian', 'gamma'], f'Unsupported distribution {distribution}' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
assert distribution in ['gaussian', 'gamma'], f'Unsupported distribution {distribution}' | |
if distribution not in ['gaussian', 'gamma']: | |
raise ValueError( f'Unsupported distribution {distribution}') |
|
||
assert distribution in ['gaussian', 'gamma'], f'Unsupported distribution {distribution}' | ||
if distribution == 'gamma': | ||
''' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did you comment out this code chunk using '''
? Can you remove it if we don't need it here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes this is an artefact of this being a temporary commit
@@ -196,7 +235,11 @@ def log_prior(self, | |||
Returns: | |||
torch.Tensor: the log prior of the variable with shape (num_seeds, num_classes). | |||
""" | |||
# p(sample) | |||
# DEBUG_MARKER |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove the debug marker?
concentration = torch.exp(mu) | ||
rate = self.prior_variance | ||
out = Gamma(concentration=concentration, | ||
rate=rate).log_prob(sample) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This might be a bug! The rate
and prior_variance
are different, can you double-check it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, i meant for gamma prior_variance to represent rate. Earlier too I set prior_variance to rate. This should be correct.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for flagging it though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is prior_variance
the rate of prior or the log(rate) of prior?
rate = torch.exp(self.variational_logstd) | ||
return Gamma(concentration=concentration, rate=rate) | ||
else: | ||
raise NotImplementedError("Unknown variational distribution type.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
raise NotImplementedError("Unknown variational distribution type.") | |
raise NotImplementedError(f"Unknown variational distribution type {self.distribution}.") |
@@ -76,10 +81,16 @@ def is_coefficient(name: str) -> bool: | |||
def is_observable(name: str) -> bool: | |||
return any(name.startswith(prefix) for prefix in observable_prefix) | |||
|
|||
utility_string = utility_string.replace(' - ', ' + -') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We might want to find a way to parse utilities even when the user does not put spaces around +
or -
; let's do this later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agree
@@ -1196,6 +1244,7 @@ def reshape_observable(obs, name): | |||
coef = (coef_sample_0 * coef_sample_1).sum(dim=-1) | |||
|
|||
additive_term = (coef * obs).sum(dim=-1) | |||
additive_term *= term['sign'] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we just add one single additive_term *= term['sign']
outside the if-else-if
loop?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, should be possible but i'll make sure
@@ -79,6 +79,14 @@ def training_step(self, batch, batch_idx): | |||
loss = - elbo | |||
return loss | |||
|
|||
# DEBUG_MARKER |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's remove debug marker.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did not review the configurations and main in your super-market specific script.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes i'll take care of those.
Temporary working commit for Tianyu to review.