Skip to content

Commit

Permalink
Add get_z_score and get_matching_masses to PosteriorMassSamples
Browse files Browse the repository at this point in the history
	modified:   core.py
  • Loading branch information
cweniger committed Aug 24, 2022
1 parent 6eeb6c8 commit 71cd013
Showing 1 changed file with 27 additions and 0 deletions.
27 changes: 27 additions & 0 deletions swyft/lightning/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import yaml

from swyft.lightning.samples import *
from swyft.plot.mass import get_empirical_z_score


class SwyftModule(pl.LightningModule):
Expand Down Expand Up @@ -244,6 +245,32 @@ class PosteriorMassSamples:
masses: torch.Tensor
parnames: np.array

def get_matching_masses(self, *args):
for i, pars in enumerate(self.parnames):
if set(pars) == set(args):
return self.masses[:,i]
return None

def get_z_score(self, max_z_score = 3.5, n_bins = 50, interval_z_score = 1.0):
"""Calculate empirical z-score of highest-density posterior interval.
Args:
max_z_score: upper limit (default 3.5)
n_bins (int): number of bins used when tabulating z-score
interval_z_score: interval used for calculating statistical z-score uncertainties
Returns:
Array with z-score (..., n_bins, 4)
"""
z0, z1, z2 = get_empirical_z_score(self.masses, max_z_score, n_bins, interval_z_score)
z0 = np.tile(z0, (*z1.shape[:-1], 1))
z0 = np.reshape(z0, (*z0.shape, 1))
z1 = z1.reshape(*z1.shape, 1)
z = np.concatenate([z0, z1, z2], axis=-1)
return z



@dataclass
class LogRatioSamples:
"""Handles logratios and the corresponding parameter values.
Expand Down

0 comments on commit 71cd013

Please sign in to comment.