Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jaxmodels #164

Draft
wants to merge 17 commits into
base: main
Choose a base branch
from
2 changes: 1 addition & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ and how much the effects differ between experiments.

- The source code is `on GitHub <https://github.com/matsengrp/multidms>`_.

- For questions or inquaries about the software please `raise an issue <https://github.com/matsengrp/multidms/issues>`_, or contact jgallowa \<at\> fredhutch.org.
- For questions or inquiries about the software please `raise an issue <https://github.com/matsengrp/multidms/issues>`_, or contact jgallowa \<at\> fredhutch.org.

.. toctree::
:hidden:
Expand Down
22 changes: 11 additions & 11 deletions multidms/biophysical.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ def proximal_box_constraints(params, hyperparameters, *args, **kwargs):
def proximal_objective(Dop, params, hyperparameters, scaling=1.0):
"""ADMM generalized lasso optimization."""
(
scale_coeff_lasso_shift,
coef_lasso_shift,
admm_niter,
admm_tau,
admm_mu,
Expand All @@ -368,7 +368,7 @@ def proximal_objective(Dop, params, hyperparameters, scaling=1.0):
# see https://pyproximal.readthedocs.io/en/stable/index.html
beta_ravel, shift_ravel = pyproximal.optimization.primal.LinearizedADMM(
pyproximal.L2(b=beta_ravel),
pyproximal.L1(sigma=scaling * scale_coeff_lasso_shift),
pyproximal.L1(sigma=scaling * coef_lasso_shift),
Dop,
niter=admm_niter,
tau=admm_tau,
Expand Down Expand Up @@ -412,9 +412,9 @@ def smooth_objective(
f,
params,
data,
scale_coeff_ridge_beta=0.0,
scale_coeff_ridge_ge_scale=0.0,
scale_coeff_ridge_ge_bias=0.0,
coef_ridge_beta=0.0,
coef_ridge_ge_scale=0.0,
coef_ridge_ge_bias=0.0,
huber_scale=1,
**kwargs,
):
Expand All @@ -432,11 +432,11 @@ def smooth_objective(
return the respective binarymap and the row associated target functional scores
huber_scale : float
Scale parameter for Huber loss function
scale_coeff_ridge_beta : float
coef_ridge_beta : float
Ridge penalty coefficient for shift parameters
scale_coeff_ridge_ge_scale : float
coef_ridge_ge_scale : float
Ridge penalty coefficient for global epistasis scale parameter
scale_coeff_ridge_ge_bias : float
coef_ridge_ge_bias : float
Ridge penalty coefficient for global epistasis bias parameter
kwargs : dict
Additional keyword arguments to pass to the biophysical model function
Expand Down Expand Up @@ -476,15 +476,15 @@ def smooth_objective(

# compute a regularization term that penalizes non-zero
# parameters and add it to the loss function
beta_ridge_penalty += scale_coeff_ridge_beta * (d_params["beta"] ** 2).sum()
beta_ridge_penalty += coef_ridge_beta * (d_params["beta"] ** 2).sum()

huber_cost /= len(X)

ge_scale_ridge_penalty = (
scale_coeff_ridge_ge_scale * (params["theta"]["ge_scale"] ** 2).sum()
coef_ridge_ge_scale * (params["theta"]["ge_scale"] ** 2).sum()
)
ge_bias_ridge_penalty = (
scale_coeff_ridge_ge_bias * (params["theta"]["ge_bias"] ** 2).sum()
coef_ridge_ge_bias * (params["theta"]["ge_bias"] ** 2).sum()
)

return (
Expand Down
59 changes: 32 additions & 27 deletions multidms/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
dms experiments under various conditions.
"""

import os
from functools import partial, cached_property
import warnings

Expand All @@ -26,7 +25,6 @@
import seaborn as sns
from jax.experimental import sparse
from matplotlib import pyplot as plt
from pandarallel import pandarallel

jax.config.update("jax_enable_x64", True)

Expand Down Expand Up @@ -97,9 +95,6 @@ class notes.
have the same wild type amino acid, grouped by condition.
verbose : bool
If True, will print progress bars.
nb_workers : int
Number of workers to use for parallel operations.
If None, will use all available CPUs.
name : str or None
Name of the data object. If None, will be assigned
a unique name based upon the number of data objects
Expand Down Expand Up @@ -197,7 +192,6 @@ def __init__(
letter_suffixed_sites=False,
assert_site_integrity=False,
verbose=False,
nb_workers=None,
name=None,
):
"""See main class docstring."""
Expand Down Expand Up @@ -252,7 +246,13 @@ def __init__(
self._mutparser = MutationParser(alphabet, letter_suffixed_sites)

# Configure new variants df
cols = ["condition", "aa_substitutions", "func_score"]
cols = [
"condition",
"aa_substitutions",
"func_score",
"pre_count",
"post_count",
]
if "weight" in variants_df.columns:
cols.append(
"weight"
Expand Down Expand Up @@ -311,11 +311,6 @@ def __init__(
sites_to_throw = na_rows[na_rows].index
site_map.dropna(inplace=True)

nb_workers = min(os.cpu_count(), 4) if nb_workers is None else nb_workers
pandarallel.initialize(
progress_bar=verbose, verbose=0 if not verbose else 2, nb_workers=nb_workers
)

def flags_invalid_sites(disallowed_sites, sites_list):
"""Check to see if a sites list contains
any disallowed sites
Expand All @@ -325,7 +320,7 @@ def flags_invalid_sites(disallowed_sites, sites_list):
return False
return True

df["allowed_variant"] = df.sites.parallel_apply(
df["allowed_variant"] = df.sites.apply(
lambda sl: flags_invalid_sites(sites_to_throw, sl)
)
if verbose:
Expand Down Expand Up @@ -394,7 +389,7 @@ def get_nis_from_site_map(site_map):
invalid_nim.append(site)

# find variants that contain mutations at invalid sites
df["allowed_variant"] = df.sites.parallel_apply(
df["allowed_variant"] = df.sites.apply(
lambda sl: flags_invalid_sites(invalid_nim, sl)
)
if verbose:
Expand Down Expand Up @@ -440,7 +435,7 @@ def get_nis_from_site_map(site_map):
continue

idx = condition_func_df.index
df.loc[idx, "var_wrt_ref"] = condition_func_df.parallel_apply(
df.loc[idx, "var_wrt_ref"] = condition_func_df.apply(
lambda x: self._convert_split_subs_wrt_ref_seq(
condition, x.wts, x.sites, x.muts
),
Expand All @@ -452,9 +447,9 @@ def get_nis_from_site_map(site_map):

# Make BinaryMap representations for each condition
allowed_subs = {s for subs in df.var_wrt_ref for s in subs.split()}
binmaps, X, y, w = {}, {}, {}, {}
binmaps, X, y, w, pre_count, post_count = {}, {}, {}, {}, {}, {}
self._bundle_idxs = {}
self._scaled_training_data = {"X": {}, "y": y, "w": w}
self._scaled_arrays = {"X": {}, "y": y, "w": w}
for condition, condition_func_score_df in df.groupby("condition"):
cond_bmap = bmap.BinaryMap(
condition_func_score_df,
Expand All @@ -466,6 +461,12 @@ def get_nis_from_site_map(site_map):
binmaps[condition] = cond_bmap
X[condition] = sparse.BCOO.from_scipy_sparse(cond_bmap.binary_variants)
y[condition] = jnp.array(condition_func_score_df["func_score"].values)
pre_count[condition] = jnp.array(
condition_func_score_df["pre_count"].values
)
post_count[condition] = jnp.array(
condition_func_score_df["post_count"].values
)
if "weight" in condition_func_score_df.columns:
w[condition] = jnp.array(condition_func_score_df["weight"].values)

Expand All @@ -480,7 +481,7 @@ def get_nis_from_site_map(site_map):
for idx in range(len(cond_bmap.all_subs))
]
)
self._scaled_training_data["X"][condition] = rereference(
self._scaled_arrays["X"][condition] = rereference(
X[condition], self._bundle_idxs[condition]
)

Expand All @@ -494,9 +495,7 @@ def get_nis_from_site_map(site_map):
for condition in self._conditions:
# compute times seen in data
# compute the sum of each mutation (column) in the scaled data
times_seen = pd.Series(
self._scaled_training_data["X"][condition].sum(0).todense()
)
times_seen = pd.Series(self._scaled_arrays["X"][condition].sum(0).todense())
times_seen.index = cond_bmap.all_subs

assert (times_seen == times_seen.astype(int)).all()
Expand All @@ -505,7 +504,13 @@ def get_nis_from_site_map(site_map):
mut_df = mut_df.merge(times_seen, on="mutation", how="left") # .fillna(0)

# set training data properties
self._training_data = {"X": X, "y": y, "w": w}
self._arrays = {
"X": X,
"y": y,
"w": w,
"pre_count": pre_count,
"post_count": post_count,
}
self._binarymaps = binmaps

self._mutations_df = mut_df
Expand Down Expand Up @@ -614,14 +619,14 @@ def reference_sequence_conditions(self) -> list:
return self._reference_sequence_conditions

@property
def training_data(self) -> dict:
def arrays(self) -> dict:
"""A dictionary with keys 'X' and 'y' for the training data."""
return self._training_data
return self._arrays

@property
def scaled_training_data(self) -> dict:
def scaled_arrays(self) -> dict:
"""A dictionary with keys 'X' and 'y' for the scaled training data."""
return self._scaled_training_data
return self._scaled_arrays

@property
def binarymaps(self) -> dict:
Expand All @@ -634,7 +639,7 @@ def binarymaps(self) -> dict:
@property
def targets(self) -> dict:
"""The functional scores for each variant in the training data."""
return self._training_data["y"]
return self._arrays["y"]

@property
def mutparser(self) -> MutationParser:
Expand Down
Loading
Loading