Skip to content

Commit

Permalink
Today's work on empirical mass estimation.
Browse files Browse the repository at this point in the history
- Plotting routines work now.
- Turns out that mass estimation is not correctly treating HDP intervals
  • Loading branch information
cweniger committed Aug 25, 2022
1 parent 71cd013 commit 58fe82a
Show file tree
Hide file tree
Showing 4 changed files with 946 additions and 342 deletions.
591 changes: 591 additions & 0 deletions notebooks/0D-Mass.ipynb

Large diffs are not rendered by default.

21 changes: 17 additions & 4 deletions swyft/lightning/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def get_pms(p0, p1):
ms.append(m)
masses = torch.stack(ms, dim = 0)
params = torch.stack(vs, dim = 0)
out = PosteriorMassSamples(params, masses, pred0.parnames)
out = PosteriorMassSamples(params, masses, p0.parnames)
return out

if isinstance(pred0, tuple):
Expand Down Expand Up @@ -251,7 +251,18 @@ def get_matching_masses(self, *args):
return self.masses[:,i]
return None

def get_z_score(self, max_z_score = 3.5, n_bins = 50, interval_z_score = 1.0):
def get_z_scores(self):
out = {}
for i, pars in enumerate(self.parnames):
pars = tuple(pars)
m = self.masses[:,i]
z0, z1, z2 = get_empirical_z_score(m, 3.5, 50, 1.0)
out[pars] = np.array([np.interp([1.0, 2.0, 3.0], z0, z2[:,0]),
np.interp([1.0, 2.0, 3.0], z0, z1), np.interp([1.0, 2.0, 3.0],
z0, z2[:,1])]).T
return out

def get_matching_z_score(self, *args, max_z_score = 3.5, n_bins = 50, interval_z_score = 1.0):
"""Calculate empirical z-score of highest-density posterior interval.
Args:
Expand All @@ -262,15 +273,17 @@ def get_z_score(self, max_z_score = 3.5, n_bins = 50, interval_z_score = 1.0):
Returns:
Array with z-score (..., n_bins, 4)
"""
z0, z1, z2 = get_empirical_z_score(self.masses, max_z_score, n_bins, interval_z_score)
m = self.get_matching_masses(*args)
if m is None:
return None
z0, z1, z2 = get_empirical_z_score(m, max_z_score, n_bins, interval_z_score)
z0 = np.tile(z0, (*z1.shape[:-1], 1))
z0 = np.reshape(z0, (*z0.shape, 1))
z1 = z1.reshape(*z1.shape, 1)
z = np.concatenate([z0, z1, z2], axis=-1)
return z



@dataclass
class LogRatioSamples:
"""Handles logratios and the corresponding parameter values.
Expand Down
Loading

0 comments on commit 58fe82a

Please sign in to comment.