Skip to content

Commit

Permalink
NB loss
Browse files Browse the repository at this point in the history
  • Loading branch information
wsdewitt committed Jul 26, 2024
1 parent ad05c8c commit 7b68b45
Show file tree
Hide file tree
Showing 2 changed files with 1,118 additions and 1,102 deletions.
37 changes: 25 additions & 12 deletions multidms/jaxmodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ class Data(eqx.Module):

x_wt: Int[Array, "n_mutations"]

Check failure on line 26 in multidms/jaxmodels.py

View workflow job for this annotation

GitHub Actions / build-and-test (macos-latest, 3.10)

Ruff (F821)

multidms/jaxmodels.py:26:23: F821 Undefined name `n_mutations`
"""Binary encoding of the wildtype sequence."""
pre_count_wt: Float[Array, ""]
pre_count_wt: Int[Array, ""]

Check failure on line 28 in multidms/jaxmodels.py

View workflow job for this annotation

GitHub Actions / build-and-test (macos-latest, 3.10)

Ruff (F722)

multidms/jaxmodels.py:28:30: F722 Syntax error in forward annotation: ``
"""Wildtype pre-selection count."""
post_count_wt: Float[Array, ""]
post_count_wt: Int[Array, ""]

Check failure on line 30 in multidms/jaxmodels.py

View workflow job for this annotation

GitHub Actions / build-and-test (macos-latest, 3.10)

Ruff (F722)

multidms/jaxmodels.py:30:31: F722 Syntax error in forward annotation: ``
"""Wildtype post-selection count."""
X: Int[Array, "n_variants n_mutations"]

Check failure on line 32 in multidms/jaxmodels.py

View workflow job for this annotation

GitHub Actions / build-and-test (macos-latest, 3.10)

Ruff (F722)

multidms/jaxmodels.py:32:19: F722 Syntax error in forward annotation: `n_variants n_mutations`
"""Variant encoding matrix (sparse format)."""
Expand Down Expand Up @@ -110,6 +110,8 @@ class Model(eqx.Module):
"""Latent models for each condition."""
α: dict[str, Float[Array, ""]]
"""Fitness-functional score scaling factors for each condition."""
logθ: dict[str, Float[Array, ""]]
"""Overdispersion parameter for each condition."""
reference_condition: str = eqx.field(static=True)
"""The condition to use as a reference."""

Expand Down Expand Up @@ -184,26 +186,28 @@ def loss(
self,
data_sets: dict[str, Data],
) -> dict[str, Float[Array, ""]]:
r"""Compute the loss.
r"""Count-based loss.
Args:
data_sets: Data sets for each condition.
"""
post_count_pred = self.predict_post_count(data_sets)
result = {}
for d in data_sets:
m_v = data_sets[d].post_counts
m_v_pred = post_count_pred[d]
result[d] = (m_v_pred - xlogy(m_v, m_v_pred) + gammaln(m_v + 1)).sum()
m = data_sets[d].post_counts
m_pred = post_count_pred[d]
logθ = self.logθ[d]
θ = jnp.exp(logθ)
result[d] = (-gammaln(m + θ) + gammaln(m + 1) + gammaln(θ) - θ * logθ - m * jnp.log(m_pred) + (θ + m) * jnp.log(θ + m_pred)).sum()
return result


def fit(
data_sets: dict[str, Data],
reference_condition: str,
l2reg_α=0.0,
l2reg=0.0,
fusionreg=0.0,
share_calibration=False,
opt_kwargs=dict(tol=1e-8, maxiter=1000),
) -> tuple[Model, jaxopt._src.proximal_gradient.ProxGradState]:
r"""
Expand All @@ -212,9 +216,9 @@ def fit(
Args:
data_sets: Data to fit to. Each key is a condition.
reference_condition: The condition to use as a reference.
l2reg_α: L2 regularization strength for α.
l2reg: L2 regularization strength for mutation effects.
fusionreg: Fusion (shift lasso) regularization strength.
share_calibration: Whether to share experimental calibration parameters across conditions.
opt_kwargs: Keyword arguments to pass to solver.
Returns:
Expand All @@ -227,24 +231,23 @@ def fit(

opt = jaxopt.ProximalGradient(
_objective_smooth_preconditioned,
prox=_prox,
prox=_prox_shared_calibration if share_calibration else _prox,
value_and_grad=True,
**opt_kwargs,
)

model = Model(
φ={d: Latent(data_sets[d], l2reg=l2reg) for d in data_sets},
α={d: jnp.ptp(data_sets[d].functional_scores) for d in data_sets},
logθ={d: 0.0 for d in data_sets},
reference_condition=reference_condition,
)

Ps = {d: jnp.diag(1 / (1 + data_sets[d].X.sum(axis=0).todense())) for d in data_sets}

hyperparameters = dict(fusionreg=fusionreg, Ps=Ps)
args = (data_sets, Ps)
kwargs = dict(l2reg_α=l2reg_α, l2reg=l2reg)

model, state = opt.run(model, hyperparameters, *args, **kwargs)
model, state = opt.run(model, hyperparameters, data_sets, Ps, l2reg=l2reg)

return model, state

Expand Down Expand Up @@ -301,3 +304,13 @@ def _prox(model, hyperparameters, scaling=1.0):
lambda model: model.α[d], model, jnp.clip(model.α[d], 0.0, jnp.inf)
)
return model


def _prox_shared_calibration(model, hyperparameters, scaling=1.0):
model = _prox(model, hyperparameters, scaling)
α_mean = sum(model.α.values()) / len(model.α)
logθ_mean = sum(model.logθ.values()) / len(model.logθ)
for d in model.α:
model = eqx.tree_at(lambda model_: model_.α[d], model, α_mean)
model = eqx.tree_at(lambda model_: model_.logθ[d], model, logθ_mean)
return model
2,183 changes: 1,093 additions & 1,090 deletions notebooks/jaxmodels/jaxmodels.ipynb

Large diffs are not rendered by default.

0 comments on commit 7b68b45

Please sign in to comment.