diff --git a/mash/mash.go b/mash/mash.go index 550d5e2a..0c109555 100644 --- a/mash/mash.go +++ b/mash/mash.go @@ -106,35 +106,28 @@ func (mash *Mash) Sketch(sequence string) { // Similarity returns the Jaccard similarity between two sketches (number of matching hashes / sketch size) func (mash *Mash) Similarity(other *Mash) float64 { var sameHashes int + largerSketch := mash + smallerSketch := other - var largerSketch *Mash - var smallerSketch *Mash - - if mash.SketchSize > other.SketchSize { - largerSketch = mash - smallerSketch = other - } else { + if mash.SketchSize < other.SketchSize { largerSketch = other smallerSketch = mash } - largerSketchSizeShifted := largerSketch.SketchSize - 1 - smallerSketchSizeShifted := smallerSketch.SketchSize - 1 - - // if the largest hash in the larger sketch is smaller than the smallest hash in the smaller sketch, the distance is 1 - if largerSketch.Sketches[largerSketchSizeShifted] < smallerSketch.Sketches[0] { - return 0 - } - - // if the largest hash in the smaller sketch is smaller than the smallest hash in the larger sketch, the distance is 1 - if smallerSketch.Sketches[smallerSketchSizeShifted] < largerSketch.Sketches[0] { + if largerSketch.Sketches[largerSketch.SketchSize-1] < smallerSketch.Sketches[0] || smallerSketch.Sketches[smallerSketch.SketchSize-1] < largerSketch.Sketches[0] { return 0 } - for _, hash := range smallerSketch.Sketches { - ind := sort.Search(largerSketchSizeShifted, func(ind int) bool { return largerSketch.Sketches[ind] <= hash }) - if largerSketch.Sketches[ind] == hash { + smallSketchIndex, largeSketchIndex := 0, 0 + for smallSketchIndex < smallerSketch.SketchSize && largeSketchIndex < largerSketch.SketchSize { + if smallerSketch.Sketches[smallSketchIndex] == largerSketch.Sketches[largeSketchIndex] { sameHashes++ + smallSketchIndex++ + largeSketchIndex++ + } else if smallerSketch.Sketches[smallSketchIndex] < largerSketch.Sketches[largeSketchIndex] { + smallSketchIndex++ + } else { + largeSketchIndex++ } } diff --git a/mash/mash_test.go b/mash/mash_test.go index 6ba9a663..ccf20ce4 100644 --- a/mash/mash_test.go +++ b/mash/mash_test.go @@ -37,4 +37,38 @@ func TestMash(t *testing.T) { if distance != 1 { t.Errorf("Expected distance to be 1, got %f", distance) } + + fingerprint1 = mash.New(17, 10) + fingerprint1.Sketch("ATGCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGA") + + fingerprint2 = mash.New(17, 5) + fingerprint2.Sketch("ATCGATCGATCGATCGATCGATCGATCGATCGATCGAATGCGATCGATCGATCGATCGATCG") + + distance = fingerprint1.Distance(fingerprint2) + if !(distance > 0.19 && distance < 0.21) { + t.Errorf("Expected distance to be 0.19999999999999996, got %f", distance) + } + + fingerprint1 = mash.New(17, 10) + fingerprint1.Sketch("ATCGATCGATCGATCGATCGATCGATCGATCGATCGAATGCGATCGATCGATCGATCGATCG") + + fingerprint2 = mash.New(17, 5) + fingerprint2.Sketch("ATGCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGA") + + distance = fingerprint1.Distance(fingerprint2) + if distance != 0 { + t.Errorf("Expected distance to be 0, got %f", distance) + } +} + +func BenchmarkMashDistancee(b *testing.B) { + fingerprint1 := mash.New(17, 10) + fingerprint1.Sketch("ATGCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGA") + + fingerprint2 := mash.New(17, 9) + fingerprint2.Sketch("ATGCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGA") + + for i := 0; i < b.N; i++ { + fingerprint1.Distance(fingerprint2) + } }