From a19d2f9d78d976ec7513cc9a1d79eb349984f64a Mon Sep 17 00:00:00 2001 From: "shota.silagadze" Date: Tue, 31 Dec 2024 22:43:42 +0400 Subject: [PATCH] implement get proof rpc endpoint --- erigon-lib/commitment/hex_patricia_hashed.go | 110 +++++++++++ erigon-lib/trie/proof.go | 19 +- turbo/jsonrpc/eth_call.go | 197 ++++++++++++------- 3 files changed, 251 insertions(+), 75 deletions(-) diff --git a/erigon-lib/commitment/hex_patricia_hashed.go b/erigon-lib/commitment/hex_patricia_hashed.go index 46dd0251fb5..635971a6906 100644 --- a/erigon-lib/commitment/hex_patricia_hashed.go +++ b/erigon-lib/commitment/hex_patricia_hashed.go @@ -1879,6 +1879,116 @@ func (hph *HexPatriciaHashed) RootHash() ([]byte, error) { return rootHash[1:], nil // first byte is 128+hash_len=160 } +func (hph *HexPatriciaHashed) GenerateProofTrie(ctx context.Context, updates *Updates, expectedRootHash []byte, logPrefix string) (proofTrie *trie.Trie, rootHash []byte, err error) { + var ( + m runtime.MemStats + ki uint64 + + updatesCount = updates.Size() + logEvery = time.NewTicker(20 * time.Second) + ) + defer logEvery.Stop() + var tries []*trie.Trie = make([]*trie.Trie, 0, len(updates.keys)) // slice of tries, i.e the witness for each key, these will be all merged into single trie + err = updates.HashSort(ctx, func(hashedKey, plainKey []byte, stateUpdate *Update) error { + select { + case <-logEvery.C: + dbg.ReadMemStats(&m) + log.Info(fmt.Sprintf("[%s][agg] computing trie", logPrefix), + "progress", fmt.Sprintf("%s/%s", common.PrettyCounter(ki), common.PrettyCounter(updatesCount)), + "alloc", common.ByteCount(m.Alloc), "sys", common.ByteCount(m.Sys)) + + default: + } + + var tr *trie.Trie + var computedRootHash []byte + + fmt.Printf("\n%d/%d) plainKey [%x] hashedKey [%x] currentKey [%x]\n", ki+1, updatesCount, plainKey, hashedKey, hph.currentKey[:hph.currentKeyLen]) + + if len(plainKey) == 20 { // account + account, err := hph.ctx.Account(plainKey) + if err != nil { + return fmt.Errorf("account with plainkey=%x not found: %w", plainKey, err) + } else { + addrHash := ecrypto.Keccak256(plainKey) + fmt.Printf("account with plainKey=%x, addrHash=%x FOUND = %v\n", plainKey, addrHash, account) + } + } else { + storage, err := hph.ctx.Storage(plainKey) + if err != nil { + return fmt.Errorf("storage with plainkey=%x not found: %w", plainKey, err) + } + fmt.Printf("storage found = %v\n", storage.Storage) + } + fmt.Println("shota unfolding", hex.EncodeToString(plainKey), hex.EncodeToString(hashedKey)) + // Keep folding until the currentKey is the prefix of the key we modify + for hph.needFolding(hashedKey) { + if err := hph.fold(); err != nil { + return fmt.Errorf("fold: %w", err) + } + } + // Now unfold until we step on an empty cell + for unfolding := hph.needUnfolding(hashedKey); unfolding > 0; unfolding = hph.needUnfolding(hashedKey) { + if err := hph.unfold(hashedKey, unfolding); err != nil { + return fmt.Errorf("unfold: %w", err) + } + } + hph.PrintGrid() + + // convert grid to trie.Trie + tr, err = hph.ToTrie(hashedKey, nil) // build witness trie for this key, based on the current state of the grid + if err != nil { + return err + } + computedRootHash = tr.Root() + fmt.Printf("computedRootHash = %x\n", computedRootHash) + + if !bytes.Equal(computedRootHash, expectedRootHash) { + err = fmt.Errorf("root hash mismatch computedRootHash(%x)!=expectedRootHash(%x)", computedRootHash, expectedRootHash) + return err + } + + tries = append(tries, tr) + ki++ + return nil + }) + + if err != nil { + return nil, nil, fmt.Errorf("hash sort failed: %w", err) + } + + // Folding everything up to the root + for hph.activeRows > 0 { + if err := hph.fold(); err != nil { + return nil, nil, fmt.Errorf("final fold: %w", err) + } + } + + rootHash, err = hph.RootHash() + if err != nil { + return nil, nil, fmt.Errorf("root hash evaluation failed: %w", err) + } + if hph.trace { + fmt.Printf("root hash %x updates %d\n", rootHash, updatesCount) + } + + // merge all individual tries + proofTrie, err = trie.MergeTries(tries) + if err != nil { + return nil, nil, err + } + + proofTrieRootHash := proofTrie.Root() + + fmt.Printf("mergedTrieRootHash = %x\n", proofTrieRootHash) + + if !bytes.Equal(proofTrieRootHash, expectedRootHash) { + return nil, nil, fmt.Errorf("root hash mismatch witnessTrieRootHash(%x)!=expectedRootHash(%x)", proofTrieRootHash, expectedRootHash) + } + + return proofTrie, rootHash, nil +} + // Generate the block witness. This works by loading each key from the list of updates (they are not really updates since we won't modify the trie, // but currently need to be defined like that for the fold/unfold algorithm) into the grid and traversing the grid to convert it into `trie.Trie`. // All the individual tries are combined to create the final witness trie. diff --git a/erigon-lib/trie/proof.go b/erigon-lib/trie/proof.go index 6e54235cc18..40925f56655 100644 --- a/erigon-lib/trie/proof.go +++ b/erigon-lib/trie/proof.go @@ -36,8 +36,9 @@ import ( // If the trie does not contain a value for key, the returned proof contains all // nodes of the longest existing prefix of the key (at least the root node), ending // with the node that proves the absence of the key. -func (t *Trie) Prove(key []byte, fromLevel int, storage bool) ([][]byte, error) { +func (t *Trie) Prove(key []byte, fromLevel int, storage bool) ([][]byte, []byte, error) { var proof [][]byte + var value []byte hasher := newHasher(t.valueNodesRLPEncoded) defer returnHasherToPool(hasher) // Collect all nodes on the path to key. @@ -51,7 +52,7 @@ func (t *Trie) Prove(key []byte, fromLevel int, storage bool) ([][]byte, error) if rlp, err := hasher.hashChildren(n, 0); err == nil { proof = append(proof, libcommon.CopyBytes(rlp)) } else { - return nil, err + return nil, value, err } } nKey := n.Key @@ -63,17 +64,20 @@ func (t *Trie) Prove(key []byte, fromLevel int, storage bool) ([][]byte, error) tn = nil } else { tn = n.Val + if valNode, ok := n.Val.(ValueNode); ok { + value = valNode + } key = key[len(nKey):] } if fromLevel > 0 { - fromLevel -= len(nKey) + fromLevel-- } case *DuoNode: if fromLevel == 0 { if rlp, err := hasher.hashChildren(n, 0); err == nil { proof = append(proof, libcommon.CopyBytes(rlp)) } else { - return nil, err + return nil, value, err } } i1, i2 := n.childrenIdx() @@ -95,7 +99,7 @@ func (t *Trie) Prove(key []byte, fromLevel int, storage bool) ([][]byte, error) if rlp, err := hasher.hashChildren(n, 0); err == nil { proof = append(proof, libcommon.CopyBytes(rlp)) } else { - return nil, err + return nil, value, err } } tn = n.Children[key[0]] @@ -110,14 +114,15 @@ func (t *Trie) Prove(key []byte, fromLevel int, storage bool) ([][]byte, error) tn = nil } case ValueNode: + value = n tn = nil case HashNode: - return nil, fmt.Errorf("encountered hashNode unexpectedly, key %x, fromLevel %d", key, fromLevel) + return nil, value, fmt.Errorf("encountered hashNode unexpectedly, key %x, fromLevel %d", key, fromLevel) default: panic(fmt.Sprintf("%T: invalid node: %v", tn, tn)) } } - return proof, nil + return proof, value, nil } func decodeRef(buf []byte) (Node, []byte, error) { diff --git a/turbo/jsonrpc/eth_call.go b/turbo/jsonrpc/eth_call.go index eb4c18780be..848d9bee36a 100644 --- a/turbo/jsonrpc/eth_call.go +++ b/turbo/jsonrpc/eth_call.go @@ -19,14 +19,17 @@ package jsonrpc import ( "bytes" "context" + "encoding/hex" "errors" "fmt" "math/big" + "unsafe" "github.com/erigontech/erigon-lib/kv/dbutils" "github.com/erigontech/erigon-lib/trie" "github.com/erigontech/erigon-lib/commitment" + "github.com/erigontech/erigon-lib/common" libstate "github.com/erigontech/erigon-lib/state" "github.com/holiman/uint256" "google.golang.org/grpc" @@ -346,88 +349,146 @@ func (api *APIImpl) EstimateGas(ctx context.Context, argsOrNil *ethapi2.CallArgs // GetProof is partially implemented; no Storage proofs, and proofs must be for // blocks within maxGetProofRewindBlockCount blocks of the head. func (api *APIImpl) GetProof(ctx context.Context, address libcommon.Address, storageKeys []libcommon.Hash, blockNrOrHash rpc.BlockNumberOrHash) (*accounts.AccProofResult, error) { - return nil, errors.New("not supported by Erigon3") - /* - tx, err := api.db.BeginTemporalRo(ctx) - if err != nil { - return nil, err - } - defer tx.Rollback() + roTx, err := api.db.BeginRo(ctx) + if err != nil { + return nil, err + } + defer roTx.Rollback() - blockNr, _, _, err := rpchelper.GetBlockNumber(blockNrOrHash, tx, api.filters) - if err != nil { - return nil, err - } + requestedBlockNr, _, _, err := rpchelper.GetCanonicalBlockNumber(ctx, blockNrOrHash, roTx, api._blockReader, api.filters) + if err != nil { + return nil, err + } - header, err := api._blockReader.HeaderByNumber(ctx, tx, blockNr) - if err != nil { - return nil, err - } + latestBlock, err := rpchelper.GetLatestBlockNumber(roTx) + if err != nil { + return nil, err + } - latestBlock, err := rpchelper.GetLatestBlockNumber(tx) - if err != nil { - return nil, err - } + if requestedBlockNr != latestBlock { + return nil, errors.New("proofs are available only for the 'latest' block") + } - if latestBlock < blockNr { - // shouldn't happen, but check anyway - return nil, fmt.Errorf("block number is in the future latest=%d requested=%d", latestBlock, blockNr) - } + return api.getProof(ctx, address, storageKeys, rpc.BlockNumberOrHashWithNumber(rpc.BlockNumber(latestBlock)), api.db, api.logger) +} - rl := trie.NewRetainList(0) - var loader *trie.FlatDBTrieLoader - if blockNr < latestBlock { - if latestBlock-blockNr > uint64(api.MaxGetProofRewindBlockCount) { - return nil, fmt.Errorf("requested block is too old, block must be within %d blocks of the head block number (currently %d)", uint64(api.MaxGetProofRewindBlockCount), latestBlock) - } - batch := membatchwithdb.NewMemoryBatch(tx, api.dirs.Tmp, api.logger) - defer batch.Rollback() +func (api *APIImpl) getProof(ctx context.Context, address libcommon.Address, storageKeys []libcommon.Hash, blockNrOrHash rpc.BlockNumberOrHash, db kv.RoDB, logger log.Logger) (*accounts.AccProofResult, error) { + roTx, err := api.db.BeginRo(ctx) + if err != nil { + return nil, err + } + defer roTx.Rollback() - unwindState := &stagedsync.UnwindState{UnwindPoint: blockNr} - stageState := &stagedsync.StageState{BlockNumber: latestBlock} + blockNr, hash, _, err := rpchelper.GetCanonicalBlockNumber(ctx, blockNrOrHash, roTx, api._blockReader, api.filters) + if err != nil { + return nil, err + } - hashStageCfg := stagedsync.StageHashStateCfg(nil, api.dirs, api.historyV3(batch)) - if err := stagedsync.UnwindHashStateStage(unwindState, stageState, batch, hashStageCfg, ctx, api.logger); err != nil { - return nil, err - } + // Witness for genesis block is empty + if blockNr == 0 { + return nil, errors.New("block not found") + } - interHashStageCfg := stagedsync.StageTrieCfg(nil, false, false, false, api.dirs.Tmp, api._blockReader, nil, api.historyV3(batch), api._agg) - loader, err = stagedsync.UnwindIntermediateHashesForTrieLoader("eth_getProof", rl, unwindState, stageState, batch, interHashStageCfg, nil, nil, ctx.Done(), api.logger) - if err != nil { - return nil, err - } - tx = batch - } else { - loader = trie.NewFlatDBTrieLoader("eth_getProof", rl, nil, nil, false) - } + block, err := api.blockWithSenders(ctx, roTx, hash, blockNr) + if err != nil { + return nil, err + } + if block == nil { + return nil, nil + } - reader, err := rpchelper.CreateStateReader(ctx, tx, blockNrOrHash, 0, api.filters, api.stateCache, "") - if err != nil { - return nil, err - } - a, err := reader.ReadAccountData(address) - if err != nil { - return nil, err - } - if a == nil { - a = &accounts.Account{} - } - pr, err := trie.NewProofRetainer(address, a, storageKeys, rl) - if err != nil { - return nil, err - } + // Compute the witness if it's for a tx or it's not present in db + prevHeader, err := api._blockReader.HeaderByNumber(ctx, roTx, blockNr) + if err != nil { + return nil, err + } - loader.SetProofRetainer(pr) - root, err := loader.CalcTrieRoot(tx, nil) + roTx2, err := db.BeginRo(ctx) + if err != nil { + return nil, err + } + defer roTx2.Rollback() + txBatch2 := membatchwithdb.NewMemoryBatch(roTx2, "", logger) + defer txBatch2.Rollback() + + domains, err := libstate.NewSharedDomains(txBatch2, log.New()) + if err != nil { + return nil, err + } + sdCtx := libstate.NewSharedDomainsCommitmentContext(domains, commitment.ModeUpdate, commitment.VariantHexPatriciaTrie) + patricieTrie := sdCtx.Trie() + hph, ok := patricieTrie.(*commitment.HexPatriciaHashed) + if !ok { + return nil, errors.New("casting to HexPatriciaTrieHashed failed") + } + + // define these keys as "updates", but we are not really updating anything, we just want to load them into the grid, + // so this is just to satisfy the current hex patricia trie api. + updates := commitment.NewUpdates(commitment.ModeDirect, sdCtx.TempDir(), hph.HashAndNibblizeKey) + updates.TouchPlainKey(string(address.Bytes()), nil, updates.TouchAccount) + for _, storageKey := range storageKeys { + updates.TouchPlainKey(string(common.FromHex(address.Hex()[2:]+storageKey.String()[2:])), nil, updates.TouchStorage) + } + hph.SetTrace(false) // disable tracing to avoid mixing with trace from witness computation + + // generate the block witness, this works by loading the merkle paths to the touched keys (they are loaded from the state at block #blockNr-1) + proofTrie, proofRootHash, err := hph.GenerateProofTrie(ctx, updates, prevHeader.Root[:], "eth_getProof") + if err != nil { + return nil, err + } + + a, found := proofTrie.GetAccount(crypto.Keccak256(address.Bytes())) + if !found { + return nil, errors.New("account not found in the trie") + } + + // verify hash + if !bytes.Equal(proofRootHash, prevHeader.Root[:]) { + return nil, fmt.Errorf("witness root hash mismatch actual(%x)!=expected(%x)", proofRootHash, prevHeader.Root[:]) + } + + // set initial response fields + proof := &accounts.AccProofResult{ + Address: address, + Balance: (*hexutil.Big)(a.Balance.ToBig()), + Nonce: hexutil.Uint64(a.Nonce), + CodeHash: a.CodeHash, + StorageHash: a.Root, + } + + // get account proof + accountProof, _, err := proofTrie.Prove(crypto.Keccak256(address.Bytes()), 0, false) + proof.AccountProof = *(*[]hexutility.Bytes)(unsafe.Pointer(&accountProof)) + + // get storage key proofs + proof.StorageProof = make([]accounts.StorProofResult, len(storageKeys)) + for i, keyHash := range storageKeys { + // prepare key path (keccak(address) | keccak(key)) + var fullKey []byte + fullKey = append(fullKey, crypto.Keccak256(address.Bytes())...) + fullKey = append(fullKey, crypto.Keccak256(keyHash.Bytes())...) + + // get proof for the given key + storageProof, value, err := proofTrie.Prove(fullKey, len(proof.AccountProof), true) if err != nil { - return nil, err + return nil, errors.New("cannot verify store proof") } - if root != header.Root { - return nil, fmt.Errorf("mismatch in expected state root computed %v vs %v indicates bug in proof implementation", root, header.Root) + // Decode the hexadecimal string into the big.Int + // The base is 16 for hexadecimal + n := new(big.Int) + n, success := n.SetString(hex.EncodeToString(value), 16) + if !success { + fmt.Println("Failed to parse hexadecimal string") } - return pr.ProofResult() - */ + + // set key proof + proof.StorageProof[i].Key = keyHash + proof.StorageProof[i].Value = (*hexutil.Big)(unsafe.Pointer(n)) + proof.StorageProof[i].Proof = *(*[]hexutility.Bytes)(unsafe.Pointer(&storageProof)) + } + + return proof, nil } func (api *APIImpl) GetWitness(ctx context.Context, blockNrOrHash rpc.BlockNumberOrHash) (hexutility.Bytes, error) {