Skip to content

Commit

Permalink
jaxmodels tidying and linting
Browse files Browse the repository at this point in the history
  • Loading branch information
wsdewitt committed Jul 30, 2024
1 parent d01e5ab commit 22fd4b4
Show file tree
Hide file tree
Showing 5 changed files with 296 additions and 190 deletions.
2 changes: 1 addition & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ and how much the effects differ between experiments.

- The source code is `on GitHub <https://github.com/matsengrp/multidms>`_.

- For questions or inquaries about the software please `raise an issue <https://github.com/matsengrp/multidms/issues>`_, or contact jgallowa \<at\> fredhutch.org.
- For questions or inquiries about the software please `raise an issue <https://github.com/matsengrp/multidms/issues>`_, or contact jgallowa \<at\> fredhutch.org.

.. toctree::
:hidden:
Expand Down
29 changes: 21 additions & 8 deletions multidms/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
dms experiments under various conditions.
"""

import os
from functools import partial, cached_property
import warnings

Expand Down Expand Up @@ -247,7 +246,13 @@ def __init__(
self._mutparser = MutationParser(alphabet, letter_suffixed_sites)

# Configure new variants df
cols = ["condition", "aa_substitutions", "func_score", "pre_count", "post_count"]
cols = [
"condition",
"aa_substitutions",
"func_score",
"pre_count",
"post_count",
]
if "weight" in variants_df.columns:
cols.append(
"weight"
Expand Down Expand Up @@ -456,8 +461,12 @@ def get_nis_from_site_map(site_map):
binmaps[condition] = cond_bmap
X[condition] = sparse.BCOO.from_scipy_sparse(cond_bmap.binary_variants)
y[condition] = jnp.array(condition_func_score_df["func_score"].values)
pre_count[condition] = jnp.array(condition_func_score_df["pre_count"].values)
post_count[condition] = jnp.array(condition_func_score_df["post_count"].values)
pre_count[condition] = jnp.array(
condition_func_score_df["pre_count"].values
)
post_count[condition] = jnp.array(
condition_func_score_df["post_count"].values
)
if "weight" in condition_func_score_df.columns:
w[condition] = jnp.array(condition_func_score_df["weight"].values)

Expand Down Expand Up @@ -486,9 +495,7 @@ def get_nis_from_site_map(site_map):
for condition in self._conditions:
# compute times seen in data
# compute the sum of each mutation (column) in the scaled data
times_seen = pd.Series(
self._scaled_arrays["X"][condition].sum(0).todense()
)
times_seen = pd.Series(self._scaled_arrays["X"][condition].sum(0).todense())
times_seen.index = cond_bmap.all_subs

assert (times_seen == times_seen.astype(int)).all()
Expand All @@ -497,7 +504,13 @@ def get_nis_from_site_map(site_map):
mut_df = mut_df.merge(times_seen, on="mutation", how="left") # .fillna(0)

# set training data properties
self._arrays = {"X": X, "y": y, "w": w, "pre_count": pre_count, "post_count": post_count}
self._arrays = {
"X": X,
"y": y,
"w": w,
"pre_count": pre_count,
"post_count": post_count,
}
self._binarymaps = binmaps

self._mutations_df = mut_df
Expand Down
Loading

0 comments on commit 22fd4b4

Please sign in to comment.