Skip to content

Commit

Permalink
Merge pull request #68 from dawe/label_transfer
Browse files Browse the repository at this point in the history
Label transfer
  • Loading branch information
dawe authored Oct 12, 2023
2 parents 436cd4c + 59e83f8 commit 1b61346
Showing 1 changed file with 61 additions and 36 deletions.
97 changes: 61 additions & 36 deletions schist/tools/_affinity_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from scanpy import logging as logg
import graph_tool.all as gt
import pandas as pd
from .._utils import get_cell_loglikelihood, get_cell_back_p, state_from_blocks
from .._utils import get_cell_loglikelihood, get_cell_back_p, state_from_blocks, get_graph_tool_from_adata
from scanpy._utils import get_igraph_from_adjacency, _choose_graph


Expand Down Expand Up @@ -358,7 +358,8 @@ def label_transfer(
adata_ref: Optional[AnnData] = None,
obs: Optional[str] = None,
label_unk: Optional[str] = 'unknown',
use_best: Optional[bool] = False,
#use_best: Optional[bool] = False,
keep_old: Optional[bool] = True,
neighbors_key: Optional[str] = 'neighbors',
adjacency: Optional[sparse.spmatrix] = None,
directed: bool = False,
Expand All @@ -371,7 +372,7 @@ def label_transfer(
) -> Optional[AnnData]:

"""\
Transfer annotation from one dataset to another using cell affinities.
Transfer annotation from one dataset to another using a SBM.
If two datasets are given, it uses harmony to perform
integration and then the kNN graph. If only no reference is given, it is assumed
that the only adata already contains the proper kNN graph and that
Expand All @@ -391,10 +392,11 @@ def label_transfer(
The label for unassigned cells. If no `adata_ref` is given, this label
identifies cells to be assigned in `adata`. If `adata_ref` is given, this
label will be given to all cells that cannot be assigned.
use_best
When assigning labels, some cells may have not enough evidence and, therefore,
left `unknown`. If this parameter is set to `True`, all cells will be assigned
to the best possible, even if it may not be optimal
keep_old
Labels are assigned using a MCMC without merge/split step, only allowing moves.
This implies that cells with known label can be moved as well. Once the MCMC
has converged, this option sets initial labels to cells that did not need
relabeling.
neighbors_key
Use neighbors connectivities as adjacency.
If not specified, leiden looks .obsp['connectivities'] for connectivities
Expand Down Expand Up @@ -521,35 +523,60 @@ def label_transfer(

adata_merge.obs[obs] = adata_merge.obs[obs].fillna(label_unk)

# calculate affinity
# get the kNN graph
g = get_graph_tool_from_adata(adata_merge, neighbors_key=neighbors_key)

calculate_affinity(adata_merge, group_by=obs, neighbors_key=neighbors_key)
# get the blocks. Here the unknown label is already present
# we need blocks as integers otherwise graphtool can't handle them
blocks = np.array(adata_merge.obs[obs].cat.codes)
old_blocks = blocks.copy()
new_blocks = blocks.copy()
block_dict = dict(zip(adata_merge.obs[obs], adata_merge.obs[obs].cat.codes))
label_dict = dict(zip(adata_merge.obs[obs].cat.codes, adata_merge.obs[obs]))

# now work on affinity, rank it to get the new labels
categories = adata_merge.obs[obs].cat.categories
affinity = pd.DataFrame(adata_merge.obsm[f'CA_{obs}'],
index=adata_merge.obs_names, columns=categories)
# if use_best we need to remove label unknonw from the matrix so it
# does not get scored
if use_best:
affinity.drop(label_unk, axis='columns', inplace=True)
known_blocks = [block_dict[x] for x in block_dict if x != label_unk]
known_labels = [label_dict[x] for x in sorted(known_blocks)]

rank_affinity = affinity.rank(axis=1, ascending=False)
adata_merge.obs[f'_{obs}_tmp'] = adata_merge.obs[obs].values
unk_cells = adata_merge.obs.query('_label_transfer == "_unk"').index
for c in rank_affinity.columns:
# pretty sure there's a way to do it without a
# for loop :-/ I really need a course on pandas
cells = rank_affinity[rank_affinity[c] == 1].index
# do not relabel known cells
cells = cells.intersection(unk_cells)
if len(cells) > 0:
adata_merge.obs.loc[cells, f'_{obs}_tmp'] = c
n_unknowns = adata_merge.obs['_label_transfer'].value_counts()['_unk']

adj_key = adata_merge.uns[neighbors_key]['connectivities_key']
adj_mat = adata_merge.obsp[adj_key]

unk_idx = np.where(adata_merge.obs['_label_transfer'] == '_unk')[0]
I = np.transpose(adj_mat[unk_idx].nonzero())
# assign to each unknown the most common block according to the adj matrix
# In reality this works also with random assignments

for x in range(n_unknowns):
x_blocks = old_blocks[I[I[:, 0] == x][:, 1]]
# exclude the unknowns
x_blocks[x_blocks != block_dict[label_unk]]
if len(x_blocks) == 0:
# handle the case of no blocks left
x_blocks = [known_blocks[np.random.choice(len(known_blocks))]]
b, c = np.unique(x_blocks, return_counts=True)
new_blocks[unk_idx[x]] = b[np.argmax(c)]

# time to MCMC
state = gt.BlockState(g, b=new_blocks)
fast_tol = 1e-6
dS = 1e9
n = 0
max_iter = 100
while (np.abs(dS) > fast_tol) and (n < max_iter):
# we essentially move stuff, don't look for new partitions
dS, _, _ = state.multiflip_mcmc_sweep(beta=np.inf, niter=10, c=0.2,
psplit=0, pmerge=0, pmergesplit=0)
n += 1

new_blocks = np.array(state.b.a)
if keep_old:
keep_idx = np.where(adata_merge.obs['_label_transfer'] == '_ref')[0]
new_blocks[keep_idx] = old_blocks[keep_idx]

# assign proper categories
adata_merge.obs[f'_{obs}_tmp'] = pd.Categorical([label_dict[x] for x in new_blocks], categories=known_labels)

# do actual transfer to dataset 1
# here we assume that concatenation does not change the order of cells
# only cell names

labels = adata_merge.obs[f'_{obs}_tmp'].cat.categories
if adata_ref:
# transfer has been done between two files
Expand All @@ -558,11 +585,12 @@ def label_transfer(
# transfer is within dataset
adata_merge.obs[obs] = adata_merge.obs[f'_{obs}_tmp'].values
adata_merge.obs.drop(f'_{obs}_tmp', axis='columns', inplace=True)
adata_merge.obs.drop('_label_transfer', axis='columns', inplace=True)
adata = adata_merge

# ensure that it is categorical with proper order
adata.obs[obs] = pd.Categorical(adata.obs[obs], categories=labels)

# transfer colors if any
if adata_ref and f'{obs}_colors' in adata_ref.uns:
colors = list(adata_ref.uns[f'{obs}_colors'])
Expand All @@ -571,9 +599,6 @@ def label_transfer(
colors.append('#aabbcc')
adata.uns[f'{obs}_colors'] = colors

# remove unused categories if "use_best" hence no "unknown"
if use_best:
adata.obs[obs].cat.remove_unused_categories()

return adata if copy else None

0 comments on commit 1b61346

Please sign in to comment.