Skip to content

Commit

Permalink
160 161 transform normal init beta (#163)
Browse files Browse the repository at this point in the history
* bump version 1.0.0 -> 1.1.0
  • Loading branch information
jgallowa07 authored Jul 12, 2024
1 parent d7ad087 commit 23dd3ed
Show file tree
Hide file tree
Showing 9 changed files with 55 additions and 34 deletions.
9 changes: 7 additions & 2 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,16 @@ All notable changes to this project will be documented in this file.

The format is based on `Keep a Changelog <https://keepachangelog.com>`_.

1.1.0
-----
* No longer calling transform() on parameters for single condition fits. See `#160 <https://github.com/matsengrp/multidms/issues/160>`_.
* Added `init_beta_variance` parameter to the `Model` instantiation to allow the user to initialize beta parameters by sampling a normal distribution. See `#161 <https://github.com/matsengrp/multidms/issues/161>`_.


1.0.0
-----
- This release re-implements the joint model as a using a generalized lasso, and bit-flipping, as described in `#156 https://github.com/matsengrp/multidms/issues/156`_. Please see the issue for more detailed description about how, and why these changes were made. Note that this changes the parameters that one may get from the model including a set of beta's for each experimental condition.
- It also cleans up various TODO's in the code as checked-off in `# https://github.com/matsengrp/multidms/issues/153`.
- This release re-implements the joint model as a using a generalized lasso, and bit-flipping, as described in `#156 <https://github.com/matsengrp/multidms/issues/156>`_. Please see the issue for more detailed description about how, and why these changes were made. Note that this changes the parameters that one may get from the model including a set of beta's for each experimental condition.
- It also cleans up various TODO's in the code as checked-off in `#153 <https://github.com/matsengrp/multidms/issues/153>`.
- Fixes a bug, where the phenotype predictions for single mutants did not correctly include the bundle effects.
- Fixes and cleans various plotting bugs.

Expand Down
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
project = "multidms"
copyright = "2023, Jared Galloway, Hugh Haddox"
author = "Jared Galloway"
release = "1.0.0"
release = "1.1.0"

needs_sphinx = "1.0"

Expand Down
2 changes: 1 addition & 1 deletion multidms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class works to compose, compile, and optimize the model parameters

__author__ = "Jared Galloway"
__email__ = "[email protected]"
__version__ = "1.0.0"
__version__ = "1.1.0"
__url__ = "https://github.com/matsengrp/multidms"

from binarymap.binarymap import AAS_NOSTOP as AAS # noqa: F401
Expand Down
6 changes: 3 additions & 3 deletions multidms/biophysical.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,10 +330,10 @@ def proximal_box_constraints(params, hyperparameters, *args, **kwargs):
(
ge_scale_upper_bound,
lock_params,
bundle_idxs,
# bundle_idxs,
) = hyperparameters

params = transform(params, bundle_idxs)
# params = transform(params, bundle_idxs)

# clamp theta scale to monotonic, and with optional upper bound
if "ge_scale" in params["theta"]:
Expand All @@ -347,7 +347,7 @@ def proximal_box_constraints(params, hyperparameters, *args, **kwargs):
for (param, subparam), value in lock_params.items():
params[param][subparam] = value

params = transform(params, bundle_idxs)
# params = transform(params, bundle_idxs)
return params


Expand Down
2 changes: 1 addition & 1 deletion multidms/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import pandas as pd
from polyclonal.plot import DEFAULT_POSITIVE_COLORS
from polyclonal.utils import MutationParser
from tqdm.auto import tqdm
from tqdm import tqdm

from multidms import AAS
from multidms.utils import rereference, split_subs
Expand Down
38 changes: 24 additions & 14 deletions multidms/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,10 @@ class Model:
init_theta_bias : float
Initialize the bias parameter :math:`\theta_{\text{bias}}` of
a two parameter epistatic model (Sigmoid or Softplus).
init_beta_variance : float
Beta parameters are initialized by sampling from
a normal distribution. This parameter specifies the
variance of the distribution being sampled.
n_hidden_units : int or None
If using :func:`multidms.biophysical.nn_global_epistasis`
as the epistatic model, this is the number of hidden units
Expand Down Expand Up @@ -145,15 +149,16 @@ class Model:
wts sites muts times_seen_a times_seen_b beta_a beta_b shift_b \
mutation
M1E M 1 E 1 3 0.0 0.0 0.0
M1W M 1 W 1 0 0.0 0.0 0.0
G3P G 3 P 1 4 0.0 -0.0 0.0
G3R G 3 R 1 2 0.0 0.0 0.0
predicted_func_score_a predicted_func_score_b
M1W M 1 W 1 0 0.0 -0.0 0.0
G3P G 3 P 1 4 -0.0 -0.0 -0.0
G3R G 3 R 1 2 -0.0 0.0 -0.0
<BLANKLINE>
predicted_func_score_a predicted_func_score_b
mutation
M1E 0.0 0.0
M1W 0.0 0.0
G3P 0.0 0.0
G3R 0.0 0.0
M1E 0.0 0.0
M1W 0.0 0.0
G3P 0.0 0.0
G3R 0.0 0.0
Notice the respective single mutation effects (``"beta"``), conditional shifts
(``shift_d``),
Expand Down Expand Up @@ -214,6 +219,7 @@ def __init__(
n_hidden_units=5,
init_theta_scale=5.0,
init_theta_bias=-5.0,
init_beta_variance=0.0,
name=None,
):
"""See class docstring."""
Expand All @@ -229,15 +235,21 @@ def __init__(
# as defined in multidms.biophysical.additive_model
latent_model = multidms.biophysical.additive_model
if latent_model == multidms.biophysical.additive_model:
n_beta_shift = len(self._data.mutations)
self._scaled_data_params["beta0"] = {
cond: jnp.zeros(shape=(1,)) for cond in data.conditions
}

n_beta_shift = len(self._data.mutations)
beta_keys = jax.random.split(key, num=len(self.data.conditions))
self._scaled_data_params["beta"] = {
cond: jnp.zeros(shape=(n_beta_shift,)) for cond in data.conditions
cond: init_beta_variance
* jax.random.normal(shape=(n_beta_shift,), key=ikey)
for cond, ikey in zip(data.conditions, beta_keys)
}
self._scaled_data_params["shift"] = {
cond: jnp.zeros(shape=(n_beta_shift,)) for cond in data.conditions
cond: self._scaled_data_params["beta"][self.data.reference]
- self._scaled_data_params["beta"][cond]
for cond in data.conditions
}
# GAMMA
# self._params["gamma"] = {
Expand Down Expand Up @@ -1108,10 +1120,8 @@ def fit(
hyperparams_prox = (
upper_bound_ge_scale,
lock_params,
self.data.bundle_idxs,
)
# compiled_proximal = jax.jit(self._model_components["proximal"])
compiled_proximal = self._model_components["proximal"]
compiled_proximal = jax.jit(self._model_components["proximal"])

solver = ProximalGradient(
compiled_objective,
Expand Down
5 changes: 5 additions & 0 deletions multidms/model_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def fit_one_model(
# gamma_corrected=False, # GAMMA
init_theta_scale=6.5,
init_theta_bias=-3.5,
init_beta_variance=1.0,
n_hidden_units=5,
lower_bound=None,
PRNGKey=0,
Expand All @@ -79,6 +80,9 @@ def fit_one_model(
The scale to use for initializing the model parameters. The default is 6.5.
init_theta_bias : float, optional
The bias to use for initializing the model parameters. The default is -3.5.
init_beta_variance : float, optional
The variance to use for initializing the model's beta parameters from a normal
distribution. The default is 1.0.
n_hidden_units : int, optional
The number of hidden units to use in the neural network model. The default is 5.
lower_bound : float, optional
Expand Down Expand Up @@ -125,6 +129,7 @@ def fit_one_model(
output_activation=biophysical_model[output_activation],
init_theta_scale=init_theta_scale,
init_theta_bias=init_theta_bias,
init_beta_variance=init_beta_variance,
# gamma_corrected=gamma_corrected, GAMMA
n_hidden_units=n_hidden_units,
lower_bound=lower_bound,
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "multidms"
version = "1.0.0"
version = "1.1.0"
description = "Joint modeling of multiple deep mutational scanning experiments."
readme = "README.md"
authors = [
Expand Down Expand Up @@ -111,7 +111,7 @@ repository = "https://github.com/matsengrp/multidms"
packages = ["multidms"]

[tool.bumpver]
current_version = "1.0.0"
current_version = "1.1.0"
version_pattern = "MAJOR.MINOR.PATCH"
commit_message = "bump version {old_version} -> {new_version}"
commit = true
Expand Down
21 changes: 11 additions & 10 deletions tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,7 @@ def test_linear_model_multi_cond_fit_simple():
reference="a",
assert_site_integrity=False,
)
assert np.all([not bi for bi in list(data.bundle_idxs["a"])])
model = multidms.Model(data, multidms.biophysical.identity_activation, PRNGKey=23)
model.fit(maxiter=2, warn_unconverged=False)

Expand Down Expand Up @@ -560,12 +561,12 @@ def test_model_phenotype_predictions():
external_pred = model.add_phenotypes_to_df(
TEST_FUNC_SCORES, unknown_as_nan=True, phenotype_as_effect=False
).dropna()
assert np.all(
internal_pred.predicted_latent.values == external_pred.predicted_latent.values
assert np.allclose(
internal_pred.predicted_latent.values, external_pred.predicted_latent.values
)
assert np.all(
internal_pred.predicted_func_score.values
== external_pred.predicted_func_score.values
assert np.allclose(
internal_pred.predicted_func_score.values,
external_pred.predicted_func_score.values,
)


Expand All @@ -580,12 +581,12 @@ def test_model_phenotype_effect_predictions():
external_pred = model.add_phenotypes_to_df(
TEST_FUNC_SCORES, unknown_as_nan=True, phenotype_as_effect=True
).dropna()
assert np.all(
internal_pred.predicted_latent.values == external_pred.predicted_latent.values
assert np.allclose(
internal_pred.predicted_latent.values, external_pred.predicted_latent.values
)
assert np.all(
internal_pred.predicted_func_score.values
== external_pred.predicted_func_score.values
assert np.allclose(
internal_pred.predicted_func_score.values,
external_pred.predicted_func_score.values,
)


Expand Down

0 comments on commit 23dd3ed

Please sign in to comment.