-
Notifications
You must be signed in to change notification settings - Fork 18
/
Copy pathdist.py
19 lines (15 loc) · 852 Bytes
/
dist.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
from pyro.distributions import Bernoulli
from torch.distributions.utils import broadcast_all
from torch.nn.functional import binary_cross_entropy_with_logits
# from pyro.distributions.torch_distribution import TorchDistribution
class WeightedBernoulli(Bernoulli):
"""Bernoulli distribution with a weighted cross entropy. Used for imbalanced data when you
want to increase the penalizization of the positive class. """
def __init__(self, *args, **kwargs):
self.weight = kwargs.pop('weight', 1.0)
super(WeightedBernoulli, self).__init__(*args, **kwargs)
def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
logits, value = broadcast_all(self.logits, value)
return -binary_cross_entropy_with_logits(logits, value, reduction='none', pos_weight=self.weight)