Skip to content
This repository has been archived by the owner on Oct 6, 2023. It is now read-only.

Return correct account count in hubble_getNetworkInfo result #594

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 1 addition & 7 deletions api/get_network_info.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,9 @@ func (a *API) unsafeGetNetworkInfo() (*dto.NetworkInfo, error) {
Rollup: a.client.ChainState.Rollup,
BlockNumber: a.storage.GetLatestBlockNumber(),
TransactionCount: a.storage.GetTransactionCount(),
AccountCount: a.storage.StateTree.LeavesCount(),
}

// TODO this ignores the fact that other nodes can put new accounts in arbitrary state leaves; to be revisited in the future
accountCount, err := a.storage.StateTree.NextAvailableStateID()
if err != nil {
return nil, err
}
networkInfo.AccountCount = *accountCount

latestBatch, err := a.storage.GetLatestSubmittedBatch()
if err != nil && !storage.IsNotFoundError(err) {
return nil, err
Expand Down
2 changes: 1 addition & 1 deletion models/dto/network_info.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ type NetworkInfo struct {
Rollup common.Address
BlockNumber uint32
TransactionCount uint64
AccountCount uint32
AccountCount uint64
LatestBatch *models.Uint256
LatestFinalisedBatch *models.Uint256
SignatureDomain bls.Domain
Expand Down
86 changes: 71 additions & 15 deletions storage/state_tree.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package storage

import (
"sync/atomic"

"github.com/Worldcoin/hubble-commander/db"
"github.com/Worldcoin/hubble-commander/encoder"
"github.com/Worldcoin/hubble-commander/models"
Expand All @@ -21,17 +23,33 @@ const (
type StateTree struct {
database *Database
merkleTree *StoredMerkleTree

leavesCount *uint64
}

func NewStateTree(database *Database) (*StateTree, error) {
stateTree := newStateTree(database)
count, err := stateTree.getLeavesCountFromStorage()
if err != nil {
return nil, err
}
atomic.StoreUint64(stateTree.leavesCount, count)
return stateTree, nil
}

func NewStateTree(database *Database) *StateTree {
func newStateTree(database *Database) *StateTree {
return &StateTree{
database: database,
merkleTree: NewStoredMerkleTree("state", database, StateTreeDepth),
database: database,
merkleTree: NewStoredMerkleTree("state", database, StateTreeDepth),
leavesCount: ref.Uint64(0),
}
}

func (s *StateTree) copyWithNewDatabase(database *Database) *StateTree {
return NewStateTree(database)
stateTree := newStateTree(database)
leavesCount := atomic.LoadUint64(s.leavesCount)
atomic.StoreUint64(stateTree.leavesCount, leavesCount)
return stateTree
}

func (s *StateTree) Root() (*common.Hash, error) {
Expand All @@ -51,14 +69,34 @@ func (s *StateTree) Leaf(stateID uint32) (stateLeaf *models.StateLeaf, err error
}

func (s *StateTree) LeafOrEmpty(stateID uint32) (*models.StateLeaf, error) {
leaf, err := s.Leaf(stateID)
leaf, _, err := s.leafOrEmpty(stateID)
if err != nil {
return nil, err
}
return leaf, nil
}

func (s *StateTree) leafOrEmpty(stateID uint32) (leaf *models.StateLeaf, isEmpty bool, err error) {
leaf, err = s.Leaf(stateID)
if IsNotFoundError(err) {
return &models.StateLeaf{
StateID: stateID,
DataHash: merkletree.GetZeroHash(0),
}, nil
return emptyStateLeaf(stateID), true, nil
} else if err != nil {
return nil, false, err
}
return leaf, err
return leaf, false, nil
}

func (s *StateTree) LeavesCount() uint64 {
return atomic.LoadUint64(s.leavesCount)
}

func (s *StateTree) getLeavesCountFromStorage() (uint64, error) {
count, err := s.database.Badger.Count(&stored.StateLeaf{}, nil)
return count, err
}

func (s *StateTree) incrementLeavesCount() {
atomic.AddUint64(s.leavesCount, 1)
}

func (s *StateTree) NextAvailableStateID() (*uint32, error) {
Expand Down Expand Up @@ -128,13 +166,17 @@ func roundAndValidateStateTreeSlot(rangeStart, rangeEnd, subtreeWidth int64) *in

// Set returns a witness containing 32 elements for the current set operation
func (s *StateTree) Set(id uint32, state *models.UserState) (witness models.Witness, err error) {
isNewLeaf := false
err = s.database.ExecuteInTransaction(TxOptions{}, func(txDatabase *Database) error {
witness, err = NewStateTree(txDatabase).unsafeSet(id, state)
witness, isNewLeaf, err = newStateTree(txDatabase).unsafeSet(id, state)
return err
})
if err != nil {
return nil, err
}
if isNewLeaf {
s.incrementLeavesCount()
}

return witness, nil
}
Expand All @@ -157,7 +199,7 @@ func (s *StateTree) RevertTo(targetRootHash common.Hash) error {
}

return s.database.ExecuteInTransaction(TxOptions{}, func(txDatabase *Database) (err error) {
stateTree := NewStateTree(txDatabase)
stateTree := newStateTree(txDatabase)

err = txDatabase.Badger.Iterator(models.StateUpdatePrefix, db.ReversePrefetchIteratorOpts, func(item *bdg.Item) (bool, error) {
var stateUpdate *models.StateUpdate
Expand Down Expand Up @@ -201,12 +243,19 @@ func decodeStateUpdate(item *bdg.Item) (*models.StateUpdate, error) {
return &stateUpdate, nil
}

func (s *StateTree) unsafeSet(index uint32, state *models.UserState) (models.Witness, error) {
prevLeaf, err := s.LeafOrEmpty(index)
func (s *StateTree) unsafeSet(index uint32, state *models.UserState) (witness models.Witness, isNewLeaf bool, err error) {
prevLeaf, isNewLeaf, err := s.leafOrEmpty(index)
if err != nil {
return nil, err
return nil, false, err
}
witness, err = s.unsafeSetLeaf(index, prevLeaf, state)
if err != nil {
return nil, false, err
}
return witness, isNewLeaf, nil
}

func (s *StateTree) unsafeSetLeaf(index uint32, prevLeaf *models.StateLeaf, state *models.UserState) (witness models.Witness, err error) {
prevRoot, err := s.Root()
if err != nil {
return nil, err
Expand Down Expand Up @@ -302,6 +351,13 @@ func (s *StateTree) IterateLeaves(action func(stateLeaf *models.StateLeaf) error
return nil
}

func emptyStateLeaf(stateID uint32) *models.StateLeaf {
return &models.StateLeaf{
StateID: stateID,
DataHash: merkletree.GetZeroHash(0),
}
}

func NewStateLeaf(stateID uint32, state *models.UserState) (*models.StateLeaf, error) {
dataHash, err := encoder.HashUserState(state)
if err != nil {
Expand Down
45 changes: 45 additions & 0 deletions storage/state_tree_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,51 @@ func (s *StateTreeTestSuite) TestSet_ReturnsWitness() {
s.Equal(node.DataHash, witness[31])
}

func (s *StateTreeTestSuite) TestLeavesCount_UpdatesCountAfterAddingNewLeaves() {
s.setStateLeaves(0, 2, 4, 8, 9)

count := s.storage.StateTree.LeavesCount()
s.EqualValues(5, count)

s.setStateLeaves(5, 7)
count = s.storage.StateTree.LeavesCount()
s.EqualValues(7, count)
}

func (s *StateTreeTestSuite) TestLeavesCount_TheSameCountAfterUpdatingLeaves() {
s.setStateLeaves(0, 1)

count := s.storage.StateTree.LeavesCount()
s.EqualValues(2, count)

s.setStateLeaves(0, 1)
count = s.storage.StateTree.LeavesCount()
s.EqualValues(2, count)
}

func (s *StateTreeTestSuite) TestLeavesCount_NoLeaves() {
count := s.storage.StateTree.LeavesCount()
s.EqualValues(0, count)
}

func (s *StateTreeTestSuite) TestLeavesCount_ReturnsCorrectCountAfterRevert() {
s.setStateLeaves(0, 2, 4, 8, 9)

root, err := s.storage.StateTree.Root()
s.NoError(err)

s.setStateLeaves(1)

count := s.storage.StateTree.LeavesCount()
s.EqualValues(6, count)

err = s.storage.StateTree.RevertTo(*root)
s.NoError(err)

count = s.storage.StateTree.LeavesCount()
s.EqualValues(5, count)
}
max-sidorov marked this conversation as resolved.
Show resolved Hide resolved

func (s *StateTreeTestSuite) TestRevertTo() {
states := []models.UserState{
{
Expand Down
7 changes: 6 additions & 1 deletion storage/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@ func newStorageFromDatabase(database *Database) (*Storage, error) {

registeredSpokeStorage := NewRegisteredSpokeStorage(database)

stateTree, err := NewStateTree(database)
if err != nil {
return nil, err
}

accountTree, err := NewAccountTree(database)
if err != nil {
return nil, err
Expand All @@ -67,7 +72,7 @@ func newStorageFromDatabase(database *Database) (*Storage, error) {
ChainStateStorage: chainStateStorage,
RegisteredTokenStorage: registeredTokenStorage,
RegisteredSpokeStorage: registeredSpokeStorage,
StateTree: NewStateTree(database),
StateTree: stateTree,
AccountTree: accountTree,
PendingStakeWithdrawalStorage: pendingStakeWithdrawalStorage,
database: database,
Expand Down