From 05169d236d0ba859e846217cebb2573483cf4f61 Mon Sep 17 00:00:00 2001 From: Trenton w Fleming Date: Wed, 24 Jan 2024 14:41:14 -0500 Subject: [PATCH] Run length bwt (#440) * run length bwt * better clarify mapping from original sequnce space to run space --------- Co-authored-by: Timothy Stiles --- CHANGELOG.md | 1 + search/bwt/bwt.go | 239 ++++++++++++++++++++++++++++++++++--- search/bwt/bwt_test.go | 167 +++++++++++++++++++++++++- search/bwt/wavelet.go | 29 +++-- search/bwt/wavelet_test.go | 17 +-- 5 files changed, 409 insertions(+), 44 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index bd26ded8..82b157bf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Basic BWT for sub-sequence count and offset for sequence alignment. Only supports exact matches for now. - Moved `BWT`, `align`, and `mash` packages to new `search` sub-directory. +- Implemented Run-Length Burrows Wheeler Transform. ## [0.30.0] - 2023-12-18 diff --git a/search/bwt/bwt.go b/search/bwt/bwt.go index 66020166..3d01c03b 100644 --- a/search/bwt/bwt.go +++ b/search/bwt/bwt.go @@ -186,13 +186,38 @@ type BWT struct { // represented as a list of skipEntries because the first column of // the BWT is always lexicographically ordered. This saves time and memory. firstColumnSkipList []skipEntry - // Column last column of the BWT- the actual textual representation - // of the BWT. - lastColumn waveletTree // suffixArray an array that allows us to map a position in the first // column to a position in the original sequence. This is needed to be // able to extract text from the BWT. suffixArray []int + // runLengthCompressedBWT is the compressed version of the BWT. The compression + // is for each run. For Example: + // the sequence "banana" has BWT "annb$aa" + // the run length compression of "annb$aa" is "anb$a" + // This helps us save a lot of memory while still having a search index we can + // use to align the original sequence. This allows us to understand how many + // runs of a certain character there are and where a run of a certain rank exists. + runBWTCompression waveletTree + // runStartPositions are the starting position of each run in the original sequence + // For example: + // "annb$aa" will have the runStartPositions [0, 1, 3, 4, 5] + // This helps us map our search range from "uncompressed BWT Space" to its + // "compressed BWT Run Space". With this, we can understand which runs we need + // to consider during LF mapping. + runStartPositions runInfo + // runCumulativeCounts is the cumulative count of characters for each run. + // This helps us efficiently lookup the number of occurrences of a given + // character before a given offset in "uncompressed BWT Space" + // For Example: + // "annb$aa" will have the runCumulativeCounts: + // "a": [0, 1, 3], + // "n": [0, 2], + // "b": [0, 1], + // "$": [0, 1], + runCumulativeCounts map[string]runInfo + + // flag for turning on BWT debugging + debug bool } // Count represents the number of times the provided pattern @@ -269,7 +294,34 @@ func (bwt BWT) Len() int { // GetTransform returns the last column of the BWT transform of the original sequence. func (bwt BWT) GetTransform() string { - return bwt.lastColumn.reconstruct() + lastColumn := strings.Builder{} + lastColumn.Grow(bwt.getLenOfOriginalStringWithNullChar()) + for i := 0; i < bwt.runBWTCompression.length; i++ { + currChar := bwt.runBWTCompression.Access(i) + var currCharEnd int + if i+1 >= len(bwt.runStartPositions) { + currCharEnd = bwt.getLenOfOriginalStringWithNullChar() + } else { + currCharEnd = bwt.runStartPositions[i+1] + } + for lastColumn.Len() < currCharEnd { + lastColumn.WriteByte(currChar) + } + } + return lastColumn.String() +} + +//lint:ignore U1000 Ignore unused function. This is valuable for future debugging +func (bwt BWT) getFirstColumnStr() string { + firstColumn := strings.Builder{} + firstColumn.Grow(bwt.getLenOfOriginalStringWithNullChar()) + for i := 0; i < len(bwt.firstColumnSkipList); i++ { + e := bwt.firstColumnSkipList[i] + for j := e.openEndedInterval.start; j < e.openEndedInterval.end; j++ { + firstColumn.WriteByte(e.char) + } + } + return firstColumn.String() } // getFCharPosFromOriginalSequenceCharPos looks up mapping from the original position @@ -291,21 +343,55 @@ func (bwt BWT) getFCharPosFromOriginalSequenceCharPos(originalPos int) int { func (bwt BWT) lfSearch(pattern string) interval { searchRange := interval{start: 0, end: bwt.getLenOfOriginalStringWithNullChar()} for i := 0; i < len(pattern); i++ { + if bwt.debug { + printLFDebug(bwt, searchRange, i) + } if searchRange.end-searchRange.start <= 0 { return interval{} } c := pattern[len(pattern)-1-i] - skip, ok := bwt.lookupSkipByChar(c) - if !ok { - return interval{} - } - searchRange.start = skip.openEndedInterval.start + bwt.lastColumn.Rank(c, searchRange.start) - searchRange.end = skip.openEndedInterval.start + bwt.lastColumn.Rank(c, searchRange.end) + nextStart := bwt.getNextLfSearchOffset(c, searchRange.start) + nextEnd := bwt.getNextLfSearchOffset(c, searchRange.end) + searchRange.start = nextStart + searchRange.end = nextEnd } return searchRange } +func (bwt BWT) getNextLfSearchOffset(c byte, offset int) int { + nearestRunStart := bwt.runStartPositions.FindNearestRunStartPosition(offset + 1) + maxRunInCompressedSpace := bwt.runBWTCompression.Rank(c, nearestRunStart) + + skip, ok := bwt.lookupSkipByChar(c) + if !ok { + return 0 + } + + cumulativeCounts, ok := bwt.runCumulativeCounts[string(c)] + if !ok { + return 0 + } + + cumulativeCountBeforeMaxRun := cumulativeCounts[maxRunInCompressedSpace] + + currRunStart := bwt.runStartPositions.FindNearestRunStartPosition(offset) + currentRunChar := string(bwt.runBWTCompression.Access(currRunStart)) + extraOffset := 0 + // It is possible that an offset currently lies within a run of the same + // character we are inspecting. In this case, cumulativeCountBeforeMaxRun + // is not enough since the Max Run in this case does not include the run + // the offset is currently in. To adjust for this, we must count the number + // of character occurrences since the beginning of the run that the offset + // is currently in. + if c == currentRunChar[0] { + o := bwt.runStartPositions[nearestRunStart] + extraOffset += offset - o + } + + return skip.openEndedInterval.start + cumulativeCountBeforeMaxRun + extraOffset +} + // lookupSkipByChar looks up a skipEntry by its character in the First Column func (bwt BWT) lookupSkipByChar(c byte) (entry skipEntry, ok bool) { for i := range bwt.firstColumnSkipList { @@ -372,30 +458,64 @@ func New(sequence string) (BWT, error) { sortPrefixArray(prefixArray) suffixArray := make([]int, len(sequence)) - lastColBuilder := strings.Builder{} + charCount := 0 + runBWTCompressionBuilder := strings.Builder{} + var runStartPositions runInfo + runCumulativeCounts := make(map[string]runInfo) + + var prevChar *byte for i := 0; i < len(prefixArray); i++ { currChar := sequence[getBWTIndex(len(sequence), len(prefixArray[i]))] - lastColBuilder.WriteByte(currChar) + if prevChar == nil { + prevChar = &currChar + } + + if currChar != *prevChar { + runBWTCompressionBuilder.WriteByte(*prevChar) + runStartPositions = append(runStartPositions, i-charCount) + addRunCumulativeCountEntry(runCumulativeCounts, *prevChar, charCount) + charCount = 0 + prevChar = &currChar + } + + charCount++ suffixArray[i] = len(sequence) - len(prefixArray[i]) } + runBWTCompressionBuilder.WriteByte(*prevChar) + runStartPositions = append(runStartPositions, len(prefixArray)-charCount) + addRunCumulativeCountEntry(runCumulativeCounts, *prevChar, charCount) + fb := strings.Builder{} for i := 0; i < len(prefixArray); i++ { fb.WriteByte(prefixArray[i][0]) } - wt, err := newWaveletTreeFromString(lastColBuilder.String()) + skipList := buildSkipList(prefixArray) + + wt, err := newWaveletTreeFromString(runBWTCompressionBuilder.String()) if err != nil { return BWT{}, err } - return BWT{ - firstColumnSkipList: buildSkipList(prefixArray), - lastColumn: wt, + firstColumnSkipList: skipList, suffixArray: suffixArray, + runBWTCompression: wt, + runStartPositions: runStartPositions, + runCumulativeCounts: runCumulativeCounts, }, nil } +func addRunCumulativeCountEntry(rumCumulativeCounts map[string]runInfo, char byte, charCount int) { + cumulativeCountsOfChar, ok := rumCumulativeCounts[string(char)] + if ok { + cumulativeCountsOfChar = append(cumulativeCountsOfChar, charCount+cumulativeCountsOfChar[len(cumulativeCountsOfChar)-1]) + } else { + cumulativeCountsOfChar = runInfo{0, charCount} + } + rumCumulativeCounts[string(char)] = cumulativeCountsOfChar +} + // buildSkipList compressed the First Column of the BWT into a skip list func buildSkipList(prefixArray []string) []skipEntry { prevChar := prefixArray[0][0] @@ -457,6 +577,38 @@ func bwtRecovery(operation string, err *error) { } } +// runInfo each element of runInfo should represent an offset i where i +// corresponds to the start of a run in a given sequence. For example, +// aaaabbccc would have the run info [0, 4, 6] +type runInfo []int + +// FindNearestRunStartPosition given some offset, find the nearest starting position for the. +// beginning of a run. Another way of saying this is give me the max i where runStartPositions[i] <= offset. +// This is needed so we can understand which run an offset is a part of. +func (r runInfo) FindNearestRunStartPosition(offset int) int { + start := 0 + end := len(r) - 1 + for start < end { + mid := start + (end-start)/2 + if r[mid] < offset { + start = mid + 1 + continue + } + if r[mid] > offset { + end = mid - 1 + continue + } + + return mid + } + + if r[start] > offset { + return start - 1 + } + + return start +} + func isValidPattern(s string) (err error) { if len(s) == 0 { return errors.New("Pattern can not be empty") @@ -480,3 +632,58 @@ func validateSequenceBeforeTransforming(sequence *string) (err error) { } return nil } + +// printLFDebug this will print the first column and last column of the BWT along with some ascii visualizations. +// This is very helpful for debugging the LF mapping. For example, lets say you're in the middle of making some changes to the LF +// mapping and the test for counting starts to fail. To understand where the LF search is going wrong, you +// can do something like the below to outline which parts of the BWT are being searched some given iteration. +// +// For Example, if you had the BWT of: +// "rowrowrowyourboat" +// and wanted to Count the number of occurrences of "row" +// Then the iterations of the LF search would look something like: +// +// BWT Debug Begin Iteration: 0 +// torbyrrru$wwaoooow +// $abooooorrrrtuwwwy +// ^^^^^^^^^^^^^^^^^^X +// +// BWT Debug Begin Iteration: 1 +// torbyrrru$wwaoooow +// $abooooorrrrtuwwwy +// ______________^^^X +// +// BWT Debug Begin Iteration: 2 +// torbyrrru$wwaoooow +// $abooooorrrrtuwwwy +// _____^^^X +// +// Where: +// * '^' denotes the active search range +// * 'X' denotes one character after the end of the active search searchRange +// * '_' is visual padding to help align the active search range +// +// NOTE: It can also be helpful to include the other auxiliary data structures. For example, it can be very helpful to include +// a similar visualization for the run length compression to help debug and understand which run were used to compute the active +// search window during each iteration. +func printLFDebug(bwt BWT, searchRange interval, iteration int) { + first := bwt.getFirstColumnStr() + last := bwt.GetTransform() + lastRunCompression := bwt.runBWTCompression.reconstruct() + + fullASCIIRange := strings.Builder{} + fullASCIIRange.Grow(searchRange.end + 1) + for i := 0; i < searchRange.start; i++ { + fullASCIIRange.WriteRune('_') + } + for i := searchRange.start; i < searchRange.end; i++ { + fullASCIIRange.WriteRune('^') + } + fullASCIIRange.WriteRune('X') + + fmt.Println("BWT Debug Begin Iteration:", iteration) + fmt.Println(last) + fmt.Println(first) + fmt.Println(fullASCIIRange.String()) + fmt.Println(lastRunCompression) +} diff --git a/search/bwt/bwt_test.go b/search/bwt/bwt_test.go index 7b5512ff..15703a28 100644 --- a/search/bwt/bwt_test.go +++ b/search/bwt/bwt_test.go @@ -1,6 +1,9 @@ package bwt import ( + "bytes" + "io" + "os" "strings" "testing" @@ -23,21 +26,30 @@ func TestBWT_Count(t *testing.T) { testTable := []BWTCountTestCase{ {"uick", 3}, - {"the", 6}, {"over", 6}, {"own", 12}, {"ana", 6}, {"an", 9}, {"na", 9}, {"rown", 6}, + {"frown", 3}, + {"brown", 3}, + {"all", 6}, + {"alle", 3}, + {"alla", 3}, + {"l", 21}, + {"the", 6}, + {"town", 3}, {"townthe", 2}, - + {"nt", 5}, // patterns that should not exist + {"@", 0}, {"zzz", 0}, {"clown", 0}, {"crown", 0}, {"spark", 0}, {"brawn", 0}, + {"overtly", 0}, } for _, v := range testTable { @@ -79,17 +91,27 @@ func TestBWT_Locate(t *testing.T) { testTable := []BWTLocateTestCase{ {"uick", []int{4, 117, 230}}, - {"the", []int{0, 25, 113, 138, 226, 251}}, {"over", []int{21, 41, 134, 154, 247, 267}}, {"own", []int{10, 48, 106, 110, 123, 161, 219, 223, 236, 274, 332, 336}}, {"ana", []int{87, 89, 200, 202, 313, 315}}, {"an", []int{39, 87, 89, 152, 200, 202, 265, 313, 315}}, {"na", []int{50, 88, 90, 163, 201, 203, 276, 314, 316}}, {"rown", []int{9, 47, 122, 160, 235, 273}}, + {"frown", []int{46, 159, 272}}, + {"brown", []int{8, 121, 234}}, + {"all", []int{70, 96, 183, 209, 296, 322}}, + {"alle", []int{70, 183, 296}}, + {"alla", []int{96, 209, 322}}, + {"l", []int{28, 60, 71, 72, 74, 97, 98, 141, 173, 184, 185, 187, 210, 211, 254, 286, 297, 298, 300, 323, 324}}, + {"the", []int{0, 25, 113, 138, 226, 251}}, + {"town", []int{109, 222, 335}}, {"townthe", []int{109, 222}}, + {"nt", []int{108, 112, 221, 225, 334}}, + {"overtly", nil}, // patterns that should not exist {"zzz", nil}, + {"@", nil}, {"clown", nil}, {"crown", nil}, {"spark", nil}, @@ -252,6 +274,22 @@ func TestBWT_Extract_DoNotAllowExtractionOfLastNullChar(t *testing.T) { } } +func TestBWT_GetTransform(t *testing.T) { + baseTestStr := "thequickbrownfoxjumpsoverthelazydogwithanovertfrownafterfumblingitsparallelogramshapedbananagramallarounddowntown" // len == 112 + testStr := strings.Join([]string{baseTestStr, baseTestStr, baseTestStr}, "") + + bwt, err := New(testStr) + if err != nil { + t.Fatal(err) + } + + expected := "nnnnnnnmmmrrrrrrrrrnnnbbbhhhhhhppplllllldddmmmkkkiiieeennnyyydddppphhhlllhhhtttvvvvvvnnntttaaarrrnnnaaaooooootttsssttttttuuulllwwwgggxxxcccllleeelllbbbaaaaaaeeeaaauuuuuuaaawwwwaaaaaauuuwwwiiiaaawwwwwllldddrrrnnnssstrrrrrrttdddfffsssaaammmeeeaaaggggggeeeaaafffbbbeeeeeemmmppptttfffrrriiirrrnn$nnniiiqqqfffjjjooooooooogggooooooooooooooozzzaaa" + actual := bwt.GetTransform() + if expected != actual { + t.Fatalf("expected did not match actual\nexpected:\t%s\nactual:\t%s", expected, actual) + } +} + func TestBWT_Len(t *testing.T) { testStr := "banana" @@ -265,6 +303,89 @@ func TestBWT_Len(t *testing.T) { } } +type sparseOnesTestCase struct { + pos int + expected int +} + +func TestRunInfo_FindNearestRunStartPosition(t *testing.T) { + runs := runInfo{ + 0, + 6, + 12, + 33, + 99, + 204, + 205, + 300, + 302, + 305, + 306, + 999, + } + + testCases := []sparseOnesTestCase{ + {0, 0}, + {4, 0}, + {5, 0}, + {6, 1}, + + {7, 1}, + {8, 1}, + {9, 1}, + {11, 1}, + {12, 2}, + + {13, 2}, + {15, 2}, + {22, 2}, + {32, 2}, + {33, 3}, + + {56, 3}, + {64, 3}, + {65, 3}, + {79, 3}, + {98, 3}, + {99, 4}, + + {100, 4}, + {112, 4}, + {168, 4}, + {197, 4}, + {199, 4}, + {203, 4}, + {204, 5}, + + {205, 6}, + + {206, 6}, + {271, 6}, + {299, 6}, + {300, 7}, + + {301, 7}, + {302, 8}, + + {303, 8}, + {304, 8}, + {305, 9}, + + {306, 10}, + {307, 10}, + {999, 11}, + + {1000, 11}, + } + + for _, v := range testCases { + actual := runs.FindNearestRunStartPosition(v.pos) + if actual != v.expected { + t.Fatalf("expected RankOnes(%d) to be %d but got %d", v.pos, v.expected, actual) + } + } +} + func TestNewBWTWithSequenceContainingNullChar(t *testing.T) { nc := nullChar testStr := "banana" + nc @@ -417,3 +538,43 @@ func TestBWTRecovery(t *testing.T) { func doPanic() { panic("test panic") } +func TestPrintLFDebug(t *testing.T) { + bwt, err := New("banana") + if err != nil { + t.Fatal(err) + } + + searchRange := interval{start: 2, end: 5} + iteration := 1 + + expectedOutput := "BWT Debug Begin Iteration: 1" + "\n" + expectedOutput += "annb$aa" + "\n" + expectedOutput += "$aaabnn" + "\n" + expectedOutput += "__^^^X" + "\n" + expectedOutput += "anb$a" + "\n" + + // Redirect stdout to capture the output + old := os.Stdout + r, w, _ := os.Pipe() + os.Stdout = w + + printLFDebug(bwt, searchRange, iteration) + + // Reset stdout + w.Close() + os.Stdout = old + + // Read the captured output + var buf bytes.Buffer + _, err = io.Copy(&buf, r) + + if err != nil { + t.Fatal(err) + } + + // Compare the output with the expected value + actualOutput := buf.String() + if actualOutput != expectedOutput { + t.Errorf("Unexpected output:\nExpected:\n%s\nActual:\n%s", expectedOutput, actualOutput) + } +} diff --git a/search/bwt/wavelet.go b/search/bwt/wavelet.go index af20d5c3..7978fc8e 100644 --- a/search/bwt/wavelet.go +++ b/search/bwt/wavelet.go @@ -178,7 +178,10 @@ func (wt waveletTree) Rank(char byte, i int) int { } curr := wt.root - ci := wt.lookupCharInfo(char) + ci, ok := wt.lookupCharInfo(char) + if !ok { + return 0 + } level := 0 var rank int for !curr.isLeaf() { @@ -208,7 +211,10 @@ func (wt waveletTree) Select(char byte, rank int) int { } curr := wt.root - ci := wt.lookupCharInfo(char) + ci, ok := wt.lookupCharInfo(char) + if !ok { + return 0 + } level := 0 for !curr.isLeaf() { @@ -236,16 +242,6 @@ func (wt waveletTree) Select(char byte, rank int) int { return rank } -func (wt waveletTree) lookupCharInfo(char byte) charInfo { - for i := range wt.alpha { - if wt.alpha[i].char == char { - return wt.alpha[i] - } - } - msg := fmt.Sprintf("could not find character %s in alphabet %+v. this should not be possible and indicates that the WaveletTree is malformed", string(char), wt.alpha) - panic(msg) -} - func (wt waveletTree) reconstruct() string { str := "" for i := 0; i < wt.length; i++ { @@ -254,6 +250,15 @@ func (wt waveletTree) reconstruct() string { return str } +func (wt waveletTree) lookupCharInfo(char byte) (charInfo, bool) { + for i := range wt.alpha { + if wt.alpha[i].char == char { + return wt.alpha[i], true + } + } + return charInfo{}, false +} + type node struct { data rsaBitVector char *byte diff --git a/search/bwt/wavelet_test.go b/search/bwt/wavelet_test.go index 432f4a85..f9471f07 100644 --- a/search/bwt/wavelet_test.go +++ b/search/bwt/wavelet_test.go @@ -141,6 +141,7 @@ func TestWaveletTree_Select(t *testing.T) { } testCases := []WaveletTreeSelectTestCase{ + {"@", 0, 0}, {"A", 0, 0}, {"A", 1, 1}, {"A", 2, 2}, @@ -153,6 +154,7 @@ func TestWaveletTree_Select(t *testing.T) { {"T", 4, 18}, {"G", 4, 19}, + {"@", 5, 0}, {"T", 5, 20}, {"G", 5, 21}, {"C", 5, 22}, @@ -168,6 +170,8 @@ func TestWaveletTree_Select(t *testing.T) { {"G", 8, 32}, {"A", 11, 47}, + + {"@", 200, 0}, } for _, tc := range testCases { @@ -270,16 +274,3 @@ func TestBuildWaveletTree_ZeroAlpha(t *testing.T) { t.Fatalf("expected root to be nil but got %v", root) } } -func TestWaveletTree_LookupCharInfo_Panic(t *testing.T) { - wt := waveletTree{ - alpha: []charInfo{}, - } - - defer func() { - if r := recover(); r == nil { - t.Errorf("expected panic but got nil") - } - }() - - wt.lookupCharInfo('B') -}