diff --git a/compact/range.go b/compact/range.go index a34c0be..9246d9e 100644 --- a/compact/range.go +++ b/compact/range.go @@ -96,7 +96,8 @@ func (r *Range) Append(hash []byte, visitor VisitFn) error { // AppendRange extends the compact range by merging in the other compact range // from the right. It uses the tree hasher to calculate hashes of newly created -// nodes, and reports them through the visitor function (if non-nil). +// nodes, and reports them through the visitor function (if non-nil). The other +// range must begin where the current range ends. func (r *Range) AppendRange(other *Range, visitor VisitFn) error { if other.f != r.f { return errors.New("incompatible ranges") @@ -110,6 +111,29 @@ func (r *Range) AppendRange(other *Range, visitor VisitFn) error { return r.appendImpl(other.end, other.hashes[0], other.hashes[1:], visitor) } +// Merge extends the compact range by merging in the other compact range from +// the right. It uses the tree hasher to calculate hashes of newly created +// nodes, and reports them through the visitor function (if non-nil). The other +// range must begin between the current range's begin and end. +// +// Warning: This method modifies both this and the other Range. +// Warning: This method is experimental. +func (r *Range) Merge(other *Range, visitor VisitFn) error { + if _, err := r.intersectImpl(other); err != nil { + return err + } + return r.AppendRange(other, visitor) +} + +// Intersect returns the intersection of two compact ranges. The other range +// must begin between the current range's begin and end. +// +// Warning: This method modifies both this and the other Range. +// Warning: This method is experimental. +func (r *Range) Intersect(other *Range) (*Range, error) { + return r.intersectImpl(other) +} + // GetRootHash returns the root hash of the Merkle tree represented by this // compact range. Requires the range to start at index 0. If the range is // empty, returns nil. @@ -208,6 +232,48 @@ func (r *Range) appendImpl(end uint64, seed []byte, hashes [][]byte, visitor Vis return nil } +// intersectImpl returns the intersection of two compact ranges. It also +// modifies the `r` and `other` compact ranges in such a way that they become +// adjacent, and a subsequent AppendRange operation between them will result in +// a compact range that represents the union of the two original ranges. +func (r *Range) intersectImpl(other *Range) (*Range, error) { + if other.f != r.f { + return nil, errors.New("incompatible ranges") + } else if other.begin < r.begin { + return nil, errors.New("ranges unordered") + } + + if other.end <= r.end { // The other range is nested. + intersection := *other // Note: Force the clone. + *other = *r.f.NewEmptyRange(r.end) // Note: Force the rewrite. + return &intersection, nil + } + + begin, end := other.begin, r.end + if begin > end { // The other range is disjoint. + return nil, fmt.Errorf("ranges are disjoint: other.begin=%d, want <= %d", begin, end) + } else if begin == end { // The ranges touch ends. + return r.f.NewEmptyRange(begin), nil // The intersection is empty. + } + + // Decompose the intersection range, allocate the resulting slice of hashes. + left, right := Decompose(begin, end) + leftBits, rightBits := bits.OnesCount64(left), bits.OnesCount64(right) + hashes := make([][]byte, 0, leftBits+rightBits) + + // Cut off the intersection hashes from the `other` range. + hashes = append(hashes, other.hashes[:leftBits]...) + other.begin += left + other.hashes = other.hashes[leftBits:] + + // Cut off the intersection hashes from the `r` range. + hashes = append(hashes, r.hashes[len(r.hashes)-rightBits:]...) + r.end -= right + r.hashes = r.hashes[:len(r.hashes)-rightBits] + + return &Range{f: r.f, begin: begin, end: end, hashes: hashes}, nil +} + // getMergePath returns the merging path between the compact range [begin, mid) // and [mid, end). The path is represented as a range of bits within mid, with // bit indices [low, high). A bit value of 1 on level i of mid means that the diff --git a/compact/range_test.go b/compact/range_test.go index 9d261c1..8525830 100644 --- a/compact/range_test.go +++ b/compact/range_test.go @@ -220,7 +220,7 @@ func TestGoldenRanges(t *testing.T) { } // Merge down from [339,340) to [0,340) by prepending single entries. -func TestMergeBackwards(t *testing.T) { +func TestAppendBackwards(t *testing.T) { const numNodes = uint64(340) tree, visit := newTree(t, numNodes) rng := factory.NewEmptyRange(numNodes) @@ -243,7 +243,7 @@ func TestMergeBackwards(t *testing.T) { // Build ranges [0, 13), [13, 26), ... [208,220) by appending single entries to // each. Then append those ranges one by one to [0,0), to get [0,220). -func TestMergeInBatches(t *testing.T) { +func TestAppendInBatches(t *testing.T) { const numNodes = uint64(220) const batch = uint64(13) tree, visit := newTree(t, numNodes) @@ -274,7 +274,7 @@ func TestMergeInBatches(t *testing.T) { } // Build many trees of random size by randomly merging their sub-ranges. -func TestMergeRandomly(t *testing.T) { +func TestAppendRandomly(t *testing.T) { for seed := int64(1); seed < 100; seed++ { t.Run(fmt.Sprintf("seed:%d", seed), func(t *testing.T) { rnd := rand.New(rand.NewSource(seed)) @@ -307,6 +307,63 @@ func TestMergeRandomly(t *testing.T) { } } +func TestMergeAndIntersect(t *testing.T) { + const size = uint64(20) + tree, visit := newTree(t, size) + getRange := func(begin, end uint64) *Range { + cr := factory.NewEmptyRange(begin) + for i := begin; i < end; i++ { + if err := cr.Append(tree.leaf(i), visit); err != nil { + t.Fatalf("Append: %v", err) + } + } + return cr + } + + type pair struct { + begin, end uint64 + } + var pairs []pair + for begin := uint64(0); begin <= size; begin++ { + for end := begin; end <= size; end++ { + pairs = append(pairs, pair{begin: begin, end: end}) + } + } + for _, first := range pairs { + for _, second := range pairs { + if second.begin < first.begin || second.begin > first.end { + continue + } + t.Logf("%+v : %+v", first, second) + rng := getRange(first.begin, first.end) + other := getRange(second.begin, second.end) + if err := rng.Merge(other, visit); err != nil { + t.Fatalf("Merge: %v", err) + } + checkRangeBounds(t, rng, first.begin, max(first.end, second.end)) + tree.verifyRange(t, rng, true) + + rng = getRange(first.begin, first.end) + other = getRange(second.begin, second.end) + inters, err := rng.Intersect(other) + if err != nil { + t.Fatalf("Intersect: %v", err) + } + checkRangeBounds(t, inters, second.begin, min(first.end, second.end)) + tree.verifyRange(t, inters, true) + } + } +} + +func checkRangeBounds(t *testing.T, r *Range, begin, end uint64) { + if got, want := r.Begin(), begin; got != want { + t.Fatalf("range [%d, %d): want begin %d", got, r.End(), want) + } + if got, want := r.End(), end; got != want { + t.Fatalf("range [%d, %d): want end %d", r.Begin(), got, want) + } +} + func TestNewRange(t *testing.T) { const numNodes = uint64(123) tree, visit := newTree(t, numNodes) @@ -813,3 +870,17 @@ func shorten(hash []byte) []byte { } return hash[:4] } + +func min(a, b uint64) uint64 { + if b < a { + a = b + } + return a +} + +func max(a, b uint64) uint64 { + if b > a { + a = b + } + return a +}