Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

latest before moving spike analysis #141

Merged
merged 2 commits into from
Mar 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 0 additions & 8 deletions multidms/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,7 +599,6 @@ def targets(self) -> dict:
"""The functional scores for each variant in the training data."""
return self._training_data["y"]

# TODO, rename mutparser
@property
def mutparser(self) -> MutationParser:
"""
Expand All @@ -608,7 +607,6 @@ def mutparser(self) -> MutationParser:
"""
return self._mutparser

# TODO, rename
@property
def parse_mut(self) -> MutationParser:
"""
Expand All @@ -618,7 +616,6 @@ def parse_mut(self) -> MutationParser:
"""
return self.mutparser.parse_mut

# TODO, document rename issue
@property
def parse_muts(self) -> partial:
"""
Expand All @@ -628,11 +625,6 @@ def parse_muts(self) -> partial:
"""
return self._parse_muts

# TODO should this be cached? how does caching interact with the way in
# which we applying this function in parallel?
# although, unless the variants are un-collapsed, this cache will be
# pretty useless.
# although it could be useful for the Model.add_phenotypes_to_df method.
def convert_subs_wrt_ref_seq(self, condition, aa_subs):
"""
Covert amino acid substitutions to be with respect to the reference sequence.
Expand Down
64 changes: 56 additions & 8 deletions multidms/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def __init__(
epistatic_model=multidms.biophysical.sigmoidal_global_epistasis,
output_activation=multidms.biophysical.identity_activation,
conditional_shifts=True,
alpha_d=False, # TODO raise issue to be squashed in this PR
alpha_d=False,
gamma_corrected=False,
PRNGKey=0,
init_beta_naught=0.0,
Expand Down Expand Up @@ -375,6 +375,25 @@ def loss(self) -> float:
data = (self.data.training_data["X"], self.data.training_data["y"])
return jax.jit(self.model_components["objective"])(self.params, data, **kwargs)

@property
def conditional_loss(self) -> float:
"""Compute loss individually for each condition."""
kwargs = {
"scale_coeff_ridge_beta": 0.0,
"scale_coeff_ridge_shift": 0.0,
"scale_coeff_ridge_gamma": 0.0,
"scale_ridge_alpha_d": 0.0,
}

X, y = self.data.training_data["X"], self.data.training_data["y"]
loss_fxn = jax.jit(self.model_components["objective"])
ret = {}
for condition in self.data.conditions:
condition_data = ({condition: X[condition]}, {condition: y[condition]})
ret[condition] = float(loss_fxn(self.params, condition_data, **kwargs))
ret["total"] = sum(ret.values())
return ret

@property
def variants_df(self):
"""
Expand Down Expand Up @@ -546,7 +565,7 @@ def get_mutations_df(

return mutations_df[col_order]

def get_df_loss(self, df, error_if_unknown=False, verbose=False):
def get_df_loss(self, df, error_if_unknown=False, verbose=False, conditional=False):
"""
Get the loss of the model on a given data frame.

Expand All @@ -563,10 +582,13 @@ def get_df_loss(self, df, error_if_unknown=False, verbose=False):
in the loss calculation. If `True`, raise an error.
verbose : bool
If True, print the number of valid and invalid variants.
conditional : bool
If True, return the loss for each condition as a dictionary.
If False, return the total loss.

Returns
-------
float
float or dict
The loss of the model on the given data frame.
"""
substitutions_col = "aa_substitutions"
Expand All @@ -579,8 +601,11 @@ def get_df_loss(self, df, error_if_unknown=False, verbose=False):
if condition_col not in df.columns:
raise ValueError("`df` lacks `condition_col` " f"{condition_col}")

X, y = {}, {}
loss_fxn = jax.jit(self.model_components["objective"])

ret = {}
for condition, condition_df in df.groupby(condition_col):
X, y = {}, {}
variant_subs = condition_df[substitutions_col]
if condition not in self.data.reference_sequence_conditions:
variant_subs = condition_df.apply(
Expand All @@ -592,14 +617,23 @@ def get_df_loss(self, df, error_if_unknown=False, verbose=False):

# build binary variants as csr matrix, make prediction, and append
valid, invalid = 0, 0 # row indices of elements that are one
binary_variants = []
# binary_variants = []
variant_targets = []
row_ind = [] # row indices of elements that are one
col_ind = [] # column indices of elements that are one

for subs, target in zip(variant_subs, condition_df[func_score_col]):
try:
binary_variants.append(ref_bmap.sub_str_to_binary(subs))
# binary_variants.append(ref_bmap.sub_str_to_binary(subs))
# variant_targets.append(target)
# valid += 1

for isub in ref_bmap.sub_str_to_indices(subs):
row_ind.append(valid)
col_ind.append(isub)
variant_targets.append(target)
valid += 1

except ValueError:
if error_if_unknown:
raise ValueError(
Expand All @@ -615,12 +649,26 @@ def get_df_loss(self, df, error_if_unknown=False, verbose=False):
f"{valid}, n invalid variants: {invalid}"
)

# X[condition] = sparse.BCOO.from_scipy_sparse(
# scipy.sparse.csr_matrix(onp.vstack(binary_variants))
# )
X[condition] = sparse.BCOO.from_scipy_sparse(
scipy.sparse.csr_matrix(onp.vstack(binary_variants))
scipy.sparse.csr_matrix(
(onp.ones(len(row_ind), dtype="int8"), (row_ind, col_ind)),
shape=(valid, ref_bmap.binarylength),
dtype="int8",
)
)

y[condition] = jnp.array(variant_targets)

return self.model_components["objective"](self.params, (X, y))
ret[condition] = float(loss_fxn(self.params, (X, y)))

ret["total"] = sum(ret.values())

if not conditional:
return ret["total"]
return ret

def add_phenotypes_to_df(
self,
Expand Down
124 changes: 102 additions & 22 deletions multidms/model_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,8 +396,19 @@ def __init__(self, fit_models):
)
all_mutations = set.union(all_mutations, set(fit.data.mutations))

# add the final training loss to the fit_models dataframe
fit_models["training_loss"] = fit_models.step_loss.apply(lambda x: x[-1])
# initialize empty columns for conditional loss
fit_models.assign(
**{
f"{condition}_loss_training": onp.nan
for condition in first_dataset.conditions
},
total_loss=onp.nan,
)
# assign coditional loss columns
for idx, fit in fit_models.iterrows():
conditional_loss = fit.model.conditional_loss
for condition, loss in conditional_loss.items():
fit_models.loc[idx, f"{condition}_loss_training"] = loss

self._site_map_union = site_map_union
self._conditions = first_dataset.conditions
Expand Down Expand Up @@ -432,7 +443,6 @@ def all_mutations(self) -> tuple:
"""The mutations shared by each fitting dataset."""
return self._all_mutations

# TODO remove verbose everywhere
@lru_cache(maxsize=10)
def split_apply_combine_muts(
self,
Expand Down Expand Up @@ -482,32 +492,52 @@ def split_apply_combine_muts(
A dataframe containing the aggregated mutational parameter values
"""
print("cache miss - this could take a moment")
queried_fits = (
self.fit_models.query(query) if query is not None else self.fit_models
)
if len(queried_fits) == 0:
raise ValueError("invalid query, no fits returned")

if groupby is None:
groupby = tuple(
set(self.fit_models.columns)
- set(["model", "data", "step_loss", "verbose"])
# groupby = tuple(
# set(queried_fits.columns)
# - set(
# ["model", "dataset_name", "verbose"]
# + [col for col in queried_fits.columns if "loss" in col]
# )
# )
ret = (
pd.concat(
[
fit["model"].get_mutations_df(return_split=False, **kwargs)
for _, fit in queried_fits.iterrows()
],
join="inner", # the columns will always match based on class req.
)
.query(
f"mutation.isin({list(self.shared_mutations)})"
if inner_merge_dataset_muts
else "mutation.notna()"
)
.groupby("mutation")
.aggregate(aggregate_func)
)
return ret

elif isinstance(groupby, str):
groupby = tuple([groupby])

elif isinstance(groupby, tuple):
if not all(feature in self.fit_models.columns for feature in groupby):
if not all(feature in queried_fits.columns for feature in groupby):
raise ValueError(
f"invalid groupby, values must be in {self.fit_models.columns}"
)
else:
raise ValueError(
"invalid groupby, must be tuple with values "
f"in {self.fit_models.columns}"
f"in {queried_fits.columns}"
)

queried_fits = (
self.fit_models.query(query) if query is not None else self.fit_models
)
if len(queried_fits) == 0:
raise ValueError("invalid query, no fits returned")

ret = pd.concat(
[
pd.concat(
Expand Down Expand Up @@ -566,20 +596,69 @@ def add_validation_loss(self, test_data, overwrite=False):
# check there's a testing dataframe for each unique dataset_name
assert set(test_data.keys()) == set(self.fit_models["dataset_name"].unique())

if "validation_loss" in self.fit_models.columns and not overwrite:
validation_cols_exist = onp.any(
[
f"{condition}_loss_validation" in self.fit_models.columns
for condition in self.conditions
]
)
if validation_cols_exist and not overwrite:
raise ValueError(
"validation_loss already exists in self.fit_models, set overwrite=True "
"to overwrite"
)

self.fit_models["validation_loss"] = onp.nan
self.fit_models = self.fit_models.assign(
**{
f"{condition}_loss_validation": onp.nan for condition in self.conditions
},
total_loss_validation=onp.nan,
)

for idx, fit in self.fit_models.iterrows():
self.fit_models.loc[idx, "validation_loss"] = fit["model"].get_df_loss(
test_data[fit["dataset_name"]]
condional_df_loss = fit.model.get_df_loss(
test_data[fit["dataset_name"]], conditional=True
)
for condition, loss in condional_df_loss.items():
self.fit_models.loc[idx, f"{condition}_loss_validation"] = loss

return None

def get_conditional_loss_df(self, query=None):
"""
return a long form dataframe with columns
"dataset_name", "scale_coeff_lasso_shift",
"split" ("training" or "validation"),
"loss" (actual value), and "condition".

Parameters
----------
query : str, optional
The query to apply to the fit_models dataframe
before formatting the loss dataframe. The default is None.
"""
if query is not None:
queried_fits = self.fit_models.query(query)
else:
queried_fits = self.fit_models
if len(queried_fits) == 0:
raise ValueError("invalid query, no fits returned")

id_vars = ["dataset_name", "scale_coeff_lasso_shift"]
value_vars = [
c for c in queried_fits.columns if "loss" in c and c != "step_loss"
]
loss_df = queried_fits.melt(
id_vars=id_vars,
value_vars=value_vars,
var_name="condition",
value_name="loss",
).assign(
split=lambda x: x.condition.str.split("_").str.get(-1),
condition=lambda x: x.condition.str.split("_").str[:-2].str.join("_"),
)
return loss_df

def mut_param_heatmap(
self,
query=None,
Expand Down Expand Up @@ -652,7 +731,11 @@ def mut_param_heatmap(
if len(queried_fits) == 0:
raise ValueError("invalid query, no fits returned")
shouldbe_uniform = list(
set(queried_fits.columns) - set(["model", "dataset_name", "step_loss"])
set(queried_fits.columns)
- set(
["model", "dataset_name"]
+ [col for col in queried_fits.columns if "loss" in col]
)
)
if len(queried_fits.groupby(list(shouldbe_uniform)).groups) > 1:
raise ValueError(
Expand Down Expand Up @@ -921,9 +1004,6 @@ def mut_type(mut):
return "stop" if mut.endswith("*") else "nonsynonymous"

# apply, drop, and melt
# TODO This throws deprecation warning
# because of the include_groups argument ...
# set to False, and lose the drop call after ...
sparsity_df = (
df.drop(columns=to_throw)
.assign(mut_type=lambda x: x.mutation.apply(mut_type))
Expand Down
Loading
Loading