Skip to content

Commit

Permalink
comments
Browse files Browse the repository at this point in the history
  • Loading branch information
wsdewitt committed Jul 26, 2024
1 parent 7b68b45 commit d01e5ab
Show file tree
Hide file tree
Showing 2 changed files with 211 additions and 1,043 deletions.
14 changes: 7 additions & 7 deletions multidms/jaxmodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,21 +40,21 @@ class Data(eqx.Module):

def __init__(self, multidms_data: multidms.Data, condition: str) -> None:
X = multidms_data.arrays["X"][condition]
not_wt = X.indices[:, 0] != 0
not_wt = X.indices[:, 0] != 0 # assumes WT is the first
sparse_data = X.data[not_wt]
sparse_idxs = X.indices[not_wt]
sparse_idxs = sparse_idxs.at[:, 0].add(-1)
sparse_idxs = sparse_idxs.at[:, 0].add(-1) # assumes WT is the first
X = jax.experimental.sparse.BCOO(
(sparse_data, sparse_idxs), shape=(X.shape[0] - 1, X.shape[1])
)

self.x_wt = multidms_data.arrays["X"][condition][0].todense()
self.pre_count_wt = multidms_data.arrays["pre_count"][condition][0]
self.post_count_wt = multidms_data.arrays["post_count"][condition][0]
self.pre_count_wt = multidms_data.arrays["pre_count"][condition][0] # assumes WT is the first
self.post_count_wt = multidms_data.arrays["post_count"][condition][0] # assumes WT is the first
self.X = X
self.pre_counts = multidms_data.arrays["pre_count"][condition][1:]
self.post_counts = multidms_data.arrays["post_count"][condition][1:]
self.functional_scores = multidms_data.arrays["y"][condition][1:]
self.pre_counts = multidms_data.arrays["pre_count"][condition][1:] # assumes WT is the first
self.post_counts = multidms_data.arrays["post_count"][condition][1:] # assumes WT is the first
self.functional_scores = multidms_data.arrays["y"][condition][1:] # assumes WT is the first


class Latent(eqx.Module):
Expand Down
Loading

0 comments on commit d01e5ab

Please sign in to comment.