diff --git a/schist/tools/_affinity_tools.py b/schist/tools/_affinity_tools.py index 7528844..13d96d6 100644 --- a/schist/tools/_affinity_tools.py +++ b/schist/tools/_affinity_tools.py @@ -9,7 +9,7 @@ from scanpy import logging as logg import graph_tool.all as gt import pandas as pd -from .._utils import get_cell_loglikelihood, get_cell_back_p, state_from_blocks +from .._utils import get_cell_loglikelihood, get_cell_back_p, state_from_blocks, get_graph_tool_from_adata from scanpy._utils import get_igraph_from_adjacency, _choose_graph @@ -358,7 +358,8 @@ def label_transfer( adata_ref: Optional[AnnData] = None, obs: Optional[str] = None, label_unk: Optional[str] = 'unknown', - use_best: Optional[bool] = False, + #use_best: Optional[bool] = False, + keep_old: Optional[bool] = True, neighbors_key: Optional[str] = 'neighbors', adjacency: Optional[sparse.spmatrix] = None, directed: bool = False, @@ -371,7 +372,7 @@ def label_transfer( ) -> Optional[AnnData]: """\ - Transfer annotation from one dataset to another using cell affinities. + Transfer annotation from one dataset to another using a SBM. If two datasets are given, it uses harmony to perform integration and then the kNN graph. If only no reference is given, it is assumed that the only adata already contains the proper kNN graph and that @@ -391,10 +392,11 @@ def label_transfer( The label for unassigned cells. If no `adata_ref` is given, this label identifies cells to be assigned in `adata`. If `adata_ref` is given, this label will be given to all cells that cannot be assigned. - use_best - When assigning labels, some cells may have not enough evidence and, therefore, - left `unknown`. If this parameter is set to `True`, all cells will be assigned - to the best possible, even if it may not be optimal + keep_old + Labels are assigned using a MCMC without merge/split step, only allowing moves. + This implies that cells with known label can be moved as well. Once the MCMC + has converged, this option sets initial labels to cells that did not need + relabeling. neighbors_key Use neighbors connectivities as adjacency. If not specified, leiden looks .obsp['connectivities'] for connectivities @@ -521,35 +523,60 @@ def label_transfer( adata_merge.obs[obs] = adata_merge.obs[obs].fillna(label_unk) - # calculate affinity + # get the kNN graph + g = get_graph_tool_from_adata(adata_merge, neighbors_key=neighbors_key) - calculate_affinity(adata_merge, group_by=obs, neighbors_key=neighbors_key) + # get the blocks. Here the unknown label is already present + # we need blocks as integers otherwise graphtool can't handle them + blocks = np.array(adata_merge.obs[obs].cat.codes) + old_blocks = blocks.copy() + new_blocks = blocks.copy() + block_dict = dict(zip(adata_merge.obs[obs], adata_merge.obs[obs].cat.codes)) + label_dict = dict(zip(adata_merge.obs[obs].cat.codes, adata_merge.obs[obs])) - # now work on affinity, rank it to get the new labels - categories = adata_merge.obs[obs].cat.categories - affinity = pd.DataFrame(adata_merge.obsm[f'CA_{obs}'], - index=adata_merge.obs_names, columns=categories) - # if use_best we need to remove label unknonw from the matrix so it - # does not get scored - if use_best: - affinity.drop(label_unk, axis='columns', inplace=True) + known_blocks = [block_dict[x] for x in block_dict if x != label_unk] + known_labels = [label_dict[x] for x in sorted(known_blocks)] - rank_affinity = affinity.rank(axis=1, ascending=False) - adata_merge.obs[f'_{obs}_tmp'] = adata_merge.obs[obs].values - unk_cells = adata_merge.obs.query('_label_transfer == "_unk"').index - for c in rank_affinity.columns: - # pretty sure there's a way to do it without a - # for loop :-/ I really need a course on pandas - cells = rank_affinity[rank_affinity[c] == 1].index - # do not relabel known cells - cells = cells.intersection(unk_cells) - if len(cells) > 0: - adata_merge.obs.loc[cells, f'_{obs}_tmp'] = c + n_unknowns = adata_merge.obs['_label_transfer'].value_counts()['_unk'] + + adj_key = adata_merge.uns[neighbors_key]['connectivities_key'] + adj_mat = adata_merge.obsp[adj_key] + + unk_idx = np.where(adata_merge.obs['_label_transfer'] == '_unk')[0] + I = np.transpose(adj_mat[unk_idx].nonzero()) + # assign to each unknown the most common block according to the adj matrix + # In reality this works also with random assignments + + for x in range(n_unknowns): + x_blocks = old_blocks[I[I[:, 0] == x][:, 1]] + # exclude the unknowns + x_blocks[x_blocks != block_dict[label_unk]] + if len(x_blocks) == 0: + # handle the case of no blocks left + x_blocks = [known_blocks[np.random.choice(len(known_blocks))]] + b, c = np.unique(x_blocks, return_counts=True) + new_blocks[unk_idx[x]] = b[np.argmax(c)] + + # time to MCMC + state = gt.BlockState(g, b=new_blocks) + fast_tol = 1e-6 + dS = 1e9 + n = 0 + max_iter = 100 + while (np.abs(dS) > fast_tol) and (n < max_iter): + # we essentially move stuff, don't look for new partitions + dS, _, _ = state.multiflip_mcmc_sweep(beta=np.inf, niter=10, c=0.2, + psplit=0, pmerge=0, pmergesplit=0) + n += 1 + + new_blocks = np.array(state.b.a) + if keep_old: + keep_idx = np.where(adata_merge.obs['_label_transfer'] == '_ref')[0] + new_blocks[keep_idx] = old_blocks[keep_idx] + + # assign proper categories + adata_merge.obs[f'_{obs}_tmp'] = pd.Categorical([label_dict[x] for x in new_blocks], categories=known_labels) - # do actual transfer to dataset 1 - # here we assume that concatenation does not change the order of cells - # only cell names - labels = adata_merge.obs[f'_{obs}_tmp'].cat.categories if adata_ref: # transfer has been done between two files @@ -558,11 +585,12 @@ def label_transfer( # transfer is within dataset adata_merge.obs[obs] = adata_merge.obs[f'_{obs}_tmp'].values adata_merge.obs.drop(f'_{obs}_tmp', axis='columns', inplace=True) + adata_merge.obs.drop('_label_transfer', axis='columns', inplace=True) adata = adata_merge # ensure that it is categorical with proper order adata.obs[obs] = pd.Categorical(adata.obs[obs], categories=labels) - + # transfer colors if any if adata_ref and f'{obs}_colors' in adata_ref.uns: colors = list(adata_ref.uns[f'{obs}_colors']) @@ -571,9 +599,6 @@ def label_transfer( colors.append('#aabbcc') adata.uns[f'{obs}_colors'] = colors - # remove unused categories if "use_best" hence no "unknown" - if use_best: - adata.obs[obs].cat.remove_unused_categories() return adata if copy else None \ No newline at end of file