From 9f2644e15abdde0d126557b075abc3e3e5b09812 Mon Sep 17 00:00:00 2001 From: xinehc Date: Tue, 19 Sep 2023 14:45:27 +0800 Subject: [PATCH] init --- MANIFEST.in | 4 + setup.cfg | 27 ++++ setup.py | 2 + src/melon/__init__.py | 3 + src/melon/cli.py | 174 ++++++++++++++++++++++ src/melon/melon.py | 331 ++++++++++++++++++++++++++++++++++++++++++ src/melon/utils.py | 93 ++++++++++++ 7 files changed, 634 insertions(+) create mode 100644 MANIFEST.in create mode 100644 setup.cfg create mode 100644 setup.py create mode 100644 src/melon/__init__.py create mode 100755 src/melon/cli.py create mode 100644 src/melon/melon.py create mode 100644 src/melon/utils.py diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000..ca5d9b4 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,4 @@ +include README.md +include LICENSE + +global-exclude *.DS_Store \ No newline at end of file diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..00f5529 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,27 @@ +[metadata] +name = melon +version = attr: melon.__version__ +classifiers = + Programming Language :: Python :: 3 + +license = MIT +description = MELON: Metagenomic Taxonomy Profiling and Genome Copies Estimation using Nanopore Long Reads +long_description = file: README.md +long_description_content_type = text/markdown +keywords = + taxonomy + taxonomy profiling + +[options] +zip_safe = False +python_requires = >=3.7 +package_dir = + = src +packages = find: + +[options.entry_points] + console_scripts = + melon = melon.cli:cli + +[options.packages.find] + where = src \ No newline at end of file diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..8ab824c --- /dev/null +++ b/setup.py @@ -0,0 +1,2 @@ +from setuptools import setup +setup() \ No newline at end of file diff --git a/src/melon/__init__.py b/src/melon/__init__.py new file mode 100644 index 0000000..08b6623 --- /dev/null +++ b/src/melon/__init__.py @@ -0,0 +1,3 @@ +__version__ = '0.0.1' + +from .melon import GenomeProfiler diff --git a/src/melon/cli.py b/src/melon/cli.py new file mode 100755 index 0000000..6c059f5 --- /dev/null +++ b/src/melon/cli.py @@ -0,0 +1,174 @@ +import sys +import os +import glob + +from argparse import ArgumentParser, SUPPRESS +from . import __version__ +from .utils import logger +from .melon import GenomeProfiler + + +def cli(argv=sys.argv): + ''' + Entry point for command line interface. + ''' + parser = ArgumentParser(description='Melon: \ + long-read targeting taxonomic profiling and genome copies estimation \ + phylogenetic marker genes.', add_help=False) + required = parser.add_argument_group('required arguments') + optional = parser.add_argument_group('optional arguments') + additional = parser.add_argument_group('additional arguments') + + parser.add_argument( + 'FILE', + nargs='+', + help='Input fasta <*.fa|*.fasta> or fastq <*.fq|*.fastq> file, gzip optional <*.gz>.') + + required.add_argument( + '-d', + '--db', + metavar='DIR', + required=True, + help='Unzipped database folder, should contains , and .') + + required.add_argument( + '-o', + '--output', + metavar='DIR', + required=True, + help='Output folder.') + + optional.add_argument( + '-t', + '--threads', + metavar='INT', + type=int, + default=32, + help='Number of CPU threads. [32]') + + optional.add_argument( + '-k', + '--db_kraken', + metavar='DIR', + help='Unzipped kraken2 database for pre-filtering of non-prokaryotic reads. Skip if not given.') + + optional.add_argument( + '--skip-profile', + action='store_true', + help='Skip profiling, output only estimated total genome copies.') + + optional.add_argument( + '--skip-clean', + action='store_true', + help='Skip cleaning, keep all temporary <*.tmp> files.') + + additional.add_argument( + '-m', + metavar='INT', + type=int, + default=25, + help='Max. number of target sequences to report (--max-target-seqs/-k in diamond). [25]') + + additional.add_argument( + '-e', + metavar='FLOAT', + type=float, + default=1e-15, + help='Max. expected value to report alignments (--evalue/-e in diamond). [1e-15]') + + additional.add_argument( + '-i', + metavar='FLOAT', + type=float, + default=0, + help='Min. identity in percentage to report alignments (--id in diamond). [0]') + + additional.add_argument( + '-s', + metavar='FLOAT', + type=float, + default=75, + help='Min. subject cover to report alignments (--subject-cover in diamond). [75]') + + additional.add_argument( + '-n', + metavar='INT', + type=int, + default=2147483647, + help='Max. number of secondary alignments to report (-N in minimap2). [2147483647]') + + additional.add_argument( + '-p', + metavar='FLOAT', + type=float, + default=0.9, + help='Min. secondary-to-primary score ratio to report secondary alignments (-p in minimap2). [0.9]') + + parser.add_argument('-v', '--version', action='version', version=__version__, help=SUPPRESS) + parser.add_argument('-h', '--help', action='help', help=SUPPRESS) + + if len(argv)==1: + print(" __ \n __ _ ___ / /__ ___ \n / ' \\/ -_) / _ \\/ _ \\\n/_/_/_/\\__/_/\\___/_//_/ ver. {}\n".format(__version__)) + + opt = parser.parse_args(argv[1:]) + run(opt) + + +def run(opt): + ''' + Sanity check of options. + ''' + ## check for output folder + if not os.path.isdir(opt.output): + os.makedirs(opt.output, exist_ok=True) + else: + logger.warning('Folder <{}> exists. Files will be overwritten.'.format(opt.output)) + + ## check for input files + for file in opt.FILE: + if not os.path.isfile(file): + logger.critical('File <{}> does not exist.'.format(file)) + sys.exit(2) + + ## check for database + if not os.path.isdir(opt.db): + logger.critical('Database folder <{}> does not exist.'.format(opt.db)) + sys.exit(2) + else: + files = [os.path.basename(x) for x in glob.glob(os.path.join(opt.db, '*'))] + if 'metadata.tsv' not in files or len([x for x in files if 'prot' in x]) != 1 or len([x for x in files if 'nucl' in x]) != 16: + logger.critical('Database <{}> is not complete.'.format(opt.db)) + sys.exit(2) + + ## check for kraken2 database + if opt.db_kraken is not None: + if not os.path.isdir(opt.db_kraken): + logger.critical('Kraken2 database folder <{}> does not exist.'.format(opt.db_kraken)) + sys.exit(2) + else: + files = [os.path.basename(x) for x in glob.glob(os.path.join(opt.db_kraken, '*'))] + if 'ktaxonomy.tsv' not in files or len([x for x in files if 'database' in x]) != 7: + logger.critical('Kraken2 database <{}> is not complete.'.format(opt.db_kraken)) + sys.exit(2) + + ## run + for i, file in enumerate(opt.FILE): + if len(opt.FILE) > 1: + logger.info('Processing file <{}> ({}/{}) ...'.format(file, i+1, len(opt.FILE))) + + GenomeProfiler(file, opt.output, opt.threads).run( + db=opt.db, + db_kraken=opt.db_kraken, + skip_profile=opt.skip_profile, + skip_clean=opt.skip_clean, + max_target_seqs=opt.m, evalue=opt.e, identity=opt.i, subject_cover=opt.s, + secondary_num=opt.n, secondary_ratio=opt.p) + + if i == len(opt.FILE) - 1: + logger.info('Done.') + else: + logger.info('Done.\n') + + +if __name__ == '__main__': + cli(sys.argv) diff --git a/src/melon/melon.py b/src/melon/melon.py new file mode 100644 index 0000000..b52e92e --- /dev/null +++ b/src/melon/melon.py @@ -0,0 +1,331 @@ +import glob +from collections import defaultdict + +from .utils import * + + +class GenomeProfiler: + ''' + Profile taxonomic genomes using a set of marker genes. + ''' + def __init__(self, file, output, threads=32): + self.file = file + self.output = output + self.threads = threads + + self.aset = {'l2', 'l11', 'l10e', 'l15e', 'l18e', 's3ae', 's19e', 's28e'} + self.bset = {'l2', 'l11', 'l20', 'l27', 's2', 's7', 's9', 's16'} + self.nset = set() + + self.ranks = ['superkingdom', 'phylum', 'class', 'order', 'family', 'genus', 'species'] + + ## genome copies + self.copies = {'bacteria': 0, 'archaea': 0} + + ## temporary variables for diamond's hits or minimap2's mappings + self.hits, self.maps = [], [] + + ## taxonomy assignments + self.assignments = {} + + + def run_kraken(self, db_kraken): + ''' + Run kraken2 for pre-filtering. + ''' + subprocess.run([ + 'kraken2', + '--db', db_kraken, + '--report', get_filename(self.file, self.output, '.kraken.report.tmp'), + '--output', get_filename(self.file, self.output, '.kraken.output.tmp'), + '--threads', str(self.threads), + self.file, + ], check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) + + + def parse_kraken(self): + ''' + Parse kraken2's output by + 1: recording the ids of eukaryota (2759), viruses (10239), other entries (2787854). + 2: adding the id of the sequence to a negative search set if it contains more than 25 matched kmers. + ''' + seqid, taxid = set(), set() + record = False + with open(get_filename(self.file, self.output, '.kraken.report.tmp')) as f: + for line in f: + ls = line.rstrip().split('\t') + if ls[3] in {'D', 'R1'}: + record = ls[5].strip() in {'Eukaryota', 'Viruses', 'other entries'} + + if record: + taxid.add(ls[4]) + + with open(get_filename(self.file, self.output, '.kraken.output.tmp')) as f: + for line in f: + ls = line.rstrip().split('\t') + ks = [x.split(':') for x in ls[-1].split()] + if any(x[0] in taxid for x in ks): + cnt = defaultdict(lambda: 0) + for key, val in ks: + cnt[key] += int(val) + + if any([key in taxid for key, val in cnt.items() if val > 25]): + seqid.add(ls[1]) + + self.nset.update(seqid) + + + def run_diamond(self, db, max_target_seqs=25, evalue=1e-15, identity=0, subject_cover=75): + ''' + Run diamond to get total prokaryotic genome copies. + ''' + outfmt = ['qseqid', 'sseqid', 'pident', 'length', 'qlen', 'qstart', 'qend', 'slen', 'sstart', 'send', 'evalue', 'bitscore'] + subprocess.run([ + 'diamond', 'blastx', + '--db', os.path.join(db, 'prot.fa'), + '--query', self.file, + '--out', get_filename(self.file, self.output, '.diamond.tmp'), + '--outfmt', '6', *outfmt, + '--evalue', str(evalue), '--subject-cover', str(subject_cover), '--id', str(identity), + '--range-culling', '-F', '15', '--range-cover', '25', + '--max-hsps', '0', '--max-target-seqs', str(max_target_seqs), + '--threads', str(self.threads) + ], check=True, stderr=subprocess.DEVNULL, stdout=subprocess.DEVNULL) + + + def parse_diamond(self): + ''' + Parse diamond's output and record the hits. + ''' + qrange = defaultdict(set) + srange = {} + + with open(get_filename(self.file, self.output, '.diamond.tmp')) as f: + for line in f: + ls = line.rstrip().split('\t') + qseqid, sseqid = ls[0], ls[1] + qstart, qend = sort_coordinate(int(ls[5]), int(ls[6])) + sstart, send = sort_coordinate(int(ls[8]), int(ls[9])) + slen = int(ls[7]) + + ## bypass non-prokaryotic reads + if qseqid not in self.nset: + if ( + qseqid not in qrange or + all(compute_overlap((qstart, qend, *x), max) < 0.25 for x in qrange[qseqid]) + ): + qrange[qseqid].add((qstart, qend)) + + ss = sseqid.split('-') + gene, kingdom = ss[0], ss[-2] + if ( + (gene in self.bset and kingdom == 'bacteria') or + (gene in self.aset and kingdom == 'archaea') + ): + if sseqid not in srange: + srange[sseqid] = np.zeros(slen) + srange[sseqid][range(sstart, send)] += 1 + + ## append qseqid and coordinates for back-tracing + self.hits.append([qseqid, kingdom, gene, qstart, qend]) + + for key, val in srange.items(): + cut = round(len(val) / 4) # in this case same as counting hits since all hits have subject cover > 75% + kingdom = key.split('-')[-2] + if kingdom == 'bacteria': + self.copies[kingdom] += np.mean(val[cut:-cut]) / len(self.bset) + else: + self.copies[kingdom] += np.mean(val[cut:-cut]) / len(self.aset) + + + def run_minimap(self, db, secondary_num=2147483647, secondary_ratio=0.9): + ''' + Run minimap2 to get taxonomic profiles. + ''' + with open(get_filename(self.file, self.output, '.sequence.tmp'), 'w') as w: + w.write(extract_sequences(self.file, {x[0] for x in self.hits})) + + ## consider each kingdom + gene separately + genes = defaultdict(set) + for i in self.hits: + genes[i[1] + '.' + i[2].replace('/', '_')].add(i[0]) + + with open(get_filename(self.file, self.output, '.minimap.tmp'), 'w') as w: + for key, val in genes.items(): + sequences = extract_sequences(get_filename(self.file, self.output, '.sequence.tmp'), val) + subprocess.run([ + 'minimap2', + '-cx', 'map-ont', + '-f', '0', + '-N', str(secondary_num), '-p', str(secondary_ratio), + '-t', str(self.threads), + os.path.join(db, 'nucl.' + key + '.fa'), '-', + ], check=True, stdout=w, stderr=subprocess.DEVNULL, input=sequences, text=True) + + + def parse_minimap(self): + ''' + Parse minimap2's output and record the mappings. + ''' + coordinates = defaultdict(set) + for i in self.hits: + coordinates[i[0]].add(tuple(i[-2:])) + + with open(get_filename(self.file, self.output, '.minimap.tmp')) as f: + for line in f: + ls = line.rstrip().split('\t') + qstart, qend, qseqid, sseqid = int(ls[2]), int(ls[3]), ls[0], ls[5] + + AS = int(ls[14].split('AS:i:')[-1]) + MS = int(ls[13].split('ms:i:')[-1]) + ID = 1 - float(ls[19].split('de:f:')[-1]) if ls[16] == 'tp:A:S' or ls[16] == 'tp:A:i' else 1 - float(ls[20].split('de:f:')[-1]) + + ## filter out non-overlapping mappings + for i in coordinates[qseqid]: + if compute_overlap((qstart, qend, *i)) > 0: + self.maps.append([qseqid, sseqid, AS, MS, ID]) + + + def postprocess(self, db): + ''' + Post-processing and label reassignment using EM. + ''' + accession2lineage = {} + with open(os.path.join(db, 'metadata.tsv')) as f: + next(f) + for line in f: + ls = line.rstrip().split('\t') + accession2lineage[ls[0]] = ';'.join(ls[1:]) + + ## paste assigned taxonomy then sort + maps = sorted([ + [*x, accession2lineage[x[1].rsplit('_', 1)[0]]] for x in self.maps + ], key=lambda x: (x[0], x[2], x[3], x[4]), reverse=True) + + ## keep only the first per qseqid and lineage, remove all inferior alignments + data = [] + duplicates = set() + max_scores = defaultdict(lambda: {'AS': 0, 'MS': 0, 'ID': 0}) + + for row in maps: + max_scores[row[0]]['AS'] = max(max_scores[row[0]]['AS'], row[2]) + max_scores[row[0]]['MS'] = max(max_scores[row[0]]['MS'], row[3]) + max_scores[row[0]]['ID'] = max(max_scores[row[0]]['ID'], row[4]) + + if (row[0], row[-1]) not in duplicates: + data.append(row) + duplicates.add((row[0], row[-1])) + + data = [row for row in data if ( + row[2] > max_scores[row[0]]['AS'] * 0.99 or + row[3] > max_scores[row[0]]['MS'] * 0.99 or + row[4] > max_scores[row[0]]['ID'] * 0.999 + )] + + ## create a matrix then fill + qseqids, lineages = np.unique([row[0] for row in data]), np.unique([row[-1] for row in data]) + qseqid2index, lineage2index = {qseqid: index for index, qseqid in enumerate(qseqids)}, {lineage: index for index, lineage in enumerate(lineages)} + + matrix = np.zeros((len(qseqids), len(lineages)), dtype=int) + for row in data: + matrix[qseqid2index[row[0]], lineage2index[row[-1]]] += 1 + + ## run EM using the count matrix as input + assignments = reassign_taxonomy(matrix) + ties = defaultdict(list) + for qseqid, lineage in enumerate(assignments): + if len(assignment := lineages[lineage]) > 1: + ties[tuple(assignment)].append(qseqids[qseqid]) + else: + self.assignments[qseqids[qseqid]] = assignment[0] + + ## resolve ties for equal probability cases using AS and ID + for key, val in ties.items(): + target = [row for row in data if row[-1] in key and row[0] in val] + + scores = defaultdict(lambda: defaultdict(list)) + for row in target: + scores[row[-1]]['AS'].append(row[2]) + scores[row[-1]]['MS'].append(row[3]) + scores[row[-1]]['ID'].append(row[4]) + + ## if still tie in AS and de, choose the one with known species name + target = sorted([ + [np.mean(val['AS']), np.mean(val['MS']), np.mean(val['ID']), not bool(re.search(' sp\.$| sp\. |sp[0-9]+', key.split(';')[-1])), key] for key, val in scores.items() + ], reverse=True)[0][-1] + + for qseqid in val: + self.assignments[qseqid] = target + + + def run(self, db, db_kraken=None, skip_profile=False, skip_clean=False, + max_target_seqs=25, evalue=1e-15, identity=0, subject_cover=75, + secondary_num=2147483647, secondary_ratio=0.9): + ''' + Run the pipeline. + ''' + if db_kraken is not None: + logger.info('Filtering reads ...') + self.run_kraken(db_kraken) + self.parse_kraken() + logger.info('... removed {} putative non-prokaryotic reads.'.format(len(self.nset))) + + logger.info('Estimating genome copies ...') + self.run_diamond(db, max_target_seqs, evalue, identity, subject_cover) + self.parse_diamond() + logger.info('... found {} copies of genomes (bacteria: {}; archaea: {}).'.format( + sum(self.copies.values()), self.copies['bacteria'], self.copies['archaea'])) + + if not skip_profile: + logger.info('Assigning taxonomy ...') + self.run_minimap(db, secondary_num, secondary_ratio) + self.parse_minimap() + self.postprocess(db) + + ## fill missing ones according to hits + replacement = { + 'bacteria': ';'.join(['2|Bacteria'] + ['0|unclassified Bacteria ' + x.lower() for x in self.ranks[1:]]), + 'archaea': ';'.join(['2157|Archaea'] + ['0|unclassified Archaea ' + x.lower() for x in self.ranks[1:]]) + } + + ## fit gtdb style + if self.assignments and '|' not in next(iter(self.assignments.values())).split(';')[0]: + replacement = {key: ';'.join(x.split('|')[-1] for x in val.split(';')) for key, val in replacement.items()} + + ## count assigned taxonomic labels + counts, total_counts = defaultdict(lambda: 0), defaultdict(lambda: 0) + for i in self.hits: + counts[(self.assignments.get(i[0], replacement.get(i[1])), i[1])] += 1 + total_counts[i[1]] += 1 + + # generate a profile output + self.profile = sorted([ + [*key[0].split(';'), val * self.copies[key[1]] / total_counts[key[1]], val / sum(total_counts.values())] for key, val in counts.items() + ], key=lambda x: (x[-2], x[-3])) + + richness = {'bacteria': 0, 'archaea': 0} + with open(get_filename(self.file, self.output, '.tsv'), 'w') as w: + w.write('\t'.join(self.ranks + ['copies', 'abundance']) + '\n') + for line in self.profile: + if not re.search('unclassified (Bacteria|Archaea) species', line[6]): + richness[line[0].split('|')[-1].lower()] += 1 + w.write('\t'.join(str(x) for x in line) + '\n') + + logger.info('... found {} unique species (bacteria: {}; archaea: {}).'.format( + sum(richness.values()), richness['bacteria'], richness['archaea'])) + + else: + self.profile = sorted([ + ['2|Bacteria', self.copies['bacteria'], self.copies['bacteria'] / sum(self.copies.values())], + ['2157|Archaea', self.copies['archaea'], self.copies['archaea'] / sum(self.copies.values())] + ], key=lambda x: (x[-2], x[-3])) + + with open(get_filename(self.file, self.output, '.tsv'), 'w') as w: + w.write('\t'.join(['superkingdom', 'copies', 'abundance']) + '\n') + for line in self.profile: + w.write('\t'.join(str(x) for x in line) + '\n') + + ## clean up + if not skip_clean: + for f in glob.glob(get_filename(self.file, self.output, '.*.tmp')): + os.remove(f) diff --git a/src/melon/utils.py b/src/melon/utils.py new file mode 100644 index 0000000..205dd3f --- /dev/null +++ b/src/melon/utils.py @@ -0,0 +1,93 @@ +import os +import re +import subprocess +import logging +import numpy as np + +## setup logger format +logging.basicConfig( + level="INFO", + format="[%(asctime)s] %(levelname)s: %(message)s", + datefmt="%Y-%m-%d %H:%M:%S") + +logging.addLevelName(logging.WARNING, + f'\033[1m\x1b[33;20m{logging.getLevelName(logging.WARNING)}\033[1;0m') +logging.addLevelName(logging.CRITICAL, + f'\033[1m\x1b[31;20m{logging.getLevelName(logging.CRITICAL)}\033[1;0m') + +logger = logging.getLogger(__name__) + + +def sort_coordinate(start, end): + ''' + Convert coordinate to 0-based, then sort. + ''' + return (start-1, end) if start < end else (end-1, start) + + +def compute_overlap(coordinates, func=None): + ''' + Compute overlap between two 0-based coordinates. + ''' + qstart, qend, sstart, send = coordinates + overlap = min(qend, send) - max(qstart, sstart) + + if func is None: + return overlap + + assert func in {max, min}, 'Use either max or min for aggregation.' + return func(overlap / (qend - qstart), overlap / (send - sstart)) + + +def get_filename(file, output=None, extension=None): + ''' + Get filename of a file, possibly add output dir and extension. + ''' + filename = re.sub(r'\.f(ast)?[aq](\.gz)?$', '', os.path.basename(file)) + if output is not None: + filename = os.path.join(output, filename) + + if extension is not None: + filename += extension + return filename + + +def reassign_taxonomy(matrix, eps=1e-5, max_iteration=100): + ''' + Reassign multi-mapped reads with EM. + ''' + n_reads, n_mappings = matrix.shape + + ## init + p_reads = np.zeros((n_reads, n_mappings)) + p_mappings = np.ones(n_mappings) / n_mappings + p_mappings_hist = p_mappings.copy() + + iteration = 0 + while iteration < max_iteration: + iteration += 1 + + ## e-step + p_reads = np.divide(p_mappings * matrix, np.dot(matrix, p_mappings).reshape(-1, 1)) + + ## m-step + p_mappings = np.sum(p_reads, axis=0) / n_reads + + ## check convergence + if np.sum(np.abs(p_mappings - p_mappings_hist)) < eps: + break + + ## update p_reads_hist + np.copyto(p_mappings_hist, p_mappings) + + ## return assignments + assignments = [np.where(row == row.max())[0].tolist() for row in p_reads] + return assignments + + +def extract_sequences(file, ids): + ''' + Extract sequences from a source fa/fq file using seqkit. + ''' + cmd = ['seqkit', 'grep', '-f', '-', file] + return subprocess.run(cmd, check=True, input='\n'.join(ids) + '\n', text=True, capture_output=True).stdout