Skip to content

Commit

Permalink
up to date simulation notebook with fitting
Browse files Browse the repository at this point in the history
  • Loading branch information
jgallowa07 committed Nov 30, 2023
1 parent 4c1a418 commit ee8e688
Show file tree
Hide file tree
Showing 3 changed files with 5,287 additions and 4 deletions.
13 changes: 10 additions & 3 deletions multidms/model_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,12 +235,14 @@ def fit_one_model(
return pd.Series(fit_attributes)[col_order] # .to_frame().T[col_order]


def _fit_fun(arg):
def _fit_fun(params):
"""Workaround for multiprocessing to fit models with sets of kwargs"""
# import jax
# jax.config.update("jax_platform_name", "cpu")
_, kwargs = arg
# data, kwargs = params
_, kwargs = params
try:
# return fit_one_model(data, **kwargs)
return fit_one_model(**kwargs)
except Exception:
return None
Expand All @@ -254,6 +256,8 @@ def stack_fit_models(fit_models_list):
return pd.concat([f.to_frame().T for f in fit_models_list], ignore_index=True)


# TODO document that these params should not be unpacked
# when passed as with fit_one_model.
def fit_models(params, n_threads, failures="error"):
"""Fit collection of :class:`~multidms.Model` models.
Expand Down Expand Up @@ -295,7 +299,10 @@ def fit_models(params, n_threads, failures="error"):
exploded_params = _explode_params_dict(params)
# see https://pythonspeed.com/articles/python-multiprocessing/ for why we spawn
with get_context("spawn").Pool(n_threads) as p:
fit_models = p.map(_fit_fun, [(None, kwargs) for kwargs in exploded_params])
fit_models = p.map(_fit_fun, [(None, params) for params in exploded_params])
# fit_models = p.map(
# _fit_fun, [(params.pop("dataset"), params) for params in exploded_params]
# )

assert len(fit_models) == len(exploded_params)

Expand Down
5,275 changes: 5,275 additions & 0 deletions notebooks/simulate_data.ipynb

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ dev = [
"nbsphinx",
"nbsphinx_link",
"bumpver",
"sphinxcontrib-bibtex"
"sphinxcontrib-bibtex",
"dms_variants"
]

[tool.ruff]
Expand Down

0 comments on commit ee8e688

Please sign in to comment.