From 80fffb695a102db3a4df25bce3c910c2c7dd58bd Mon Sep 17 00:00:00 2001 From: Maxim Sidorov Date: Mon, 21 Feb 2022 16:21:53 +0300 Subject: [PATCH 01/11] add leavesCount method for state_tree.go --- api/get_network_info.go | 5 ++--- storage/state_tree.go | 5 +++++ storage/state_tree_test.go | 14 ++++++++++++++ 3 files changed, 21 insertions(+), 3 deletions(-) diff --git a/api/get_network_info.go b/api/get_network_info.go index 4113ff108..27519d59f 100644 --- a/api/get_network_info.go +++ b/api/get_network_info.go @@ -35,12 +35,11 @@ func (a *API) unsafeGetNetworkInfo() (*dto.NetworkInfo, error) { TransactionCount: a.storage.GetTransactionCount(), } - // TODO this ignores the fact that other nodes can put new accounts in arbitrary state leaves; to be revisited in the future - accountCount, err := a.storage.StateTree.NextAvailableStateID() + accountCount, err := a.storage.StateTree.LeavesCount() if err != nil { return nil, err } - networkInfo.AccountCount = *accountCount + networkInfo.AccountCount = accountCount latestBatch, err := a.storage.GetLatestSubmittedBatch() if err != nil && !storage.IsNotFoundError(err) { diff --git a/storage/state_tree.go b/storage/state_tree.go index e84b6eb2b..81b1d7ee9 100644 --- a/storage/state_tree.go +++ b/storage/state_tree.go @@ -61,6 +61,11 @@ func (s *StateTree) LeafOrEmpty(stateID uint32) (*models.StateLeaf, error) { return leaf, err } +func (s *StateTree) LeavesCount() (uint32, error) { + count, err := s.database.Badger.Count(&stored.StateLeaf{}, nil) + return uint32(count), err +} + func (s *StateTree) NextAvailableStateID() (*uint32, error) { return s.NextVacantSubtree(0) } diff --git a/storage/state_tree_test.go b/storage/state_tree_test.go index 3c965f5ef..89e5e5122 100644 --- a/storage/state_tree_test.go +++ b/storage/state_tree_test.go @@ -365,6 +365,20 @@ func (s *StateTreeTestSuite) TestSet_ReturnsWitness() { s.Equal(node.DataHash, witness[31]) } +func (s *StateTreeTestSuite) TestLeavesCount_UpdateCountAfterAddingNewLeaves() { + s.setStateLeaves(0, 2, 4, 8, 9) + + count, err := s.storage.StateTree.LeavesCount() + s.NoError(err) + s.EqualValues(5, count) +} + +func (s *StateTreeTestSuite) TestLeavesCount_NoLeaves() { + count, err := s.storage.StateTree.LeavesCount() + s.NoError(err) + s.EqualValues(0, count) +} + func (s *StateTreeTestSuite) TestRevertTo() { states := []models.UserState{ { From 1c6d81283ef8179b0f1174658bb593d801fe8421 Mon Sep 17 00:00:00 2001 From: Maxim Sidorov Date: Tue, 22 Feb 2022 14:41:27 +0300 Subject: [PATCH 02/11] add leaves counter cache --- api/get_network_info.go | 7 +--- models/dto/network_info.go | 2 +- storage/state_tree.go | 82 +++++++++++++++++++++++++++----------- storage/storage.go | 7 +++- 4 files changed, 67 insertions(+), 31 deletions(-) diff --git a/api/get_network_info.go b/api/get_network_info.go index 27519d59f..61181c550 100644 --- a/api/get_network_info.go +++ b/api/get_network_info.go @@ -33,14 +33,9 @@ func (a *API) unsafeGetNetworkInfo() (*dto.NetworkInfo, error) { Rollup: a.client.ChainState.Rollup, BlockNumber: a.storage.GetLatestBlockNumber(), TransactionCount: a.storage.GetTransactionCount(), + AccountCount: a.storage.StateTree.LeavesCount(), } - accountCount, err := a.storage.StateTree.LeavesCount() - if err != nil { - return nil, err - } - networkInfo.AccountCount = accountCount - latestBatch, err := a.storage.GetLatestSubmittedBatch() if err != nil && !storage.IsNotFoundError(err) { return nil, err diff --git a/models/dto/network_info.go b/models/dto/network_info.go index bd02bd8ff..8975c95ee 100644 --- a/models/dto/network_info.go +++ b/models/dto/network_info.go @@ -17,7 +17,7 @@ type NetworkInfo struct { Rollup common.Address BlockNumber uint32 TransactionCount uint64 - AccountCount uint32 + AccountCount uint64 LatestBatch *models.Uint256 LatestFinalisedBatch *models.Uint256 SignatureDomain bls.Domain diff --git a/storage/state_tree.go b/storage/state_tree.go index 81b1d7ee9..465758346 100644 --- a/storage/state_tree.go +++ b/storage/state_tree.go @@ -1,6 +1,8 @@ package storage import ( + "sync/atomic" + "github.com/Worldcoin/hubble-commander/db" "github.com/Worldcoin/hubble-commander/encoder" "github.com/Worldcoin/hubble-commander/models" @@ -21,17 +23,33 @@ const ( type StateTree struct { database *Database merkleTree *StoredMerkleTree + + leavesCounter *uint64 +} + +func NewStateTree(database *Database) (*StateTree, error) { + stateTree := newStateTree(database) + count, err := stateTree.leavesCountFromStorage() + if err != nil { + return nil, err + } + atomic.StoreUint64(stateTree.leavesCounter, count) + return stateTree, nil } -func NewStateTree(database *Database) *StateTree { +func newStateTree(database *Database) *StateTree { return &StateTree{ - database: database, - merkleTree: NewStoredMerkleTree("state", database, StateTreeDepth), + database: database, + merkleTree: NewStoredMerkleTree("state", database, StateTreeDepth), + leavesCounter: ref.Uint64(0), } } func (s *StateTree) copyWithNewDatabase(database *Database) *StateTree { - return NewStateTree(database) + stateTree := newStateTree(database) + leavesCount := atomic.LoadUint64(s.leavesCounter) + atomic.StoreUint64(stateTree.leavesCounter, leavesCount) + return stateTree } func (s *StateTree) Root() (*common.Hash, error) { @@ -53,17 +71,22 @@ func (s *StateTree) Leaf(stateID uint32) (stateLeaf *models.StateLeaf, err error func (s *StateTree) LeafOrEmpty(stateID uint32) (*models.StateLeaf, error) { leaf, err := s.Leaf(stateID) if IsNotFoundError(err) { - return &models.StateLeaf{ - StateID: stateID, - DataHash: merkletree.GetZeroHash(0), - }, nil + return emptyStateLeaf(stateID), nil } return leaf, err } -func (s *StateTree) LeavesCount() (uint32, error) { +func (s *StateTree) LeavesCount() uint64 { + return atomic.LoadUint64(s.leavesCounter) +} + +func (s *StateTree) leavesCountFromStorage() (uint64, error) { count, err := s.database.Badger.Count(&stored.StateLeaf{}, nil) - return uint32(count), err + return count, err +} + +func (s *StateTree) incrementLeavesCounter() { + atomic.AddUint64(s.leavesCounter, 1) } func (s *StateTree) NextAvailableStateID() (*uint32, error) { @@ -133,13 +156,17 @@ func roundAndValidateStateTreeSlot(rangeStart, rangeEnd, subtreeWidth int64) *in // Set returns a witness containing 32 elements for the current set operation func (s *StateTree) Set(id uint32, state *models.UserState) (witness models.Witness, err error) { + isNewLeaf := false err = s.database.ExecuteInTransaction(TxOptions{}, func(txDatabase *Database) error { - witness, err = NewStateTree(txDatabase).unsafeSet(id, state) + witness, isNewLeaf, err = newStateTree(txDatabase).unsafeSet(id, state) return err }) if err != nil { return nil, err } + if isNewLeaf { + s.incrementLeavesCounter() + } return witness, nil } @@ -162,7 +189,7 @@ func (s *StateTree) RevertTo(targetRootHash common.Hash) error { } return s.database.ExecuteInTransaction(TxOptions{}, func(txDatabase *Database) (err error) { - stateTree := NewStateTree(txDatabase) + stateTree := newStateTree(txDatabase) err = txDatabase.Badger.Iterator(models.StateUpdatePrefix, db.ReversePrefetchIteratorOpts, func(item *bdg.Item) (bool, error) { var stateUpdate *models.StateUpdate @@ -206,31 +233,34 @@ func decodeStateUpdate(item *bdg.Item) (*models.StateUpdate, error) { return &stateUpdate, nil } -func (s *StateTree) unsafeSet(index uint32, state *models.UserState) (models.Witness, error) { - prevLeaf, err := s.LeafOrEmpty(index) - if err != nil { - return nil, err +func (s *StateTree) unsafeSet(index uint32, state *models.UserState) (witness models.Witness, isNewLeaf bool, err error) { + prevLeaf, err := s.Leaf(index) + if IsNotFoundError(err) { + prevLeaf = emptyStateLeaf(index) + isNewLeaf = true + } else if err != nil { + return nil, false, err } prevRoot, err := s.Root() if err != nil { - return nil, err + return nil, false, err } currentLeaf, err := NewStateLeaf(index, state) if err != nil { - return nil, err + return nil, false, err } err = s.upsertStateLeaf(currentLeaf) if err != nil { - return nil, err + return nil, false, err } prevLeafPath := models.MakeMerklePathFromLeafID(prevLeaf.StateID) currentRoot, witness, err := s.merkleTree.SetNode(&prevLeafPath, currentLeaf.DataHash) if err != nil { - return nil, err + return nil, false, err } err = s.addStateUpdate(&models.StateUpdate{ @@ -239,10 +269,9 @@ func (s *StateTree) unsafeSet(index uint32, state *models.UserState) (models.Wit PrevStateLeaf: *prevLeaf, }) if err != nil { - return nil, err + return nil, false, err } - - return witness, nil + return witness, isNewLeaf, nil } func (s *StateTree) getLeafByPubKeyIDAndTokenID(pubKeyID uint32, tokenID models.Uint256) (*models.StateLeaf, error) { @@ -307,6 +336,13 @@ func (s *StateTree) IterateLeaves(action func(stateLeaf *models.StateLeaf) error return nil } +func emptyStateLeaf(stateID uint32) *models.StateLeaf { + return &models.StateLeaf{ + StateID: stateID, + DataHash: merkletree.GetZeroHash(0), + } +} + func NewStateLeaf(stateID uint32, state *models.UserState) (*models.StateLeaf, error) { dataHash, err := encoder.HashUserState(state) if err != nil { diff --git a/storage/storage.go b/storage/storage.go index 01560c35b..8b63bc236 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -59,6 +59,11 @@ func newStorageFromDatabase(database *Database) (*Storage, error) { pendingStakeWithdrawalStorage := NewPendingStakeWithdrawalStorage(database) + stateTree, err := NewStateTree(database) + if err != nil { + return nil, err + } + storage := &Storage{ BatchStorage: batchStorage, CommitmentStorage: commitmentStorage, @@ -67,7 +72,7 @@ func newStorageFromDatabase(database *Database) (*Storage, error) { ChainStateStorage: chainStateStorage, RegisteredTokenStorage: registeredTokenStorage, RegisteredSpokeStorage: registeredSpokeStorage, - StateTree: NewStateTree(database), + StateTree: stateTree, AccountTree: accountTree, PendingStakeWithdrawalStorage: pendingStakeWithdrawalStorage, database: database, From 84bfb1ede715ced948ee498ab8e8b0e5130c543d Mon Sep 17 00:00:00 2001 From: Maxim Sidorov Date: Tue, 22 Feb 2022 14:41:40 +0300 Subject: [PATCH 03/11] update tests --- storage/state_tree_test.go | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/storage/state_tree_test.go b/storage/state_tree_test.go index 89e5e5122..46aae40c7 100644 --- a/storage/state_tree_test.go +++ b/storage/state_tree_test.go @@ -368,14 +368,27 @@ func (s *StateTreeTestSuite) TestSet_ReturnsWitness() { func (s *StateTreeTestSuite) TestLeavesCount_UpdateCountAfterAddingNewLeaves() { s.setStateLeaves(0, 2, 4, 8, 9) - count, err := s.storage.StateTree.LeavesCount() - s.NoError(err) + count := s.storage.StateTree.LeavesCount() + s.EqualValues(5, count) + + s.setStateLeaves(4, 8) + count = s.storage.StateTree.LeavesCount() s.EqualValues(5, count) } +func (s *StateTreeTestSuite) TestLeavesCount_TheSameCountAfterUpdatingLeaves() { + s.setStateLeaves(0, 1) + + count := s.storage.StateTree.LeavesCount() + s.EqualValues(2, count) + + s.setStateLeaves(0, 1) + count = s.storage.StateTree.LeavesCount() + s.EqualValues(2, count) +} + func (s *StateTreeTestSuite) TestLeavesCount_NoLeaves() { - count, err := s.storage.StateTree.LeavesCount() - s.NoError(err) + count := s.storage.StateTree.LeavesCount() s.EqualValues(0, count) } From cffafcf1f2dccc2f75d65412d0cff581e9dd2c1b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Sieczkowski?= Date: Tue, 1 Mar 2022 13:57:57 +0100 Subject: [PATCH 04/11] Make TestLeavesCount_UpdatesCountAfterAddingNewLeaves do what it claims --- storage/state_tree_test.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/storage/state_tree_test.go b/storage/state_tree_test.go index 46aae40c7..06aeef674 100644 --- a/storage/state_tree_test.go +++ b/storage/state_tree_test.go @@ -365,15 +365,15 @@ func (s *StateTreeTestSuite) TestSet_ReturnsWitness() { s.Equal(node.DataHash, witness[31]) } -func (s *StateTreeTestSuite) TestLeavesCount_UpdateCountAfterAddingNewLeaves() { +func (s *StateTreeTestSuite) TestLeavesCount_UpdatesCountAfterAddingNewLeaves() { s.setStateLeaves(0, 2, 4, 8, 9) count := s.storage.StateTree.LeavesCount() s.EqualValues(5, count) - s.setStateLeaves(4, 8) + s.setStateLeaves(5, 7) count = s.storage.StateTree.LeavesCount() - s.EqualValues(5, count) + s.EqualValues(7, count) } func (s *StateTreeTestSuite) TestLeavesCount_TheSameCountAfterUpdatingLeaves() { From 113da4fb8585419a9b66cfd499a821ca8ff5e5ab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Sieczkowski?= Date: Tue, 1 Mar 2022 13:58:38 +0100 Subject: [PATCH 05/11] Reorder lines --- storage/storage.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/storage/storage.go b/storage/storage.go index 8b63bc236..503e16d75 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -52,18 +52,18 @@ func newStorageFromDatabase(database *Database) (*Storage, error) { registeredSpokeStorage := NewRegisteredSpokeStorage(database) - accountTree, err := NewAccountTree(database) + stateTree, err := NewStateTree(database) if err != nil { return nil, err } - pendingStakeWithdrawalStorage := NewPendingStakeWithdrawalStorage(database) - - stateTree, err := NewStateTree(database) + accountTree, err := NewAccountTree(database) if err != nil { return nil, err } + pendingStakeWithdrawalStorage := NewPendingStakeWithdrawalStorage(database) + storage := &Storage{ BatchStorage: batchStorage, CommitmentStorage: commitmentStorage, From e40b812f8e647bccbc0f9fd59505bd43ae2f7e5c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Sieczkowski?= Date: Tue, 1 Mar 2022 14:39:41 +0100 Subject: [PATCH 06/11] Rename leavesCounter -> leavesCount --- storage/state_tree.go | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/storage/state_tree.go b/storage/state_tree.go index 465758346..4971ec3e5 100644 --- a/storage/state_tree.go +++ b/storage/state_tree.go @@ -24,7 +24,7 @@ type StateTree struct { database *Database merkleTree *StoredMerkleTree - leavesCounter *uint64 + leavesCount *uint64 } func NewStateTree(database *Database) (*StateTree, error) { @@ -33,22 +33,22 @@ func NewStateTree(database *Database) (*StateTree, error) { if err != nil { return nil, err } - atomic.StoreUint64(stateTree.leavesCounter, count) + atomic.StoreUint64(stateTree.leavesCount, count) return stateTree, nil } func newStateTree(database *Database) *StateTree { return &StateTree{ - database: database, - merkleTree: NewStoredMerkleTree("state", database, StateTreeDepth), - leavesCounter: ref.Uint64(0), + database: database, + merkleTree: NewStoredMerkleTree("state", database, StateTreeDepth), + leavesCount: ref.Uint64(0), } } func (s *StateTree) copyWithNewDatabase(database *Database) *StateTree { stateTree := newStateTree(database) - leavesCount := atomic.LoadUint64(s.leavesCounter) - atomic.StoreUint64(stateTree.leavesCounter, leavesCount) + leavesCount := atomic.LoadUint64(s.leavesCount) + atomic.StoreUint64(stateTree.leavesCount, leavesCount) return stateTree } @@ -77,7 +77,7 @@ func (s *StateTree) LeafOrEmpty(stateID uint32) (*models.StateLeaf, error) { } func (s *StateTree) LeavesCount() uint64 { - return atomic.LoadUint64(s.leavesCounter) + return atomic.LoadUint64(s.leavesCount) } func (s *StateTree) leavesCountFromStorage() (uint64, error) { @@ -86,7 +86,7 @@ func (s *StateTree) leavesCountFromStorage() (uint64, error) { } func (s *StateTree) incrementLeavesCounter() { - atomic.AddUint64(s.leavesCounter, 1) + atomic.AddUint64(s.leavesCount, 1) } func (s *StateTree) NextAvailableStateID() (*uint32, error) { From bed7443228b7cb9bbeb2e27c1d5e80ed16e27263 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Sieczkowski?= Date: Tue, 1 Mar 2022 14:40:01 +0100 Subject: [PATCH 07/11] Rename incrementLeavesCounter -> incrementLeavesCount --- storage/state_tree.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/storage/state_tree.go b/storage/state_tree.go index 4971ec3e5..1ae3dabdd 100644 --- a/storage/state_tree.go +++ b/storage/state_tree.go @@ -85,7 +85,7 @@ func (s *StateTree) leavesCountFromStorage() (uint64, error) { return count, err } -func (s *StateTree) incrementLeavesCounter() { +func (s *StateTree) incrementLeavesCount() { atomic.AddUint64(s.leavesCount, 1) } @@ -165,7 +165,7 @@ func (s *StateTree) Set(id uint32, state *models.UserState) (witness models.Witn return nil, err } if isNewLeaf { - s.incrementLeavesCounter() + s.incrementLeavesCount() } return witness, nil From 4afef6ca7d1e2c920b6a76bfd09d1f84c2bd3abc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Sieczkowski?= Date: Tue, 1 Mar 2022 14:41:09 +0100 Subject: [PATCH 08/11] Rename leavesCountFromStorage -> getLeavesCountFromStorage --- storage/state_tree.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/storage/state_tree.go b/storage/state_tree.go index 1ae3dabdd..30a8a84d4 100644 --- a/storage/state_tree.go +++ b/storage/state_tree.go @@ -29,7 +29,7 @@ type StateTree struct { func NewStateTree(database *Database) (*StateTree, error) { stateTree := newStateTree(database) - count, err := stateTree.leavesCountFromStorage() + count, err := stateTree.getLeavesCountFromStorage() if err != nil { return nil, err } @@ -80,7 +80,7 @@ func (s *StateTree) LeavesCount() uint64 { return atomic.LoadUint64(s.leavesCount) } -func (s *StateTree) leavesCountFromStorage() (uint64, error) { +func (s *StateTree) getLeavesCountFromStorage() (uint64, error) { count, err := s.database.Badger.Count(&stored.StateLeaf{}, nil) return count, err } From 103b2fb8c6278ac16ffee94f1287d7231e03b43f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Sieczkowski?= Date: Tue, 1 Mar 2022 14:49:27 +0100 Subject: [PATCH 09/11] Extract leafOrEmpty and unsafeSetLeaf --- storage/state_tree.go | 43 +++++++++++++++++++++++++++++-------------- 1 file changed, 29 insertions(+), 14 deletions(-) diff --git a/storage/state_tree.go b/storage/state_tree.go index 30a8a84d4..a0592ffd0 100644 --- a/storage/state_tree.go +++ b/storage/state_tree.go @@ -69,11 +69,21 @@ func (s *StateTree) Leaf(stateID uint32) (stateLeaf *models.StateLeaf, err error } func (s *StateTree) LeafOrEmpty(stateID uint32) (*models.StateLeaf, error) { - leaf, err := s.Leaf(stateID) + leaf, _, err := s.leafOrEmpty(stateID) + if err != nil { + return nil, err + } + return leaf, nil +} + +func (s *StateTree) leafOrEmpty(stateID uint32) (leaf *models.StateLeaf, isEmpty bool, err error) { + leaf, err = s.Leaf(stateID) if IsNotFoundError(err) { - return emptyStateLeaf(stateID), nil + return emptyStateLeaf(stateID), true, nil + } else if err != nil { + return nil, false, err } - return leaf, err + return leaf, false, nil } func (s *StateTree) LeavesCount() uint64 { @@ -234,33 +244,37 @@ func decodeStateUpdate(item *bdg.Item) (*models.StateUpdate, error) { } func (s *StateTree) unsafeSet(index uint32, state *models.UserState) (witness models.Witness, isNewLeaf bool, err error) { - prevLeaf, err := s.Leaf(index) - if IsNotFoundError(err) { - prevLeaf = emptyStateLeaf(index) - isNewLeaf = true - } else if err != nil { + prevLeaf, isNewLeaf, err := s.leafOrEmpty(index) + if err != nil { return nil, false, err } + witness, err = s.unsafeSetLeaf(index, prevLeaf, state) + if err != nil { + return nil, false, err + } + return witness, isNewLeaf, nil +} +func (s *StateTree) unsafeSetLeaf(index uint32, prevLeaf *models.StateLeaf, state *models.UserState) (witness models.Witness, err error) { prevRoot, err := s.Root() if err != nil { - return nil, false, err + return nil, err } currentLeaf, err := NewStateLeaf(index, state) if err != nil { - return nil, false, err + return nil, err } err = s.upsertStateLeaf(currentLeaf) if err != nil { - return nil, false, err + return nil, err } prevLeafPath := models.MakeMerklePathFromLeafID(prevLeaf.StateID) currentRoot, witness, err := s.merkleTree.SetNode(&prevLeafPath, currentLeaf.DataHash) if err != nil { - return nil, false, err + return nil, err } err = s.addStateUpdate(&models.StateUpdate{ @@ -269,9 +283,10 @@ func (s *StateTree) unsafeSet(index uint32, state *models.UserState) (witness mo PrevStateLeaf: *prevLeaf, }) if err != nil { - return nil, false, err + return nil, err } - return witness, isNewLeaf, nil + + return witness, nil } func (s *StateTree) getLeafByPubKeyIDAndTokenID(pubKeyID uint32, tokenID models.Uint256) (*models.StateLeaf, error) { From c5a1c25765dc82a309c53a0550d8dff122a488f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Sieczkowski?= Date: Tue, 1 Mar 2022 14:54:34 +0100 Subject: [PATCH 10/11] Add a failing test for LeavesCount --- storage/state_tree_test.go | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/storage/state_tree_test.go b/storage/state_tree_test.go index 06aeef674..a5bc84a10 100644 --- a/storage/state_tree_test.go +++ b/storage/state_tree_test.go @@ -392,6 +392,24 @@ func (s *StateTreeTestSuite) TestLeavesCount_NoLeaves() { s.EqualValues(0, count) } +func (s *StateTreeTestSuite) TestLeavesCount_ReturnsCorrectCountAfterRevert() { + s.setStateLeaves(0, 2, 4, 8, 9) + + root, err := s.storage.StateTree.Root() + s.NoError(err) + + s.setStateLeaves(1) + + count := s.storage.StateTree.LeavesCount() + s.EqualValues(6, count) + + err = s.storage.StateTree.RevertTo(*root) + s.NoError(err) + + count = s.storage.StateTree.LeavesCount() + s.EqualValues(5, count) +} + func (s *StateTreeTestSuite) TestRevertTo() { states := []models.UserState{ { From 6102466ff55588d4a1a0e7a94033e5e41dbcf5b6 Mon Sep 17 00:00:00 2001 From: Maxim Sidorov Date: Tue, 1 Mar 2022 18:12:21 +0300 Subject: [PATCH 11/11] fix decreasing count for RevertTo method --- storage/state_tree.go | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/storage/state_tree.go b/storage/state_tree.go index a0592ffd0..fa166e613 100644 --- a/storage/state_tree.go +++ b/storage/state_tree.go @@ -99,6 +99,10 @@ func (s *StateTree) incrementLeavesCount() { atomic.AddUint64(s.leavesCount, 1) } +func (s *StateTree) decreaseLeavesCount(delta uint64) { + atomic.AddUint64(s.leavesCount, ^(delta - 1)) +} + func (s *StateTree) NextAvailableStateID() (*uint32, error) { return s.NextVacantSubtree(0) } @@ -197,6 +201,7 @@ func (s *StateTree) RevertTo(targetRootHash common.Hash) error { if *currentRootHash == targetRootHash { return nil } + revertedLeavesCount := uint64(0) return s.database.ExecuteInTransaction(TxOptions{}, func(txDatabase *Database) (err error) { stateTree := newStateTree(txDatabase) @@ -215,6 +220,7 @@ func (s *StateTree) RevertTo(targetRootHash common.Hash) error { if err != nil { return false, err } + revertedLeavesCount++ return *currentRootHash == targetRootHash, nil }) if err != nil && err != db.ErrIteratorFinished { @@ -224,6 +230,7 @@ func (s *StateTree) RevertTo(targetRootHash common.Hash) error { if *currentRootHash != targetRootHash { return errors.WithStack(ErrNonexistentState) } + s.decreaseLeavesCount(revertedLeavesCount) return nil }) }