Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

post: bugfix: chains should be detempered/reweighted at once (or bad weights) #322

Merged
merged 6 commits into from
Oct 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@

- Created a general `load_samples` function to load Cobaya results natively or as GetDist MCSamples.
- Improved `.products()` method for samplers (MCMC and PolyChord) and post-processing: samples can now retrieved simultaneously for all MPI processes, and converted to GetDist. Also added `.samples()` methods to retrieve just the samples.
- Fixed a bug with mpi runs partly stalling when run with many chains (#308)
- Python 12 support (removed all dependence on distutils)
- Collections are now aware of whether they are part of a parallel batch, and warn if trying to reweight/detemper individually (fixes #321).
- Fixed a bug with mpi runs partly stalling when run with many chains (#308, thanks @vivianmiranda @lukashergt for reporting and testing).
- Fixed a bug with overzealous checks when loading chains (#306, thanks @mishakb for reporting).
- Python 3.12 support (removed all dependence on distutils)

### Cosmology

Expand Down
162 changes: 116 additions & 46 deletions cobaya/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,19 +111,26 @@ def compute_temperature(logpost, logprior, loglike, check=True):
return temp


def detempering_weights_factor(tempered_logpost, temperature):
def detempering_weights_factor(tempered_logpost, temperature, max_tempered_logpost=None):
"""
Returns the detempering factors for the weights of a tempered sample, i.e. if ``w_t``
is the weight of the tempered sample, then the weight of the unit-temperature one is
``w_t * f``, where the ``f`` returned by this method is
``exp(logp * (1 - 1/temperature))``, where ``logp`` is the (untempered) logposterior.

Factors are normalized so that the largest equals one.
Factors are normalized so that the largest equals one, according to the maximum
logposterior (can be overridden with argument ``max_tempered_logpost``, useful for
detempering chain batches).
"""
if temperature == 1:
return np.ones(np.atleast_1d(tempered_logpost).shape)
log_ratio = remove_temperature(tempered_logpost, temperature) - tempered_logpost
return np.exp(log_ratio - max(log_ratio))
if max_tempered_logpost is None:
max_log_ratio = max(log_ratio)
else:
max_log_ratio = \
remove_temperature(max_tempered_logpost, temperature) - max_tempered_logpost
return np.exp(log_ratio - max_log_ratio)


class BaseCollection(HasLogger):
Expand Down Expand Up @@ -202,13 +209,15 @@ class SampleCollection(BaseCollection):

def __init__(self, model, output=None, cache_size=_default_cache_size, name=None,
extension=None, file_name=None, resuming=False, load=False,
temperature=None, onload_skip=0, onload_thin=1, sample_type=None):
temperature=None, onload_skip=0, onload_thin=1, sample_type=None,
is_batch=False):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is_batch is not used?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops!

super().__init__(model, name)
if sample_type is not None and (not isinstance(sample_type, str) or
not sample_type.lower() in sample_types):
raise LoggedError(self.log, "'sample_type' must be one of %r.", sample_types)
self.sample_type = sample_type.lower() if sample_type is not None else sample_type
self.cache_size = cache_size
self.is_batch = is_batch
self._data = None
self._n = None
# Create/load the main data frame and the tracking indices
Expand Down Expand Up @@ -478,7 +487,8 @@ def _check_logps(self, temperature_only=False):
np.atleast_2d(
self[self.minuslogprior_names].to_numpy(dtype=np.float64)),
axis=-1,
)
),
rtol=1e-4,
):
raise LoggedError(
self.log,
Expand All @@ -489,7 +499,8 @@ def _check_logps(self, temperature_only=False):
np.sum(
np.atleast_2d(self[self.chi2_names].to_numpy(dtype=np.float64)),
axis=-1,
)
),
rtol=1e-4,
):
raise LoggedError(
self.log,
Expand All @@ -502,37 +513,40 @@ def _check_logps(self, temperature_only=False):
def _check_weights(
self,
weights: Optional[np.ndarray] = None,
length: Optional[int] = None
length: Optional[Union[int, List[int]]] = None
):
"""
Checks correct length, shape and signs of the ``weights``.

If no weights passed, checks internal consistency.

If ``length`` passed, checks for specific length of the weights vector.
If ``length`` passed, checks for specific length(s) of the weights vector(s).

Raises ``LoggedError`` if the weights are badly arranged or invalid.
"""
if weights is None:
weights_array = self[OutPar.weight].to_numpy(dtype=np.float64)
weights = [self[OutPar.weight].to_numpy(dtype=np.float64)]
else:
weights_array = np.atleast_1d(weights)
if len(weights_array.shape) != 1:
if not hasattr(weights[0], "__len__"):
weights = [weights]
weights = [np.array(ws) for ws in weights]
if length is None:
length = [len(w) for w in weights]
lengths_array = np.atleast_1d(length)
if len(weights) != len(lengths_array):
expected_msg = f"Expected a list of {len(lengths_array)} 1d arrays"
raise LoggedError(
self.log,
"The shape of the weights is wrong. Expected a 1d array, "
"but got shape %r.",
weights_array.shape
f"The shape of the weights is wrong. {expected_msg}, "
f"but got {weights}."
)
check_length = len(self) if length is None else length
if len(weights_array) != check_length:
if any(len(w) != l for w, l in zip(weights, lengths_array)):
raise LoggedError(
self.log,
"The length of the weights vector is wrong. Expected %d but got %d.",
check_length,
len(weights_array)
f"The lengths of the weights vectors are wrong. Expected "
f"{[len(w) for w in weights]} but got {lengths_array}."
)
if np.any(weights_array < 0):
if any(np.any(ws < 0) for ws in weights):
raise LoggedError(
self.log,
"The weight vector contains negative elements."
Expand All @@ -554,17 +568,47 @@ def has_int_weights(self) -> bool:
weights = self[OutPar.weight]
return np.allclose(np.round(weights), weights)

def _detempered_weights(self):
"""Computes the detempered weights."""
# pylint: disable=protected-access
def _detempered_weights(self, with_batch=None):
"""
Computes the detempered weights.

If this sample is part of a batch, call this method passing the rest of the batch
as a list using the argument ``with_batch`` (otherwise inconsistent weights
between samples will be introduced). If additional chains are passed with
``with_batch``, their temperature will be reset in-place.

Returns always a list of weight vectors: one element per collection in the batch.
"""
batch = [self]
if with_batch is not None:
batch += list(with_batch)
elif self.is_batch:
self.log.warning(
"Trying to get detempered weights for individual sample collection that "
"appears to be part of a batch (e.g. of parallel MCMC chains). This will "
"produce inconsistent weights across chains, unless passing the rest of "
"the batch as ``with_batch=[collection_1, collection_2,... ]``.")
temps = [c.temperature for c in batch]
if not np.allclose(temps, temps[0]):
raise LoggedError(
self.log,
f"Temperature inconsistent across the batch: {temps}."
)
for c in batch:
c._cache_dump()
if self.temperature == 1:
return self._data[OutPar.weight].to_numpy(dtype=np.float64)
return (
self._data[OutPar.weight].to_numpy(dtype=np.float64) *
return [c._data[OutPar.weight].to_numpy(dtype=np.float64) for c in batch]
max_logpost = np.max(np.concatenate(
[-c._data[OutPar.minuslogpost].to_numpy(dtype=np.float64) for c in batch]))
return [
c._data[OutPar.weight].to_numpy(dtype=np.float64) *
detempering_weights_factor(
-self._data[OutPar.minuslogpost].to_numpy(dtype=np.float64),
self.temperature
)
)
-c._data[OutPar.minuslogpost].to_numpy(dtype=np.float64),
c.temperature,
max_tempered_logpost=max_logpost
) for c in batch
]

def _detempered_minuslogpost(self):
"""Computes the detempered -log-posterior."""
Expand All @@ -575,21 +619,30 @@ def _detempered_minuslogpost(self):
self.temperature
)

def reset_temperature(self):
# pylint: disable=protected-access
def reset_temperature(self, with_batch=None):
"""
Drops the information about sampling temperature: ``weight`` and ``minuslogpost``
columns will now correspond to those of a unit-temperature posterior sample.

If this sample is part of a batch, call this method passing the rest of the batch
as a list using the argument ``with`` (otherwise inconsistent weights between
samples will be introduced). If additional chains are passed with ``with``, their
temperature will be reset in-place.

This cannot be undone: (e.g. recovering original integer tempered weights).
You may want to call this method on a copy (see :func:`SampleCollection.copy`).
"""
self._cache_dump()
weights_batch = self._detempered_weights(with_batch=with_batch)
# Calling *after* getting weights, since that call checks consistency across batch
if self.temperature == 1:
return
self._data[OutPar.weight] = self._detempered_weights()
self._drop_samples_null_weight()
self._data[OutPar.minuslogpost] = self._detempered_minuslogpost()
self.temperature = 1
batch = [self] + list(with_batch or [])
for c, weights in zip(batch, weights_batch):
c._data[OutPar.weight] = weights
c._drop_samples_null_weight()
c._data[OutPar.minuslogpost] = c._detempered_minuslogpost()
c.temperature = 1

def _enlarge(self, n):
"""
Expand Down Expand Up @@ -719,8 +772,9 @@ def _weights_for_stats(
weights /= max(weights)
return weights, np.allclose(np.round(weights), weights)
if self.is_tempered and not tempered:
# For sure the weights are not integer
return self._detempered_weights()[first:last], False
# For sure the weights are not integer in this case
# NB: Index [0] below bc a list is returned, in case of batch processing
return self._detempered_weights()[0][first:last], False
return (
self[OutPar.weight][first:last].to_numpy(dtype=np.float64),
self.has_int_weights
Expand All @@ -732,7 +786,7 @@ def mean(
last: Optional[int] = None,
weights: Optional[np.ndarray] = None,
derived: bool = False,
tempered: bool = False
tempered: bool = False,
) -> np.ndarray:
"""
Returns the (weighted) mean of the parameters in the chain,
Expand Down Expand Up @@ -770,7 +824,7 @@ def cov(
last: Optional[int] = None,
weights: Optional[np.ndarray] = None,
derived: bool = False,
tempered: bool = False
tempered: bool = False,
) -> np.ndarray:
"""
Returns the (weighted) covariance matrix of the parameters in the chain,
Expand Down Expand Up @@ -800,31 +854,47 @@ def cov(
return np.atleast_2d(np.cov( # type: ignore
self[list(self.sampled_params) +
(list(self.derived_params) if derived else [])][first:last].to_numpy(
dtype=np.float64).T,
dtype=np.float64).T,
ddof=0, # does simple mean w/o bias factor; weights are used as probabilities
**{weight_type_kwarg: weights_cov}))

def _drop_samples_null_weight(self):
"""Removes from the DataFrame all samples that have 0 weight."""
self._data = self.data[self._data.weight > 0].reset_index(drop=True)
self._n = self._data.last_valid_index() + 1

def reweight(self, importance_weights, check=True):
def reweight(self, importance_weights, with_batch=None, check=True):
"""
Reweights the sample with the given ``importance_weights``.
Reweights the sample in-place with the given ``importance_weights``.

Temperature information is dropped.

If this sample is part of a batch, call this method passing the rest of the batch
as a list using the argument ``with_match`` (otherwise inconsistent weights
between samples will be introduced). If additional chains are passed with
``with_batch``, they will also be reweighted in-place. In that case,
``importance_weights`` needs to be a list of weight vectors, the first of which to
be applied to the current instance, and the rest to the collections passed with
``with_batch``.

This cannot be fully undone (e.g. recovering original integer weights).
You may want to call this method on a copy (see :func:`SampleCollection.copy`).

For the sake of speed, length and positivity checks on the importance weights can
be skipped with ``check=False`` (default ``True``).
"""
self.reset_temperature() # includes a self._cache_dump()
self.reset_temperature(with_batch=with_batch) # includes a self._cache_dump()
if not hasattr(importance_weights[0], "__len__"):
importance_weights = [importance_weights]
if check:
self._check_weights(importance_weights)
self._data[OutPar.weight] *= importance_weights
self._drop_samples_null_weight()
self._check_weights(
importance_weights,
length=[len(self)] + [len(c) for c in with_batch or []]
)
batch = [self] + list(with_batch or [])
for c, iweights in zip(batch, importance_weights):
c._data[OutPar.weight] *= iweights
c._drop_samples_null_weight()

def filtered_copy(self, where) -> 'SampleCollection':
"""Returns a copy of the collection with some condition ``where`` imposed."""
Expand Down
4 changes: 2 additions & 2 deletions cobaya/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from itertools import chain
from functools import reduce
from typing import Mapping, MutableMapping, Union, Optional, TypeVar, Callable, Dict, \
List, Sized
List
from collections import defaultdict

# Local
Expand Down Expand Up @@ -570,7 +570,7 @@ def is_equal_info(info_old, info_new, strict=True, print_not_log=False, ignore_b
for value in [block1[k], block2[k]]:
if isinstance(value, MutableMapping):
for kk in value:
if isinstance(value[kk], Sized) and len(value[kk]) == 0:
if hasattr(value[kk], "__len__") and len(value[kk]) == 0:
value[kk] = None
if block1[k] != block2[k]:
# For clarity, pop common stuff before printing
Expand Down
1 change: 0 additions & 1 deletion cobaya/likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ def current_logp(self) -> float:
:return: log likelihood from the current state as a scalar
"""
value = self.current_state["logp"]
# Unfortunately, numpy arrays are not derived from Sequence, so the test is ugly
if hasattr(value, "__len__"):
value = value[0]
Copy link
Collaborator

@cmbant cmbant Oct 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sized does not imply indexable?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note the Sized form is also nearly ten times slower than hasattr, which I think may be why we were using that

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sized simply implies that it has length (as opposed to Sequence, which is not compatible with numpy arrays). Reverting where applicable. I though you preferred stronger type checks and less duck-typing, but fine to me.

return value
Expand Down
5 changes: 4 additions & 1 deletion cobaya/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def __init__(self, filename=None, log=None):
if filename:
self.set_lock(log, filename)

# pylint: disable=consider-using-with
def set_lock(self, log, filename, force=False):
if self.has_lock():
return
Expand Down Expand Up @@ -318,7 +319,8 @@ def load_collections(self, model, skip=0, thin=1, combined=False,
from cobaya.collection import SampleCollection
collections = [
SampleCollection(model, self, name="%d" % (1 + i), file_name=filename,
load=True, onload_skip=skip, onload_thin=thin)
load=True, onload_skip=skip, onload_thin=thin,
is_batch=len(filenames) > 1)
for i, filename in enumerate(filenames)]
# MARKED FOR DEPRECATION IN v3.3
if concatenate is not None:
Expand All @@ -332,6 +334,7 @@ def load_collections(self, model, skip=0, thin=1, combined=False,
for collection_i in collections[1:]:
# noinspection PyProtectedMember
collection._append(collection_i) # pylint: disable=protected-access
collection.is_batch = False
return collection
return collections

Expand Down
20 changes: 14 additions & 6 deletions cobaya/post.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,18 +244,26 @@ def post(info_or_yaml_or_file: Union[InputDict, str, os.PathLike],
else:
raise LoggedError(log, "No output from where to load from, "
"nor input collections given.")
# A note on tempered chains: detempering happens automatically when reweighting,
# which is done later in this function in most cases.
# But for the sake of robustness, we detemper all chains at init.
if mpi.is_main_process() and any(c.is_tempered for c in in_collections):
log.info("Starting from tempered chains. Will detemper before proceeding.")
# Let's make sure we work on a copy if the chain is going to be altered
already_copied = bool(output_in) or (sample is not None and (skip or thin != 1))
for i, collection in enumerate(in_collections):
if not already_copied:
collection = collection.copy()
collection.reset_temperature()
in_collections[i] = collection
# A note on tempered chains: detempering happens automatically when reweighting,
# which is done later in this function in most cases.
# But for the sake of robustness, we detemper all chains at init.
# In order not to introduce reweighting errors coming from subtractions of the max
# log-posterior at detempering, we need to detemper all samples at once
all_in_collections = mpi.gather(in_collections)
if mpi.is_main_process():
flat_in_collections = list(chain(*all_in_collections))
if any(c.is_tempered for c in flat_in_collections):
log.info("Starting from tempered chains. Will detemper before proceeding.")
flat_in_collections[0].reset_temperature(with_batch=flat_in_collections[1:])
# Detempering happens in place, so one can scatter back the original
# all_in_collections object to preserve the in_collection dist across processes
in_collections = mpi.scatter(all_in_collections)
if any(len(c) <= 1 for c in in_collections):
raise LoggedError(
log, "Not enough samples for post-processing. Try using a larger sample, "
Expand Down
Loading