Skip to content

Commit

Permalink
Fixed logic identifying named tax lineages
Browse files Browse the repository at this point in the history
  • Loading branch information
crosenth committed Jun 5, 2024
1 parent b0dedc7 commit 9f8baba
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 19 deletions.
9 changes: 6 additions & 3 deletions taxtastic/subcommands/named.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,10 @@ def action(args):
engine = sqlalchemy.create_engine(args.url, echo=args.verbosity > 3)
tax = Taxonomy(engine, schema=args.schema)
if args.seq_info:
named = set(tax.named(no_rank=not args.ranked))
seq_info = csv.DictReader(args.seq_info)
tax_ids = (i['tax_id'] for i in seq_info)
named = tax.is_valid(tax_ids, no_rank=not args.ranked)
args.seq_info.seek(0)
seq_info = csv.DictReader(args.seq_info)
out = csv.DictWriter(args.outfile, fieldnames=seq_info.fieldnames)
out.writeheader()
Expand All @@ -67,8 +70,8 @@ def action(args):
tax_ids = args.tax_ids
else:
tax_ids = (i.strip() for i in args.tax_id_file)
tax_ids = (i for i in tax_ids if i)
named = set(tax.named())
tax_ids = [i for i in tax_ids if i]
named = tax.is_valid(tax_ids, no_rank=not args.ranked)
tax_ids = (i for i in tax_ids if i in named)
for i in tax_ids:
args.outfile.write(i + '\n')
53 changes: 37 additions & 16 deletions taxtastic/taxonomy.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,14 +318,14 @@ def _get_lineage_table(self, tax_ids, merge_obsolete=True):
raise ValueError('no tax_ids were found')
else:
returned = {row[0] for row in rows}
# TODO: compare set membership, not lengths
if len(returned) < len(tax_ids):
msg = ('{} tax_ids were provided '
'but only {} were returned').format(
len(tax_ids), len(returned))
log.error('Input tax_ids not represented in output:')
log.error(sorted(set(tax_ids) - returned))
raise ValueError(msg)
# # TODO: compare set membership, not lengths
# if len(returned) < len(tax_ids):
# msg = ('{} tax_ids were provided '
# 'but only {} were returned').format(
# len(tax_ids), len(returned))
# log.error('Input tax_ids not represented in output:')
# log.error(sorted(set(tax_ids) - returned))
# raise ValueError(msg)

return rows

Expand Down Expand Up @@ -832,12 +832,33 @@ def species_below(self, tax_id):
assert self.is_ancestor_of(newc, tax_id)
return newc

def named(self, no_rank=True):
names = self.names
s = select(names.c.tax_id)
s = s.where(names.c.is_classified)
def descendants_of(self, tax_ids):
"""Return list of all tax_ids under *tax_id*"""
tax_ids = ','.join("'{}'".format(t) for t in tax_ids)
cmd = sa.text("""
WITH RECURSIVE descendants AS (
SELECT tax_id
FROM nodes
WHERE tax_id in ({})
UNION ALL
SELECT
n.tax_id
FROM nodes n
JOIN descendants d ON d.tax_id = n.parent_id
) SELECT DISTINCT tax_id
FROM descendants
JOIN names using(tax_id)
WHERE is_primary;
""".format(tax_ids))
with self.engine.connect() as con:
return [row[0] for row in con.execute(cmd).fetchall()]

def is_valid(self, tax_ids=None, no_rank=True):
"""Return all classified tax_ids"""
nodes = self.nodes
s = select(nodes.c.tax_id).where(nodes.c.is_valid)
if tax_ids:
s = s.where(nodes.c.tax_id.in_(tax_ids))
if not no_rank:
nodes = self.nodes
s = s.join(nodes, nodes.c.tax_id == names.c.tax_id)
s = s.where(nodes.c.rank != 'no_rank')
return [f[0] for f in self.fetchall(s)]
s = s.where(nodes.c.rank == 'no_rank')
return [r[0] for r in self.fetchall(s)]

0 comments on commit 9f8baba

Please sign in to comment.