Skip to content

Commit

Permalink
implement get proof rpc endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
shotasilagadzetaal committed Jan 1, 2025
1 parent 6f33aa7 commit a19d2f9
Show file tree
Hide file tree
Showing 3 changed files with 251 additions and 75 deletions.
110 changes: 110 additions & 0 deletions erigon-lib/commitment/hex_patricia_hashed.go
Original file line number Diff line number Diff line change
Expand Up @@ -1879,6 +1879,116 @@ func (hph *HexPatriciaHashed) RootHash() ([]byte, error) {
return rootHash[1:], nil // first byte is 128+hash_len=160
}

func (hph *HexPatriciaHashed) GenerateProofTrie(ctx context.Context, updates *Updates, expectedRootHash []byte, logPrefix string) (proofTrie *trie.Trie, rootHash []byte, err error) {
var (
m runtime.MemStats
ki uint64

updatesCount = updates.Size()
logEvery = time.NewTicker(20 * time.Second)
)
defer logEvery.Stop()
var tries []*trie.Trie = make([]*trie.Trie, 0, len(updates.keys)) // slice of tries, i.e the witness for each key, these will be all merged into single trie
err = updates.HashSort(ctx, func(hashedKey, plainKey []byte, stateUpdate *Update) error {
select {
case <-logEvery.C:
dbg.ReadMemStats(&m)
log.Info(fmt.Sprintf("[%s][agg] computing trie", logPrefix),
"progress", fmt.Sprintf("%s/%s", common.PrettyCounter(ki), common.PrettyCounter(updatesCount)),
"alloc", common.ByteCount(m.Alloc), "sys", common.ByteCount(m.Sys))

default:
}

var tr *trie.Trie
var computedRootHash []byte

fmt.Printf("\n%d/%d) plainKey [%x] hashedKey [%x] currentKey [%x]\n", ki+1, updatesCount, plainKey, hashedKey, hph.currentKey[:hph.currentKeyLen])

if len(plainKey) == 20 { // account
account, err := hph.ctx.Account(plainKey)
if err != nil {
return fmt.Errorf("account with plainkey=%x not found: %w", plainKey, err)
} else {
addrHash := ecrypto.Keccak256(plainKey)
fmt.Printf("account with plainKey=%x, addrHash=%x FOUND = %v\n", plainKey, addrHash, account)
}
} else {
storage, err := hph.ctx.Storage(plainKey)
if err != nil {
return fmt.Errorf("storage with plainkey=%x not found: %w", plainKey, err)
}
fmt.Printf("storage found = %v\n", storage.Storage)
}
fmt.Println("shota unfolding", hex.EncodeToString(plainKey), hex.EncodeToString(hashedKey))
// Keep folding until the currentKey is the prefix of the key we modify
for hph.needFolding(hashedKey) {
if err := hph.fold(); err != nil {
return fmt.Errorf("fold: %w", err)
}
}
// Now unfold until we step on an empty cell
for unfolding := hph.needUnfolding(hashedKey); unfolding > 0; unfolding = hph.needUnfolding(hashedKey) {
if err := hph.unfold(hashedKey, unfolding); err != nil {
return fmt.Errorf("unfold: %w", err)
}
}
hph.PrintGrid()

// convert grid to trie.Trie
tr, err = hph.ToTrie(hashedKey, nil) // build witness trie for this key, based on the current state of the grid
if err != nil {
return err
}
computedRootHash = tr.Root()
fmt.Printf("computedRootHash = %x\n", computedRootHash)

if !bytes.Equal(computedRootHash, expectedRootHash) {
err = fmt.Errorf("root hash mismatch computedRootHash(%x)!=expectedRootHash(%x)", computedRootHash, expectedRootHash)
return err
}

tries = append(tries, tr)
ki++
return nil
})

if err != nil {
return nil, nil, fmt.Errorf("hash sort failed: %w", err)
}

// Folding everything up to the root
for hph.activeRows > 0 {
if err := hph.fold(); err != nil {
return nil, nil, fmt.Errorf("final fold: %w", err)
}
}

rootHash, err = hph.RootHash()
if err != nil {
return nil, nil, fmt.Errorf("root hash evaluation failed: %w", err)
}
if hph.trace {
fmt.Printf("root hash %x updates %d\n", rootHash, updatesCount)
}

// merge all individual tries
proofTrie, err = trie.MergeTries(tries)
if err != nil {
return nil, nil, err
}

proofTrieRootHash := proofTrie.Root()

fmt.Printf("mergedTrieRootHash = %x\n", proofTrieRootHash)

if !bytes.Equal(proofTrieRootHash, expectedRootHash) {
return nil, nil, fmt.Errorf("root hash mismatch witnessTrieRootHash(%x)!=expectedRootHash(%x)", proofTrieRootHash, expectedRootHash)
}

return proofTrie, rootHash, nil
}

// Generate the block witness. This works by loading each key from the list of updates (they are not really updates since we won't modify the trie,
// but currently need to be defined like that for the fold/unfold algorithm) into the grid and traversing the grid to convert it into `trie.Trie`.
// All the individual tries are combined to create the final witness trie.
Expand Down
19 changes: 12 additions & 7 deletions erigon-lib/trie/proof.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,9 @@ import (
// If the trie does not contain a value for key, the returned proof contains all
// nodes of the longest existing prefix of the key (at least the root node), ending
// with the node that proves the absence of the key.
func (t *Trie) Prove(key []byte, fromLevel int, storage bool) ([][]byte, error) {
func (t *Trie) Prove(key []byte, fromLevel int, storage bool) ([][]byte, []byte, error) {
var proof [][]byte
var value []byte
hasher := newHasher(t.valueNodesRLPEncoded)
defer returnHasherToPool(hasher)
// Collect all nodes on the path to key.
Expand All @@ -51,7 +52,7 @@ func (t *Trie) Prove(key []byte, fromLevel int, storage bool) ([][]byte, error)
if rlp, err := hasher.hashChildren(n, 0); err == nil {
proof = append(proof, libcommon.CopyBytes(rlp))
} else {
return nil, err
return nil, value, err
}
}
nKey := n.Key
Expand All @@ -63,17 +64,20 @@ func (t *Trie) Prove(key []byte, fromLevel int, storage bool) ([][]byte, error)
tn = nil
} else {
tn = n.Val
if valNode, ok := n.Val.(ValueNode); ok {
value = valNode
}
key = key[len(nKey):]
}
if fromLevel > 0 {
fromLevel -= len(nKey)
fromLevel--
}
case *DuoNode:
if fromLevel == 0 {
if rlp, err := hasher.hashChildren(n, 0); err == nil {
proof = append(proof, libcommon.CopyBytes(rlp))
} else {
return nil, err
return nil, value, err
}
}
i1, i2 := n.childrenIdx()
Expand All @@ -95,7 +99,7 @@ func (t *Trie) Prove(key []byte, fromLevel int, storage bool) ([][]byte, error)
if rlp, err := hasher.hashChildren(n, 0); err == nil {
proof = append(proof, libcommon.CopyBytes(rlp))
} else {
return nil, err
return nil, value, err
}
}
tn = n.Children[key[0]]
Expand All @@ -110,14 +114,15 @@ func (t *Trie) Prove(key []byte, fromLevel int, storage bool) ([][]byte, error)
tn = nil
}
case ValueNode:
value = n
tn = nil
case HashNode:
return nil, fmt.Errorf("encountered hashNode unexpectedly, key %x, fromLevel %d", key, fromLevel)
return nil, value, fmt.Errorf("encountered hashNode unexpectedly, key %x, fromLevel %d", key, fromLevel)
default:
panic(fmt.Sprintf("%T: invalid node: %v", tn, tn))
}
}
return proof, nil
return proof, value, nil
}

func decodeRef(buf []byte) (Node, []byte, error) {
Expand Down
Loading

0 comments on commit a19d2f9

Please sign in to comment.