Skip to content

Commit

Permalink
trivial tidy
Browse files Browse the repository at this point in the history
  • Loading branch information
cmbant committed Jan 9, 2024
1 parent 62e1ca8 commit a3cb9bf
Showing 1 changed file with 12 additions and 11 deletions.
23 changes: 12 additions & 11 deletions cobaya/likelihoods/base_classes/bao.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,14 +266,13 @@ def initialize(self):
"needs to be specified.")
self.data["observable"] = [self.observable_1]
x = self.grid_data[:, 0]
Nx = x.shape[0]
chi2 = np.log(self.grid_data[:, 1])
self.interpolator = UnivariateSpline(x, chi2, s=0, ext=2)
elif self.grid_data.shape[1] == 3:
self.use_grid_1d = False
self.use_grid_2d = True
self.use_grid_3d = False
if (not self.observable_1) or (not self.observable_2):
if not (self.observable_1 and self.observable_2):
raise LoggedError(
self.log, "If using grid data, 'observable_1' and 'observable_2'"
"need to be specified.")
Expand All @@ -293,10 +292,11 @@ def initialize(self):
self.use_grid_1d = False
self.use_grid_2d = False
self.use_grid_3d = True
if (not self.observable_1) or (not self.observable_2) or (not self.observable_3):
if not (self.observable_1 and self.observable_2 and self.observable_3):
raise LoggedError(
self.log, "If using grid data, 'observable_1', 'observable_2' and 'observable_3'"
"need to be specified.")
self.log,
"If using grid data, 'observable_1', 'observable_2' "
"and 'observable_3' need to be specified.")
self.data["observable"] = [self.observable_1, self.observable_2,
self.observable_3]

Expand Down Expand Up @@ -390,21 +390,22 @@ def get_requirements(self):
if obs not in theory_reqs])
if len(obs_used_not_implemented):
raise LoggedError(
self.log, "This likelihood refers to observables '%s' that have not been"
" implemented yet. Did you mean any of %s? "
self.log, "This likelihood refers to observables '%s' that "
"have not been implemented yet. Did you mean any of %s? "
"If you didn't, please, open an issue in github.",
obs_used_not_implemented, list(theory_reqs))
requisites = {}
if self.has_type:
for observable in self.data["observable"].unique():
for req, req_values in theory_reqs[observable].items():
if req not in requisites.keys():
if req not in requisites:
requisites[req] = req_values
else:
if isinstance(req_values, dict):
for k, v in req_values.items():
if v is not None:
requisites[req][k] = np.unique(np.concatenate((requisites[req][k], v)))
requisites[req][k] = np.unique(
np.concatenate((requisites[req][k], v)))
return requisites

def theory_fun(self, z, observable):
Expand All @@ -419,8 +420,8 @@ def theory_fun(self, z, observable):
elif observable == "rs_over_DV":
return np.cbrt(
((1 + z) * self.provider.get_angular_diameter_distance(z)) ** 2 *
Const.c_km_s * z / self.provider.get_Hubble(z, units="km/s/Mpc")) ** (
-1) * self.rs()
Const.c_km_s * z / self.provider.get_Hubble(z, units="km/s/Mpc")
) ** (-1) * self.rs()
# Comoving angular diameter distance, over sound horizon radius
elif observable == "DM_over_rs":
return (1 + z) * self.provider.get_angular_diameter_distance(z) / self.rs()
Expand Down

0 comments on commit a3cb9bf

Please sign in to comment.