Skip to content

Commit

Permalink
Update tl.py
Browse files Browse the repository at this point in the history
  • Loading branch information
serjisa authored Jul 6, 2024
1 parent b81893a commit 8bdc3a9
Showing 1 changed file with 14 additions and 5 deletions.
19 changes: 14 additions & 5 deletions sclitr/tl.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,17 +220,26 @@ def clone2vec(

clone2vec = model.embedding.weight.data.cpu().numpy()
if not (fill_ct is None):
cell_counts = adata_only_clones.obs.groupby([fill_ct, obs_name]).size().unstack()[adata_only_clones.uns[f"{obsm_name}_names"]]
cell_counts = cell_counts.values
cell_counts = adata_only_clones.obs.groupby(
[fill_ct, obs_name]
).size().unstack()[adata_only_clones.uns[f"{obsm_name}_names"]]

var_names = list(cell_counts.index)
obs_names = list(cell_counts.columns)

cell_counts = cell_counts.values
freqs = cell_counts / cell_counts.sum(axis=0)
freqs = freqs.values
else:
cell_counts = np.array([0] * len(adata.uns[f"{obsm_name}_names"]))
freqs = np.array([0] * len(adata.uns[f"{obsm_name}_names"]))
var_names = ["None"]
obs_names = adata.uns[f"{obsm_name}_names"]

cell_counts = np.matrix([0] * len(adata.uns[f"{obsm_name}_names"]))
freqs = np.matrix([0] * len(adata.uns[f"{obsm_name}_names"]))

clones = sc.AnnData(
X=cell_counts.T,
obs=pd.DataFrame(index=obs_names),
var=pd.DataFrame(index=var_names),
layers={
"frequencies": freqs.T,
"counts": cell_counts.T,
Expand Down

0 comments on commit 8bdc3a9

Please sign in to comment.