diff --git a/argnorm/lib.py b/argnorm/lib.py index c9024c4..54472c1 100644 --- a/argnorm/lib.py +++ b/argnorm/lib.py @@ -8,29 +8,21 @@ _ROOT = os.path.abspath(os.path.dirname(__file__)) -def is_number(num): - try: - int(num) - except ValueError: - return False - - return True - def get_aro_mapping_table(database): - aro_mapping_table = pd.read_csv(os.path.join(_ROOT, 'data', f'{database}_ARO_mapping.tsv'), sep='\t') - - manual_curation = pd.read_csv(os.path.join(_ROOT, 'data/manual_curation', f'{database}_curation.tsv'), sep='\t') - manual_curation['Database'] = aro_mapping_table['Database'] - - aro_mapping_table = aro_mapping_table.drop_duplicates(subset=['Original ID'], ignore_index=True).set_index('Original ID') - for i in manual_curation['Original ID']: - if i in aro_mapping_table.index: - aro_mapping_table.loc[i, 'ARO'] = manual_curation.set_index('Original ID').loc[i, 'ARO'] - aro_mapping_table.loc[i, 'Gene Name in CARD'] = manual_curation.set_index('Original ID').loc[i, 'Gene Name in CARD'] - else: - aro_mapping_table.loc[i] = manual_curation.set_index('Original ID').loc[i] - - aro_mapping_table[TARGET_ARO_COL] = aro_mapping_table[TARGET_ARO_COL].map(lambda a: f'ARO:{int(a)}' if is_number(a) else a) + aro_mapping_table = pd.read_csv( + os.path.join(_ROOT, 'data', f'{database}_ARO_mapping.tsv'), + sep='\t') + aro_mapping_table.drop_duplicates(subset=['Original ID'], inplace=True) + aro_mapping_table.set_index('Original ID', inplace=True) + + manual_curation = pd.read_csv( + os.path.join(_ROOT, 'data/manual_curation', f'{database}_curation.tsv'), + sep='\t', index_col=0) + manual_curation['Database'] = aro_mapping_table['Database'].iloc[0] + aro_mapping_table.drop(index=set(manual_curation.index) & set(aro_mapping_table.index), inplace=True) + aro_mapping_table = pd.concat([aro_mapping_table, manual_curation]) + + aro_mapping_table['ARO'] = aro_mapping_table['ARO'].map(lambda a: f'ARO:{int(a)}', na_action='ignore') return aro_mapping_table.reset_index() def map_to_aro(gene, database): @@ -50,4 +42,4 @@ def map_to_aro(gene, database): if type(result) != str: return ARO[list(set(result))[0]] else: - return ARO[result] \ No newline at end of file + return ARO[result] diff --git a/tests/test_lib.py b/tests/test_lib.py index 11ced10..3e31d98 100644 --- a/tests/test_lib.py +++ b/tests/test_lib.py @@ -1,4 +1,5 @@ -from argnorm.lib import map_to_aro +import pytest +from argnorm.lib import map_to_aro, get_aro_mapping_table import pronto def test_map_to_aro(): @@ -18,4 +19,10 @@ def test_map_to_aro(): ] for t, e in zip(test_cases, expected_output): - assert map_to_aro(t[0], t[1]) == e \ No newline at end of file + assert map_to_aro(t[0], t[1]) == e + +@pytest.mark.parametrize('database', ['argannot', 'megares', 'ncbi', 'resfinder', 'resfinderfg']) +def test_get_aro_mapping_table_smoke(database): + df = get_aro_mapping_table(database) + assert len(df) > 0 +