From 9f8babaadd80b936a4a3f95186a4fb20fbef3b73 Mon Sep 17 00:00:00 2001 From: Chris Rosenthal Date: Wed, 5 Jun 2024 10:36:33 -0700 Subject: [PATCH] Fixed logic identifying named tax lineages --- taxtastic/subcommands/named.py | 9 ++++-- taxtastic/taxonomy.py | 53 ++++++++++++++++++++++++---------- 2 files changed, 43 insertions(+), 19 deletions(-) diff --git a/taxtastic/subcommands/named.py b/taxtastic/subcommands/named.py index a1b59df..2925a7b 100644 --- a/taxtastic/subcommands/named.py +++ b/taxtastic/subcommands/named.py @@ -57,7 +57,10 @@ def action(args): engine = sqlalchemy.create_engine(args.url, echo=args.verbosity > 3) tax = Taxonomy(engine, schema=args.schema) if args.seq_info: - named = set(tax.named(no_rank=not args.ranked)) + seq_info = csv.DictReader(args.seq_info) + tax_ids = (i['tax_id'] for i in seq_info) + named = tax.is_valid(tax_ids, no_rank=not args.ranked) + args.seq_info.seek(0) seq_info = csv.DictReader(args.seq_info) out = csv.DictWriter(args.outfile, fieldnames=seq_info.fieldnames) out.writeheader() @@ -67,8 +70,8 @@ def action(args): tax_ids = args.tax_ids else: tax_ids = (i.strip() for i in args.tax_id_file) - tax_ids = (i for i in tax_ids if i) - named = set(tax.named()) + tax_ids = [i for i in tax_ids if i] + named = tax.is_valid(tax_ids, no_rank=not args.ranked) tax_ids = (i for i in tax_ids if i in named) for i in tax_ids: args.outfile.write(i + '\n') diff --git a/taxtastic/taxonomy.py b/taxtastic/taxonomy.py index 4fdd62f..0dc1a49 100644 --- a/taxtastic/taxonomy.py +++ b/taxtastic/taxonomy.py @@ -318,14 +318,14 @@ def _get_lineage_table(self, tax_ids, merge_obsolete=True): raise ValueError('no tax_ids were found') else: returned = {row[0] for row in rows} - # TODO: compare set membership, not lengths - if len(returned) < len(tax_ids): - msg = ('{} tax_ids were provided ' - 'but only {} were returned').format( - len(tax_ids), len(returned)) - log.error('Input tax_ids not represented in output:') - log.error(sorted(set(tax_ids) - returned)) - raise ValueError(msg) + # # TODO: compare set membership, not lengths + # if len(returned) < len(tax_ids): + # msg = ('{} tax_ids were provided ' + # 'but only {} were returned').format( + # len(tax_ids), len(returned)) + # log.error('Input tax_ids not represented in output:') + # log.error(sorted(set(tax_ids) - returned)) + # raise ValueError(msg) return rows @@ -832,12 +832,33 @@ def species_below(self, tax_id): assert self.is_ancestor_of(newc, tax_id) return newc - def named(self, no_rank=True): - names = self.names - s = select(names.c.tax_id) - s = s.where(names.c.is_classified) + def descendants_of(self, tax_ids): + """Return list of all tax_ids under *tax_id*""" + tax_ids = ','.join("'{}'".format(t) for t in tax_ids) + cmd = sa.text(""" + WITH RECURSIVE descendants AS ( + SELECT tax_id + FROM nodes + WHERE tax_id in ({}) + UNION ALL + SELECT + n.tax_id + FROM nodes n + JOIN descendants d ON d.tax_id = n.parent_id + ) SELECT DISTINCT tax_id + FROM descendants + JOIN names using(tax_id) + WHERE is_primary; + """.format(tax_ids)) + with self.engine.connect() as con: + return [row[0] for row in con.execute(cmd).fetchall()] + + def is_valid(self, tax_ids=None, no_rank=True): + """Return all classified tax_ids""" + nodes = self.nodes + s = select(nodes.c.tax_id).where(nodes.c.is_valid) + if tax_ids: + s = s.where(nodes.c.tax_id.in_(tax_ids)) if not no_rank: - nodes = self.nodes - s = s.join(nodes, nodes.c.tax_id == names.c.tax_id) - s = s.where(nodes.c.rank != 'no_rank') - return [f[0] for f in self.fetchall(s)] + s = s.where(nodes.c.rank == 'no_rank') + return [r[0] for r in self.fetchall(s)]