Skip to content

Commit

Permalink
Merge pull request #219 from andersen-lab/collapse_updates
Browse files Browse the repository at this point in the history
Collapse updates
  • Loading branch information
joshuailevy authored Mar 18, 2024
2 parents 2339e51 + 9ca1877 commit 8d736dc
Show file tree
Hide file tree
Showing 4 changed files with 146 additions and 41 deletions.
30 changes: 25 additions & 5 deletions freyja/_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,15 @@ def print_barcode_version(ctx, param, value):
@click.option('--region_of_interest', default='', help='JSON file containing'
'region(s) of interest for which to compute additional coverage'
'estimates')
@click.option('--relaxedmrca', is_flag=True, default=False,
help='for use with depth cutoff,'
'clusters are assigned robust mrca to handle outliers')
@click.option('--relaxedthresh', default=0.9,
help='associated threshold for robust mrca function')
def demix(variants, depths, output, eps, barcodes, meta,
covcut, confirmedonly, depthcutoff, lineageyml,
adapt, a_eps, region_of_interest):
adapt, a_eps, region_of_interest,
relaxedmrca, relaxedthresh):
"""
Generate prevalence of lineages per sample
Expand All @@ -85,7 +91,10 @@ def demix(variants, depths, output, eps, barcodes, meta,
:param lineageyml: used to pass a custom lineage hierarchy file
:param adapt: used to set adaptive lasso penalty parameter
:param a_eps: used to set adaptive lasso
penalty parameter hard threshold
penalty parameter hard threshold'
:param relaxedmrca: for use with depth cutoff,
clusters are assigned robust mrca to handle outliers
:param relaxedthresh: associated threshold for robust mrca function
:return : a tsv file that includes the
lineages present,their corresponding abundances,
and summarization by constellation.
Expand Down Expand Up @@ -114,7 +123,8 @@ def demix(variants, depths, output, eps, barcodes, meta,
df_depth = pd.read_csv(depths, sep='\t', header=None, index_col=1)
if depthcutoff != 0:
df_barcodes = collapse_barcodes(df_barcodes, df_depth, depthcutoff,
lineageyml, locDir, output)
lineageyml, locDir, output,
relaxedmrca, relaxedthresh)
muts = list(df_barcodes.columns)
mapDict = buildLineageMap(meta)
print('building mix/depth matrices')
Expand Down Expand Up @@ -386,8 +396,14 @@ def variants(bamfile, ref, variants, depths, refname, minq, annot, varthresh):
@click.option('--depthcutoff', default=0,
help='exclude sites with coverage depth below this value and'
'group identical barcodes')
@click.option('--relaxedmrca', is_flag=True, default=False,
help='for use with depth cutoff,'
'clusters are assigned robust mrca to handle outliers')
@click.option('--relaxedthresh', default=0.9,
help='associated threshold for robust mrca function')
def boot(variants, depths, output_base, eps, barcodes, meta,
nb, nt, boxplot, confirmedonly, lineageyml, depthcutoff, rawboots):
nb, nt, boxplot, confirmedonly, lineageyml, depthcutoff,
rawboots, relaxedmrca, relaxedthresh):
"""
Perform bootstrapping method for freyja
Expand All @@ -408,6 +424,9 @@ def boot(variants, depths, output_base, eps, barcodes, meta,
:param lineageyml: used to pass a custom lineage hierarchy file
:param depthcutoff: used to exclude sites with coverage depth
below this value andgroup identical barcodes
:param relaxedmrca: for use with depth cutoff,
clusters are assigned robust mrca to handle outliers
:param relaxedthresh: associated threshold for robust mrca function
:return : base-name_lineages.csv and base-name_summarized.csv
"""
Expand Down Expand Up @@ -439,7 +458,8 @@ def boot(variants, depths, output_base, eps, barcodes, meta,
if depthcutoff != 0:
df_barcodes = collapse_barcodes(
df_barcodes, df_depths, depthcutoff,
lineageyml, locDir, output_base)
lineageyml, locDir, output_base,
relaxedmrca, relaxedthresh)

muts = list(df_barcodes.columns)
mapDict = buildLineageMap(meta)
Expand Down
6 changes: 5 additions & 1 deletion freyja/sample_deconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,11 @@ def reindex_dfs(df_barcodes, mix, depths):
def map_to_constellation(sample_strains, vals, mapDict):
# maps lineage names to constellations
localDict = {}
for jj, lin in enumerate(sample_strains):
for jj, lin0 in enumerate(sample_strains):
if '-like' in lin0:
lin = lin0.split('-like')[0]
else:
lin = lin0
if lin in mapDict.keys():
if mapDict[lin] not in localDict.keys():
localDict[mapDict[lin]] = vals[jj]
Expand Down
3 changes: 2 additions & 1 deletion freyja/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,8 @@ def test_collapse_barcodes(self):
original = barcodes.shape
lineages_yml = 'freyja/data/lineages.yml'
barcodes = collapse_barcodes(barcodes, self.depth,
100, lineages_yml, 'freyja', 'test')
100, lineages_yml, 'freyja', 'test',
False, 0.9)
collapsed = barcodes.shape

self.assertLess(collapsed[0], original[0])
Expand Down
148 changes: 114 additions & 34 deletions freyja/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,8 @@ def get_color_scheme(df, default_color_scheme, config=None):
return color_scheme


def prepLineageDict(agg_d0, thresh=0.001, config=None, lineage_info=None):
def prepLineageDict(agg_d0, thresh=0.001, config=None, lineage_info=None,
mergeLikes=False):

if len(agg_d0.index[agg_d0.index.duplicated(keep=False)]) > 0:
print('WARNING: multiple samples have the same ID/filename.')
Expand Down Expand Up @@ -234,7 +235,22 @@ def prepLineageDict(agg_d0, thresh=0.001, config=None, lineage_info=None):
.apply(lambda x:
re.sub(' +', ' ', x)
.split(' ')).copy()
# print([float(abund) for abund in agg_d0.iloc[0].loc['abundances']])

if mergeLikes:
agg_d0.loc[:, 'lineages'] = agg_d0['lineages'].apply(
lambda x: [x0.split('-like')[0] for x0 in x])
newLins, newAbunds = [], []
for lins, abunds in zip(agg_d0['lineages'], agg_d0['abundances']):
linsUnique, indicesUnique = np.unique(lins, return_inverse=True)
newLins.append(linsUnique)
newAbund = [0]*len(linsUnique)
for ind0, j in enumerate(indicesUnique):
newAbund[j] += float(abunds[ind0])
newAbunds.append(newAbund)
agg_d0['lineages'] = newLins
agg_d0['abundances'] = newAbunds
# agg_d0.loc[:,'abundances'] = agg_d0[['lineages']
# merge abundances
agg_d0.loc[:, 'linDict'] = [{lin: float(abund) for lin, abund in
zip(agg_d0.loc[samp, 'lineages'],
agg_d0.loc[samp, 'abundances'])}
Expand Down Expand Up @@ -496,7 +512,6 @@ def makePlot_time(agg_df, lineages, times_df, interval, outputFn,
weekInfo = [Week.fromdate(dfi).weektuple()
for dfi in df_abundances.index]
df_abundances.index = [str(wi[0])+'-'+str(wi[1]) for wi in weekInfo]
print(df_abundances)
for i in range(0, df_abundances.shape[1]):
label = df_abundances.columns[i]
ax.bar(df_abundances.index, df_abundances.iloc[:, i],
Expand Down Expand Up @@ -902,7 +917,8 @@ def make_dashboard(agg_df, meta_df, thresh, title, introText,


def collapse_barcodes(df_barcodes, df_depth, depthcutoff,
lineageyml, locDir, output):
lineageyml, locDir, output,
relaxed, relaxedthresh):
# drop low coverage sites
low_cov_sites = df_depth[df_depth[3].astype(int) < depthcutoff] \
.index.astype(str)
Expand All @@ -913,7 +929,6 @@ def collapse_barcodes(df_barcodes, df_depth, depthcutoff,
max_depth = df_depth[3].astype(int).max()

# find lineages with identical barcodes

try:
duplicates = df_barcodes.groupby(df_barcodes.columns.tolist()).apply(
lambda x: tuple(x.index) if len(x.index) > 1 else None
Expand All @@ -935,46 +950,111 @@ def collapse_barcodes(df_barcodes, df_depth, depthcutoff,

# collapse lineages into MRCA, where possible
for tup in duplicates:
pango_aliases = [lineage_data[lin]['alias']
for lin in tup]
alias_dict = {lineage_data[lin]['alias']: lin for lin in tup}

try:
pango_aliases = [lineage_data[lin]['alias']
for lin in tup]
except KeyError:
print('Lineage hierarchy file is likely behind'
' the selected barcode file. Try updating'
' the hierarchy file.')
# handle cases where multiple lineage classes are being merged
# e.g. (A.5, B.12) or (XBB, XBN)
multiple_lin_classes = len(
set([alias[0] for alias in pango_aliases])) > 1
set([alias.split('.')[0] for alias in pango_aliases])) > 1

if multiple_lin_classes:
# for recombinant lineages, find the parent lineages
parent_aliases = []
for alias in pango_aliases:

if 'recombinant_parents' in lineage_data[alias_dict[alias]]:

# replace with its recombinant parents
pango_aliases.remove(alias)
parents = lineage_data[alias]['recombinant_parents'] \
.replace('*', '').split(',')
parents = [lineage_data[lin]['alias'] for lin in parents]
recombs = [alias for alias in pango_aliases
if 'recombinant_parents' in
lineage_data[alias.split('.')[0]]]

# only consider parents related to the other lineages
for parent in parents:
if any([alias.startswith(parent)
for alias in pango_aliases]) and \
parent not in pango_aliases:
parent_aliases.append(parent)

pango_aliases += parent_aliases

# get MRCA
mrca = os.path.commonpath(
[lin.replace('.', '/') for lin in pango_aliases]
).replace('/', '.')
# for recombinant lineages, find the parent lineages
startTypes = set([alias.split('.')[0] for alias in pango_aliases])
# figure out which are the candidates for recomb merging
# if they exist
while len(recombs) > 0:
parent_aliases = []
for alias in recombs:
if 'recombinant_parents' in lineage_data[alias.split('.')
[0]]:
# trace up tree until a recombination event.
# grab parents of recombinant
parents = lineage_data[alias.split('.')
[0]]['recombinant_parents'
].replace('*', ''
).split(',')
parent_aliases.append([lineage_data[lin]['alias']
for lin in parents])

distinct = []
newRecombs = []
mergedIn = False
for alias, pa in zip(recombs, parent_aliases):
for aliasP in pa:
if aliasP.split('.')[0] in startTypes:
mergedIn = True
# if now using same start as others,
# add to list of aliases
pango_aliases.append(aliasP)
if alias in pango_aliases:
pango_aliases.remove(alias)
elif 'recombinant_parents' in lineage_data[aliasP.
split('.')
[0]]:
# check if it's a different recombinant
newRecombs.append(aliasP)
else:
# non-recombinant, but not in current start types.
distinct.append(aliasP)
if not mergedIn:
# if no merges, remove the recombinants
# and add in the parents
for r in recombs:
pango_aliases.remove(r)
pango_aliases.extend(distinct+newRecombs)

startTypes = set([alias.split('.')[0]
for alias in pango_aliases])
if len(startTypes) == 1:
break
recombs = [alias for alias in pango_aliases if
'recombinant_parents' in
lineage_data[alias.split('.')[0]]]

if not relaxed:
mrca = os.path.commonpath(
[lin.replace('.', '/') for lin in pango_aliases]
).replace('/', '.')
else:
j0 = 1
groupCt = float(len(pango_aliases))
ext_counts = np.unique([lin.split('.')[0:j0]
for lin in pango_aliases],
return_counts=True)
coherentFrac = np.max(ext_counts[1]) / groupCt
if coherentFrac < relaxedthresh:
mrca = ''
else:
maxLength = np.max([len(lin.split('.'))
for lin in pango_aliases])
while coherentFrac >= relaxedthresh and j0 <= maxLength:
ext_counts = np.unique([lin.split('.')[0:j0]
if j0 <= len(lin.split('.'))
else lin.split('.') +
['']*(j0-len(lin.split('.')))
for lin in pango_aliases],
return_counts=True,
axis=0)
max_ind = np.argmax(ext_counts[1])
coherentFrac = ext_counts[1][max_ind] / groupCt
mrca = '.'.join(ext_counts[0][max_ind])
j0 += 1

# assign placeholder if no MRCA found
if len(mrca) == 0:
mrca = 'Misc'
else:
if mrca[-1] == '.':
mrca = mrca[0:(len(mrca)-1)]
# otherwise, get the shortened alias, if available
for lineage in lineage_data:
if lineage_data[lineage]['alias'] == mrca:
Expand Down

0 comments on commit 8d736dc

Please sign in to comment.