Skip to content

Commit

Permalink
added inference functionality to constraint of GE fx scaling (#146)
Browse files Browse the repository at this point in the history
  • Loading branch information
jgallowa07 authored Mar 22, 2024
1 parent 900cb53 commit 79dc61a
Show file tree
Hide file tree
Showing 4 changed files with 319 additions and 81 deletions.
9 changes: 5 additions & 4 deletions multidms/biophysical.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,8 +323,7 @@ def _abstract_epistasis(
def _lasso_lock_prox(
params,
hyperparams_prox=dict(
lasso_params=None,
lock_params=None,
lasso_params=None, lock_params=None, upper_bound_theta_ge_scale=None
),
scaling=1.0,
):
Expand All @@ -340,9 +339,11 @@ def _lasso_lock_prox(
scaling : float
Scaling factor for lasso penalty
"""
# enforce monotonic epistasis
# enforce monotonic epistasis and constrain ge_scale upper limit
if "ge_scale" in params["theta"]:
params["theta"]["ge_scale"] = params["theta"]["ge_scale"].clip(0)
params["theta"]["ge_scale"] = params["theta"]["ge_scale"].clip(
0, hyperparams_prox["upper_bound_theta_ge_scale"]
)

if "p_weights_1" in params["theta"]:
params["theta"]["p_weights_1"] = params["theta"]["p_weights_1"].clip(0)
Expand Down
46 changes: 37 additions & 9 deletions multidms/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,12 +193,14 @@ class Model:
Next, we fit the model with some chosen hyperparameters.
>>> model.fit(maxiter=1000, lasso_shift=1e-5)
>>> model.fit(maxiter=1000, lasso_shift=1e-5, warn_unconverged=False)
>>> model.loss
Array(6.0517805e-06, dtype=float32)
The model tunes its parameters in place, and the subsequent call to retrieve
the loss reflects our models loss given its updated parameters.
TODO: add more examples, and explain the convergence criteria and warning.
""" # noqa: E501

counter = 0
Expand All @@ -212,11 +214,11 @@ def __init__(
alpha_d=False,
gamma_corrected=False,
PRNGKey=0,
lower_bound=None,
n_hidden_units=5,
init_beta_naught=0.0,
init_theta_scale=5.0,
init_theta_bias=-5.0,
n_hidden_units=5,
lower_bound=None,
name=None,
):
"""See class docstring."""
Expand Down Expand Up @@ -267,6 +269,10 @@ def __init__(
self._params["theta"] = dict(ghost_param=jnp.zeros(shape=(1,)))

elif epistatic_model == multidms.biophysical.nn_global_epistasis:
if n_hidden_units is None:
raise ValueError(
"n_hidden_units must be specified for nn_global_epistasis"
)
key, key1, key2, key3, key4 = jax.random.split(key, num=5)
self._params["theta"] = dict(
p_weights_1=jax.random.normal(shape=(n_hidden_units,), key=key1).clip(
Expand Down Expand Up @@ -629,17 +635,12 @@ def get_df_loss(self, df, error_if_unknown=False, verbose=False, conditional=Fal

# build binary variants as csr matrix, make prediction, and append
valid, invalid = 0, 0 # row indices of elements that are one
# binary_variants = []
variant_targets = []
row_ind = [] # row indices of elements that are one
col_ind = [] # column indices of elements that are one

for subs, target in zip(variant_subs, condition_df[func_score_col]):
try:
# binary_variants.append(ref_bmap.sub_str_to_binary(subs))
# variant_targets.append(target)
# valid += 1

for isub in ref_bmap.sub_str_to_indices(subs):
row_ind.append(valid)
col_ind.append(isub)
Expand Down Expand Up @@ -973,6 +974,7 @@ def fit(
acceleration=True,
lock_params={},
warn_unconverged=True,
upper_bound_theta_ge_scale="infer",
**kwargs,
):
r"""
Expand All @@ -997,6 +999,13 @@ def fit(
convergence is defined by whether the model tolerance (''tol'') threshold
was passed during the optimization process.
Defaults to True.
upper_bound_theta_ge_scale : float, None, or 'infer'
The positive upper bound of the theta scale parameter -
negative values are not allowed.
Passing ``None`` allows the scale of the sigmoid to be unconstrained.
Passing the string literal 'infer' results in the
scale being set to double the range of the training data.
Defaults to 'infer'.
**kwargs : dict
Additional keyword arguments passed to the objective function.
These include hyperparameters like a ridge penalty on beta, shift, and gamma
Expand Down Expand Up @@ -1035,9 +1044,28 @@ def fit(
continue
lasso_params[f"shift_{non_ref_condition}"] = lasso_shift

if not isinstance(upper_bound_theta_ge_scale, (float, int, type(None), str)):
raise ValueError(
"upper_bound_theta_ge_scale must be a float, None, or 'infer'"
)
if isinstance(upper_bound_theta_ge_scale, (float, int)):
if upper_bound_theta_ge_scale < 0:
raise ValueError("upper_bound_theta_ge_scale must be non-negative")
# infer the range of the training data, and double it
# to set the upper bound of the theta scale parameter.
# see https://github.com/matsengrp/multidms/issues/143 for details
if upper_bound_theta_ge_scale == "infer":
y = jnp.concatenate(list(self.data.training_data["y"].values()))
y_range = y.max() - y.min()
upper_bound_theta_ge_scale = 2 * y_range

self._params, self._state = solver.run(
self._params,
hyperparams_prox=dict(lasso_params=lasso_params, lock_params=lock_params),
hyperparams_prox=dict(
lasso_params=lasso_params,
lock_params=lock_params,
upper_bound_theta_ge_scale=upper_bound_theta_ge_scale,
),
data=(self._data.training_data["X"], self._data.training_data["y"]),
**kwargs,
)
Expand Down
67 changes: 54 additions & 13 deletions multidms/model_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,28 +55,30 @@ def _explode_params_dict(params_dict):
]


# everything below verbose could be a kwargs passed to fit()
def fit_one_model(
dataset,
huber_scale_huber=1,
scale_coeff_lasso_shift=2e-5,
scale_coeff_ridge_beta=0,
scale_coeff_ridge_shift=0,
scale_coeff_ridge_gamma=0,
scale_coeff_ridge_alpha_d=0,
epistatic_model="Sigmoid",
output_activation="Identity",
lock_beta_naught_at=None,
gamma_corrected=False,
alpha_d=False,
init_beta_naught=0.0,
tol=1e-4,
num_training_steps=1,
iterations_per_step=20000,
acceleration=True,
n_hidden_units=5,
lower_bound=None,
PRNGKey=0,
verbose=False,
tol=1e-4,
huber_scale_huber=1,
scale_coeff_lasso_shift=2e-5,
scale_coeff_ridge_beta=0,
scale_coeff_ridge_shift=0,
scale_coeff_ridge_gamma=0,
scale_coeff_ridge_alpha_d=0,
num_training_steps=1,
iterations_per_step=20000,
upper_bound_theta_ge_scale="infer",
acceleration=True,
):
"""
Fit a multidms model to a dataset. This is a wrapper around the multidms
Expand Down Expand Up @@ -136,6 +138,8 @@ def fit_one_model(
lower_bound : float, optional
The lower bound for use with the softplus activation function.
The default is None, but must be specified if using the softplus activation.
upper_bound_theta_ge_scale : float, optional
The upper bound for the theta_ge_scale parameter. The default is None.
PRNGKey : int, optional
The PRNGKey to use to initialize model parameters. The default is 0.
verbose : bool, optional
Expand Down Expand Up @@ -164,6 +168,7 @@ def fit_one_model(
"Softplus": multidms.biophysical.softplus_activation,
}

# should these all be kwargs?
imodel = multidms.Model(
dataset,
epistatic_model=biophysical_model[epistatic_model],
Expand All @@ -185,6 +190,7 @@ def fit_one_model(
del fit_attributes["verbose"]

fit_attributes["step_loss"] = onp.repeat(onp.nan, num_training_steps + 1)
fit_attributes["step_error"] = onp.repeat(onp.nan, num_training_steps + 1)
fit_attributes["step_loss"][0] = float(imodel.loss)
fit_attributes["dataset_name"] = dataset.name
fit_attributes["model"] = imodel
Expand All @@ -208,6 +214,7 @@ def fit_one_model(
scale_coeff_ridge_beta=scale_coeff_ridge_beta,
scale_coeff_ridge_gamma=scale_coeff_ridge_gamma,
scale_coeff_ridge_alpha_d=scale_coeff_ridge_alpha_d,
upper_bound_theta_ge_scale=upper_bound_theta_ge_scale,
warn_unconverged=False,
)
end = time.time()
Expand Down Expand Up @@ -1067,6 +1074,7 @@ def mut_param_dataset_correlation(
width_scalar=150,
height=200,
return_data=False,
r=2,
**kwargs,
):
"""
Expand All @@ -1075,13 +1083,37 @@ def mut_param_dataset_correlation(
We compute correlation of mutation parameters accross each pair of datasets
in the collection.
Parameters
----------
x : str, optional
The parameter to plot on the x-axis.
The default is "scale_coeff_lasso_shift".
width_scalar : int, optional
The width of the chart. The default is 150.
height : int, optional
The height of the chart. The default is 200.
return_data : bool, optional
Whether to return the underlying data. The default is False.
r : int, optional
The exponential of the correlation coefficient reported.
May be either 1 for pearson,
2 for coefficient of determination (r-squared),
The default is 2.
**kwargs : dict
The keyword arguments to pass to the
:func:`multidms.model_collection.ModelCollection.split_apply_combine_muts`
method. See the method docstring for details.
Returns
-------
altair.Chart or Tuple(altair.Chart, pd.DataFrame)
A chart object which can be displayed in a jupyter notebook
or saved to a file. If `return_data=True`, then a tuple
containing the chart and the underlying data will be returned.
"""
if r not in [1, 2]:
raise ValueError("invalid r, must be 1 or 2")

query = "dataset_name.notna()" if "query" not in kwargs else kwargs["query"]
if len(self.fit_models.query(query).dataset_name.unique()) < 2:
raise ValueError("Must specify a subset of fits with multiple datasets")
Expand All @@ -1107,15 +1139,20 @@ def mut_param_dataset_correlation(
{
"datasets": ",".join(datasets),
"mut_param": mut_param,
"correlation": replicate_params_df.T.corr().iloc[0, 1],
"correlation": replicate_params_df.T.corr().iloc[0, 1] ** r,
x: x_i,
},
index=[0],
),
)

replicate_df = pd.concat(replicate_series)
# https://github.com/microsoft/pylance-release/issues/5630
def my_concat(dfs_list, axis=0):
return pd.concat(dfs_list, axis=axis)

replicate_df = my_concat(replicate_series)

title_suffix = "(R^2)" if r == 2 else "(pearson)"
# create altair chart
base_chart = (
alt.Chart(replicate_df)
Expand All @@ -1131,7 +1168,11 @@ def mut_param_dataset_correlation(
).axis(
format=".1e",
),
y=alt.Y("correlation", type="quantitative", title="Correlation"),
y=alt.Y(
"correlation",
type="quantitative",
title=f"Correlation {title_suffix}",
),
color=alt.Color("mut_param", type="nominal", title="Parameter"),
tooltip=["datasets", "correlation", "mut_param"],
)
Expand Down
Loading

0 comments on commit 79dc61a

Please sign in to comment.