Skip to content

Commit

Permalink
Remove dependency on nebulous in proof generation (#68)
Browse files Browse the repository at this point in the history
Use the same recursive approach as in nmt.Root() to generate proof
Removes the dependency on Celestiaorg/merkletree in proof generation

- [x] Passing full testing
- [x] Fixes #15

Co-authored-by: John Adler <[email protected]>
  • Loading branch information
rahulghangas and adlerjohn authored Oct 12, 2022
1 parent 6274243 commit 230d27f
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 107 deletions.
10 changes: 0 additions & 10 deletions internal/doc.go

This file was deleted.

52 changes: 0 additions & 52 deletions internal/subtree_hasher.go

This file was deleted.

77 changes: 57 additions & 20 deletions nmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,11 @@ import (
"hash"
"math/bits"

"github.com/celestiaorg/merkletree"
"github.com/celestiaorg/nmt/internal"
"github.com/celestiaorg/nmt/namespace"
)

var (
ErrInvalidRange = errors.New("invalid proof range")
ErrMismatchedNamespaceSize = errors.New("mismatching namespace sizes")
ErrInvalidPushOrder = errors.New("pushed data has to be lexicographically ordered by namespace IDs")
noOp = func(hash []byte, children ...[]byte) {}
Expand Down Expand Up @@ -127,12 +126,11 @@ func (n NamespacedMerkleTree) Prove(index int) (Proof, error) {
func (n NamespacedMerkleTree) ProveRange(start, end int) (Proof, error) {
isMaxNsIgnored := n.treeHasher.IsMaxNamespaceIDIgnored()
n.computeLeafHashesIfNecessary()
subTreeHasher := internal.NewCachedSubtreeHasher(n.leafHashes, n.treeHasher)
// TODO: store nodes and re-use the hashes instead recomputing parts of the tree here
proof, err := merkletree.BuildRangeProof(start, end, subTreeHasher)
if err != nil {
return NewEmptyRangeProof(isMaxNsIgnored), err
if start < 0 || start >= end || end > len(n.leafHashes) {
return NewEmptyRangeProof(isMaxNsIgnored), ErrInvalidRange
}
proof := n.buildRangeProof(start, end)

return NewInclusionProof(start, end, proof, isMaxNsIgnored), nil
}
Expand Down Expand Up @@ -171,27 +169,66 @@ func (n NamespacedMerkleTree) ProveNamespace(nID namespace.ID) (Proof, error) {
// the range it would be in (to generate a proof of absence and to return
// the corresponding leaf hashes).
n.computeLeafHashesIfNecessary()
subTreeHasher := internal.NewCachedSubtreeHasher(n.leafHashes, n.treeHasher)
var err error
proof, err := merkletree.BuildRangeProof(proofStart, proofEnd, subTreeHasher)
if err != nil {
// This should never happen.
// TODO would be good to back this by more tests and fuzzing.
return Proof{}, fmt.Errorf(
"unexpected err: %w on nID: %v, range: [%v, %v)",
err,
nID,
proofStart,
proofEnd,
)
}
proof := n.buildRangeProof(proofStart, proofEnd)

if found {
return NewInclusionProof(proofStart, proofEnd, proof, isMaxNsIgnored), nil
}
return NewAbsenceProof(proofStart, proofEnd, proof, n.leafHashes[proofStart], isMaxNsIgnored), nil
}

func (n NamespacedMerkleTree) buildRangeProof(proofStart, proofEnd int) [][]byte {
proof := [][]byte{}
var recurse func(start, end int, includeNode bool) []byte
recurse = func(start, end int, includeNode bool) []byte {
if start >= len(n.leafHashes) {
return nil
}

// reached a leaf
if end-start == 1 {
leafHash := n.leafHashes[start]
// if current range does not overlap with proof range, add a node to proofs
if (start < proofStart || start >= proofEnd) && includeNode {
proof = append(proof, leafHash)
}
return leafHash
}

// recursively get left and right subtree
newIncludeNode := includeNode
if (end <= proofStart || start >= proofEnd) && includeNode {
newIncludeNode = false
}

k := getSplitPoint(end - start)
left := recurse(start, start+k, newIncludeNode)
right := recurse(start+k, end, newIncludeNode)

// only right leaf/subtree can be non-existent
var hash []byte
if right == nil {
hash = left
} else {
hash = n.treeHasher.HashNode(left, right)
}

// highest node in subtree that lies outside proof range
if includeNode && !newIncludeNode {
proof = append(proof, hash)
}

return hash
}

fullTreeSize := getSplitPoint(len(n.leafHashes)) * 2
if fullTreeSize < 1 {
fullTreeSize = 1
}
recurse(0, fullTreeSize, true)
return proof
}

// Get returns leaves for the given namespace.ID.
func (n NamespacedMerkleTree) Get(nID namespace.ID) [][]byte {
_, start, end := n.foundInRange(nID)
Expand Down
33 changes: 11 additions & 22 deletions nmt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -509,15 +509,14 @@ func TestNodeVisitor(t *testing.T) {

func TestNamespacedMerkleTree_ProveErrors(t *testing.T) {
tests := []struct {
name string
nidLen int
index int
pushData []namespaceDataPair
wantErr bool
wantPanic bool
name string
nidLen int
index int
pushData []namespaceDataPair
wantErr bool
}{
{"negative index", 1, -1, generateLeafData(1, 0, 10, []byte("_data")), false, true},
{"too large index", 1, 11, generateLeafData(1, 0, 10, []byte("_data")), true, false},
{"negative index", 1, -1, generateLeafData(1, 0, 10, []byte("_data")), true},
{"too large index", 1, 11, generateLeafData(1, 0, 10, []byte("_data")), true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand All @@ -534,20 +533,10 @@ func TestNamespacedMerkleTree_ProveErrors(t *testing.T) {
t.Fatalf("Prove() failed on valid index: %v, err: %v", i, err)
}
}
if tt.wantPanic {
shouldPanic(t, func() {
_, err := n.Prove(tt.index)
if (err != nil) != tt.wantErr {
t.Errorf("Prove() error = %v, wantErr %v", err, tt.wantErr)
return
}
})
} else {
_, err := n.Prove(tt.index)
if (err != nil) != tt.wantErr {
t.Errorf("Prove() error = %v, wantErr %v", err, tt.wantErr)
return
}
_, err := n.Prove(tt.index)
if (err != nil) != tt.wantErr {
t.Errorf("Prove() error = %v, wantErr %v", err, tt.wantErr)
return
}
})
}
Expand Down
50 changes: 47 additions & 3 deletions proof_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,58 @@ package nmt
import (
"bytes"
"crypto/sha256"
"io"
"testing"

"github.com/celestiaorg/merkletree"

"github.com/celestiaorg/nmt/internal"
"github.com/celestiaorg/nmt/namespace"
)

type treeHasher interface {
merkletree.TreeHasher
Size() int
}

// CachedSubtreeHasher implements SubtreeHasher using a set of precomputed
// leaf hashes.
type cachedSubtreeHasher struct {
leafHashes [][]byte
treeHasher
}

// NextSubtreeRoot implements SubtreeHasher.
func (csh *cachedSubtreeHasher) NextSubtreeRoot(subtreeSize int) ([]byte, error) {
if len(csh.leafHashes) == 0 {
return nil, io.EOF
}
tree := merkletree.NewFromTreehasher(csh.treeHasher)
for i := 0; i < subtreeSize && len(csh.leafHashes) > 0; i++ {
if err := tree.PushSubTree(0, csh.leafHashes[0]); err != nil {
return nil, err
}
csh.leafHashes = csh.leafHashes[1:]
}
return tree.Root(), nil
}

// Skip implements SubtreeHasher.
func (csh *cachedSubtreeHasher) Skip(n int) error {
if n > len(csh.leafHashes) {
return io.ErrUnexpectedEOF
}
csh.leafHashes = csh.leafHashes[n:]
return nil
}

// newCachedSubtreeHasher creates a CachedSubtreeHasher using the specified
// leaf hashes and hash function.
func newCachedSubtreeHasher(leafHashes [][]byte, h treeHasher) *cachedSubtreeHasher {
return &cachedSubtreeHasher{
leafHashes: leafHashes,
treeHasher: h,
}
}

func TestProof_VerifyNamespace_False(t *testing.T) {
const testNidLen = 3

Expand Down Expand Up @@ -98,7 +142,7 @@ func TestProof_VerifyNamespace_False(t *testing.T) {

func rangeProof(t *testing.T, n *NamespacedMerkleTree, start, end int) [][]byte {
n.computeLeafHashesIfNecessary()
subTreeHasher := internal.NewCachedSubtreeHasher(n.leafHashes, n.treeHasher)
subTreeHasher := newCachedSubtreeHasher(n.leafHashes, n.treeHasher)
incompleteRange, err := merkletree.BuildRangeProof(start, end, subTreeHasher)
if err != nil {
t.Fatalf("Could not create range proof: %v", err)
Expand Down

0 comments on commit 230d27f

Please sign in to comment.