Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: use KmerMinHashBTree for hash subtraction #310

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
5 changes: 3 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ csv = "1.3.0"
camino = "1.1.6"
glob = "0.3.1"
rustworkx-core = "0.14.2"
streaming-stats = "0.2.3"

[dev-dependencies]
assert_cmd = "2.0.14"
Expand Down
11 changes: 10 additions & 1 deletion src/fastgather.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ pub fn fastgather(
gather_output: Option<String>,
prefetch_output: Option<String>,
allow_failed_sigpaths: bool,
make_full_result: bool,
) -> Result<()> {
let query_collection = load_collection(
&query_filepath,
Expand Down Expand Up @@ -93,6 +94,14 @@ pub fn fastgather(
}

// run the gather!
consume_query_by_gather(query_sig, matchlist, threshold_hashes, gather_output).ok();
consume_query_by_gather(
query_sig,
scaled as u64,
matchlist,
threshold_hashes,
gather_output,
make_full_result,
)
.ok();
Ok(())
}
4 changes: 4 additions & 0 deletions src/fastmultigather.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ pub fn fastmultigather(
scaled: usize,
selection: &Selection,
allow_failed_sigpaths: bool,
make_full_result: bool,
) -> Result<()> {
// load query collection
let query_collection = load_collection(
Expand Down Expand Up @@ -80,6 +81,7 @@ pub fn fastmultigather(
name: against.name.clone(),
md5sum: against.md5sum.clone(),
minhash: against.minhash.clone(),
location: against.location.clone(),
overlap,
};
mm = Some(result);
Expand All @@ -98,9 +100,11 @@ pub fn fastmultigather(
// Now, do the gather!
consume_query_by_gather(
query_sig.clone(),
scaled as u64,
matchlist,
threshold_hashes,
Some(gather_output),
make_full_result,
)
.ok();
} else {
Expand Down
4 changes: 4 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ fn do_fastgather(
ksize: u8,
scaled: usize,
moltype: String,
make_full_result: bool,
output_path_prefetch: Option<String>,
output_path_gather: Option<String>,
) -> anyhow::Result<u8> {
Expand All @@ -94,6 +95,7 @@ fn do_fastgather(
output_path_prefetch,
output_path_gather,
allow_failed_sigpaths,
make_full_result,
) {
Ok(_) => Ok(0),
Err(e) => {
Expand All @@ -111,6 +113,7 @@ fn do_fastmultigather(
ksize: u8,
scaled: usize,
moltype: String,
make_full_result: bool,
output_path: Option<String>,
) -> anyhow::Result<u8> {
let againstfile_path: camino::Utf8PathBuf = siglist_path.clone().into();
Expand Down Expand Up @@ -141,6 +144,7 @@ fn do_fastmultigather(
scaled,
&selection,
allow_failed_sigpaths,
make_full_result,
) {
Ok(_) => Ok(0),
Err(e) => {
Expand Down
4 changes: 4 additions & 0 deletions src/python/sourmash_plugin_branchwater/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def __init__(self, p):
help='scaled factor at which to do comparisons (default: 1000)')
p.add_argument('-m', '--moltype', default='DNA', choices = ["DNA", "protein", "dayhoff", "hp"],
help = 'molecule type (DNA, protein, dayhoff, or hp; default DNA)')
p.add_argument('--full-results', action='store_true', default=False, help = 'produce full gather results')
p.add_argument('-c', '--cores', default=0, type=int,
help='number of cores to use (default is all available)')

Expand All @@ -121,6 +122,7 @@ def main(self, args):
args.ksize,
args.scaled,
args.moltype,
args.full_results,
args.output_gather,
args.output_prefetch)
if status == 0:
Expand All @@ -147,6 +149,7 @@ def __init__(self, p):
help='scaled factor at which to do comparisons (default: 1000)')
p.add_argument('-m', '--moltype', default='DNA', choices = ["DNA", "protein", "dayhoff", "hp"],
help = 'molecule type (DNA, protein, dayhoff, or hp; default DNA)')
p.add_argument('--full-results', action='store_true', default=False, help = 'produce full gather results')
p.add_argument('-c', '--cores', default=0, type=int,
help='number of cores to use (default is all available)')
p.add_argument('-o', '--output', help='CSV output file for matches')
Expand All @@ -167,6 +170,7 @@ def main(self, args):
args.ksize,
args.scaled,
args.moltype,
args.full_results,
args.output)
if status == 0:
notify(f"...fastmultigather is done!")
Expand Down
170 changes: 170 additions & 0 deletions src/python/tests/test_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,3 +686,173 @@ def test_simple_with_manifest_loading(runtmp):
assert len(df) == 3
keys = set(df.keys())
assert {'query_filename', 'query_name', 'query_md5', 'match_name', 'match_md5', 'rank', 'intersect_bp'}.issubset(keys)


def test_simple_full_output(runtmp):
# test basic execution!
query = get_test_data('SRR606249.sig.gz')
against_list = runtmp.output('against.txt')

sig2 = get_test_data('2.fa.sig.gz')
sig47 = get_test_data('47.fa.sig.gz')
sig63 = get_test_data('63.fa.sig.gz')

make_file_list(against_list, [sig2, sig47, sig63])

g_output = runtmp.output('gather.csv')
p_output = runtmp.output('prefetch.csv')

runtmp.sourmash('scripts', 'fastgather', query, against_list,
'-o', g_output, '-s', '100000', '--full-results')
assert os.path.exists(g_output)

df = pandas.read_csv(g_output)
assert len(df) == 3
keys = set(df.keys())
print(keys)
print(df)
assert {'query_filename', 'query_name', 'query_md5', 'match_name', 'match_md5', 'gather_result_rank', 'intersect_bp'}.issubset(keys)
expected_keys = {'match_name', 'query_filename', 'query_n_hashes', 'match_filename', 'f_match_orig',
'query_bp', 'query_abundance', 'match_containment_ani', 'intersect_bp', 'total_weighted_hashes',
'n_unique_weighted_found', 'query_name', 'gather_result_rank', 'moltype',
'query_containment_ani', 'sum_weighted_found', 'f_orig_query', 'ksize', 'max_containment_ani',
'std_abund', 'scaled', 'average_containment_ani', 'f_match', 'f_unique_to_query',
'average_abund', 'unique_intersect_bp', 'median_abund', 'query_md5', 'match_md5', 'remaining_bp',
'f_unique_weighted'}
assert keys == expected_keys

md5s = set(df['match_md5'])
for against_file in (sig2, sig47, sig63):
for ss in sourmash.load_file_as_signatures(against_file, ksize=31):
assert ss.md5sum() in md5s


intersect_bp = set(df['intersect_bp'])
assert intersect_bp == set([4400000, 4100000, 2200000])
f_unique_to_query = set([round(x,4) for x in df['f_unique_to_query']])
assert f_unique_to_query == set([0.0053, 0.0105, 0.0044])
query_containment_ani = set([round(x,4) for x in df['query_containment_ani']])
assert query_containment_ani == set([0.8632, 0.8444, 0.8391])
print(query_containment_ani)
for index, row in df.iterrows():
print(row.to_dict())


def test_fullres_vs_sourmash_gather(runtmp):
# fastgather results should match to sourmash gather results
query = get_test_data('SRR606249.sig.gz')

sig2 = get_test_data('2.fa.sig.gz')
sig47 = get_test_data('47.fa.sig.gz')
sig63 = get_test_data('63.fa.sig.gz')

query_list = runtmp.output('query.txt')
make_file_list(query_list, [query])
against_list = runtmp.output('against.txt')
make_file_list(against_list, [sig2, sig47, sig63])

g_output = runtmp.output('SRR606249.gather.csv')
runtmp.sourmash('scripts', 'fastgather', query_list,
against_list, '-s', '100000', '-t', '0',
'--full-results', '-o', g_output)

print(runtmp.last_result.out)
print(runtmp.last_result.err)
assert os.path.exists(g_output)
# now run sourmash gather
sg_output = runtmp.output('.csv')
runtmp.sourmash('gather', query, against_list,
'-o', sg_output, '--scaled', '100000')

gather_df = pandas.read_csv(g_output)
g_keys = set(gather_df.keys())

sourmash_gather_df = pandas.read_csv(sg_output)
sg_keys = set(sourmash_gather_df.keys())
print(sg_keys)
modified_keys = ["match_md5", "match_name", "match_filename"]
sg_keys.update(modified_keys) # fastgather is more explicit (match_md5 instead of md5, etc)
print('g_keys - sg_keys:', g_keys - sg_keys)
assert not g_keys - sg_keys, g_keys - sg_keys

for _idx, row in sourmash_gather_df.iterrows():
print(row.to_dict())

fg_intersect_bp = set(gather_df['intersect_bp'])
g_intersect_bp = set(sourmash_gather_df['intersect_bp'])
assert fg_intersect_bp == g_intersect_bp == set([4400000, 4100000, 2200000])

fg_f_orig_query = set([round(x,4) for x in gather_df['f_orig_query']])
g_f_orig_query = set([round(x,4) for x in sourmash_gather_df['f_orig_query']])
assert fg_f_orig_query == g_f_orig_query == set([0.0098, 0.0105, 0.0052])

fg_f_match = set([round(x,4) for x in gather_df['f_match']])
g_f_match = set([round(x,4) for x in sourmash_gather_df['f_match']])
assert fg_f_match == g_f_match == set([0.439, 1.0])

fg_f_unique_to_query = set([round(x,3) for x in gather_df['f_unique_to_query']]) # rounding to 4 --> slightly different!
g_f_unique_to_query = set([round(x,3) for x in sourmash_gather_df['f_unique_to_query']])
assert fg_f_unique_to_query == g_f_unique_to_query == set([0.004, 0.01, 0.005])

fg_f_unique_weighted = set([round(x,4) for x in gather_df['f_unique_weighted']])
g_f_unique_weighted = set([round(x,4) for x in sourmash_gather_df['f_unique_weighted']])
assert fg_f_unique_weighted== g_f_unique_weighted == set([0.0063, 0.002, 0.0062])

fg_average_abund = set([round(x,4) for x in gather_df['average_abund']])
g_average_abund = set([round(x,4) for x in sourmash_gather_df['average_abund']])
assert fg_average_abund== g_average_abund == set([8.2222, 10.3864, 21.0455])

fg_median_abund = set([round(x,4) for x in gather_df['median_abund']])
g_median_abund = set([round(x,4) for x in sourmash_gather_df['median_abund']])
assert fg_median_abund== g_median_abund == set([8.0, 10.5, 21.5])

fg_std_abund = set([round(x,4) for x in gather_df['std_abund']])
g_std_abund = set([round(x,4) for x in sourmash_gather_df['std_abund']])
assert fg_std_abund== g_std_abund == set([3.172, 5.6446, 6.9322])

g_match_filename_basename = [os.path.basename(filename) for filename in sourmash_gather_df['filename']]
fg_match_filename_basename = [os.path.basename(filename) for filename in gather_df['match_filename']]
assert all([x in fg_match_filename_basename for x in ['2.fa.sig.gz', '63.fa.sig.gz', '47.fa.sig.gz']])
assert fg_match_filename_basename == g_match_filename_basename

assert list(sourmash_gather_df['name']) == list(gather_df['match_name'])
assert list(sourmash_gather_df['md5']) == list(gather_df['match_md5'])

fg_f_match_orig = set([round(x,4) for x in gather_df['f_match_orig']])
g_f_match_orig = set([round(x,4) for x in sourmash_gather_df['f_match_orig']])
assert fg_f_match_orig == g_f_match_orig == set([1.0])

fg_unique_intersect_bp = set(gather_df['unique_intersect_bp'])
g_unique_intersect_bp = set(sourmash_gather_df['unique_intersect_bp'])
assert fg_unique_intersect_bp == g_unique_intersect_bp == set([4400000, 1800000, 2200000])

fg_gather_result_rank= set(gather_df['gather_result_rank'])
g_gather_result_rank = set(sourmash_gather_df['gather_result_rank'])
assert fg_gather_result_rank == g_gather_result_rank == set([0,1,2])

fg_remaining_bp = list(gather_df['remaining_bp'])
assert fg_remaining_bp == [415600000, 413400000, 411600000]
### Gather remaining bp does not match, but I think this one is right?
#g_remaining_bp = list(sourmash_gather_df['remaining_bp'])
#print("gather remaining bp: ", g_remaining_bp) #{4000000, 0, 1800000}
# assert fg_remaining_bp == g_remaining_bp == set([])

fg_query_containment_ani = set([round(x,4) for x in gather_df['query_containment_ani']])
g_query_containment_ani = set([round(x,4) for x in sourmash_gather_df['query_containment_ani']])
assert fg_query_containment_ani == set([0.8632, 0.8444, 0.8391])
# gather cANI are nans here -- perhaps b/c sketches too small?
# assert fg_query_containment_ani == g_query_containment_ani == set([0.8632, 0.8444, 0.8391])
print("fg qcANI: ", fg_query_containment_ani)
print("g_qcANI: ", g_query_containment_ani)

fg_n_unique_weighted_found= set(gather_df['n_unique_weighted_found'])
g_n_unique_weighted_found = set(sourmash_gather_df['n_unique_weighted_found'])
assert fg_n_unique_weighted_found == g_n_unique_weighted_found == set([457, 148, 463])

fg_sum_weighted_found= set(gather_df['sum_weighted_found'])
g_sum_weighted_found = set(sourmash_gather_df['sum_weighted_found'])
assert fg_sum_weighted_found == g_sum_weighted_found == set([920, 457, 1068])

fg_total_weighted_hashes= set(gather_df['total_weighted_hashes'])
g_total_weighted_hashes = set(sourmash_gather_df['total_weighted_hashes'])
assert fg_total_weighted_hashes == g_total_weighted_hashes == set([73489])
Loading
Loading