Skip to content

Commit

Permalink
revert _affinit_tools.py for previous label transfer method
Browse files Browse the repository at this point in the history
  • Loading branch information
dawe committed Nov 9, 2023
1 parent 1b61346 commit 0342206
Showing 1 changed file with 36 additions and 61 deletions.
97 changes: 36 additions & 61 deletions schist/tools/_affinity_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from scanpy import logging as logg
import graph_tool.all as gt
import pandas as pd
from .._utils import get_cell_loglikelihood, get_cell_back_p, state_from_blocks, get_graph_tool_from_adata
from .._utils import get_cell_loglikelihood, get_cell_back_p, state_from_blocks
from scanpy._utils import get_igraph_from_adjacency, _choose_graph


Expand Down Expand Up @@ -358,8 +358,7 @@ def label_transfer(
adata_ref: Optional[AnnData] = None,
obs: Optional[str] = None,
label_unk: Optional[str] = 'unknown',
#use_best: Optional[bool] = False,
keep_old: Optional[bool] = True,
use_best: Optional[bool] = False,
neighbors_key: Optional[str] = 'neighbors',
adjacency: Optional[sparse.spmatrix] = None,
directed: bool = False,
Expand All @@ -372,7 +371,7 @@ def label_transfer(
) -> Optional[AnnData]:

"""\
Transfer annotation from one dataset to another using a SBM.
Transfer annotation from one dataset to another using cell affinities.
If two datasets are given, it uses harmony to perform
integration and then the kNN graph. If only no reference is given, it is assumed
that the only adata already contains the proper kNN graph and that
Expand All @@ -392,11 +391,10 @@ def label_transfer(
The label for unassigned cells. If no `adata_ref` is given, this label
identifies cells to be assigned in `adata`. If `adata_ref` is given, this
label will be given to all cells that cannot be assigned.
keep_old
Labels are assigned using a MCMC without merge/split step, only allowing moves.
This implies that cells with known label can be moved as well. Once the MCMC
has converged, this option sets initial labels to cells that did not need
relabeling.
use_best
When assigning labels, some cells may have not enough evidence and, therefore,
left `unknown`. If this parameter is set to `True`, all cells will be assigned
to the best possible, even if it may not be optimal
neighbors_key
Use neighbors connectivities as adjacency.
If not specified, leiden looks .obsp['connectivities'] for connectivities
Expand Down Expand Up @@ -523,60 +521,35 @@ def label_transfer(

adata_merge.obs[obs] = adata_merge.obs[obs].fillna(label_unk)

# get the kNN graph
g = get_graph_tool_from_adata(adata_merge, neighbors_key=neighbors_key)
# calculate affinity

# get the blocks. Here the unknown label is already present
# we need blocks as integers otherwise graphtool can't handle them
blocks = np.array(adata_merge.obs[obs].cat.codes)
old_blocks = blocks.copy()
new_blocks = blocks.copy()
block_dict = dict(zip(adata_merge.obs[obs], adata_merge.obs[obs].cat.codes))
label_dict = dict(zip(adata_merge.obs[obs].cat.codes, adata_merge.obs[obs]))
calculate_affinity(adata_merge, group_by=obs, neighbors_key=neighbors_key)

known_blocks = [block_dict[x] for x in block_dict if x != label_unk]
known_labels = [label_dict[x] for x in sorted(known_blocks)]
# now work on affinity, rank it to get the new labels
categories = adata_merge.obs[obs].cat.categories
affinity = pd.DataFrame(adata_merge.obsm[f'CA_{obs}'],
index=adata_merge.obs_names, columns=categories)
# if use_best we need to remove label unknonw from the matrix so it
# does not get scored
if use_best:
affinity.drop(label_unk, axis='columns', inplace=True)

n_unknowns = adata_merge.obs['_label_transfer'].value_counts()['_unk']

adj_key = adata_merge.uns[neighbors_key]['connectivities_key']
adj_mat = adata_merge.obsp[adj_key]

unk_idx = np.where(adata_merge.obs['_label_transfer'] == '_unk')[0]
I = np.transpose(adj_mat[unk_idx].nonzero())
# assign to each unknown the most common block according to the adj matrix
# In reality this works also with random assignments

for x in range(n_unknowns):
x_blocks = old_blocks[I[I[:, 0] == x][:, 1]]
# exclude the unknowns
x_blocks[x_blocks != block_dict[label_unk]]
if len(x_blocks) == 0:
# handle the case of no blocks left
x_blocks = [known_blocks[np.random.choice(len(known_blocks))]]
b, c = np.unique(x_blocks, return_counts=True)
new_blocks[unk_idx[x]] = b[np.argmax(c)]

# time to MCMC
state = gt.BlockState(g, b=new_blocks)
fast_tol = 1e-6
dS = 1e9
n = 0
max_iter = 100
while (np.abs(dS) > fast_tol) and (n < max_iter):
# we essentially move stuff, don't look for new partitions
dS, _, _ = state.multiflip_mcmc_sweep(beta=np.inf, niter=10, c=0.2,
psplit=0, pmerge=0, pmergesplit=0)
n += 1

new_blocks = np.array(state.b.a)
if keep_old:
keep_idx = np.where(adata_merge.obs['_label_transfer'] == '_ref')[0]
new_blocks[keep_idx] = old_blocks[keep_idx]

# assign proper categories
adata_merge.obs[f'_{obs}_tmp'] = pd.Categorical([label_dict[x] for x in new_blocks], categories=known_labels)
rank_affinity = affinity.rank(axis=1, ascending=False)
adata_merge.obs[f'_{obs}_tmp'] = adata_merge.obs[obs].values
unk_cells = adata_merge.obs.query('_label_transfer == "_unk"').index
for c in rank_affinity.columns:
# pretty sure there's a way to do it without a
# for loop :-/ I really need a course on pandas
cells = rank_affinity[rank_affinity[c] == 1].index
# do not relabel known cells
cells = cells.intersection(unk_cells)
if len(cells) > 0:
adata_merge.obs.loc[cells, f'_{obs}_tmp'] = c

# do actual transfer to dataset 1
# here we assume that concatenation does not change the order of cells
# only cell names

labels = adata_merge.obs[f'_{obs}_tmp'].cat.categories
if adata_ref:
# transfer has been done between two files
Expand All @@ -585,12 +558,11 @@ def label_transfer(
# transfer is within dataset
adata_merge.obs[obs] = adata_merge.obs[f'_{obs}_tmp'].values
adata_merge.obs.drop(f'_{obs}_tmp', axis='columns', inplace=True)
adata_merge.obs.drop('_label_transfer', axis='columns', inplace=True)
adata = adata_merge

# ensure that it is categorical with proper order
adata.obs[obs] = pd.Categorical(adata.obs[obs], categories=labels)

# transfer colors if any
if adata_ref and f'{obs}_colors' in adata_ref.uns:
colors = list(adata_ref.uns[f'{obs}_colors'])
Expand All @@ -599,6 +571,9 @@ def label_transfer(
colors.append('#aabbcc')
adata.uns[f'{obs}_colors'] = colors

# remove unused categories if "use_best" hence no "unknown"
if use_best:
adata.obs[obs].cat.remove_unused_categories()

return adata if copy else None

0 comments on commit 0342206

Please sign in to comment.