diff --git a/protocol/x/accountplus/genesis.go b/protocol/x/accountplus/genesis.go index d84289a0db..66601dfbc2 100644 --- a/protocol/x/accountplus/genesis.go +++ b/protocol/x/accountplus/genesis.go @@ -18,7 +18,19 @@ func ExportGenesis(ctx sdk.Context, k keeper.Keeper) *types.GenesisState { if err != nil { panic(err) } + + params := k.GetParams(ctx) + nextAuthenticatorId := k.InitializeOrGetNextAuthenticatorId(ctx) + + data, err := k.GetAllAuthenticatorData(ctx) + if err != nil { + panic(err) + } + return &types.GenesisState{ - Accounts: accounts, + Accounts: accounts, + Params: params, + NextAuthenticatorId: nextAuthenticatorId, + AuthenticatorData: data, } } diff --git a/protocol/x/accountplus/genesis_test.go b/protocol/x/accountplus/genesis_test.go index dc0a82d055..df323ff829 100644 --- a/protocol/x/accountplus/genesis_test.go +++ b/protocol/x/accountplus/genesis_test.go @@ -18,7 +18,10 @@ func TestImportExportGenesis(t *testing.T) { // The order of this list may not match the order in GenesisState. We want our tests to be deterministic so // order of expectedAccountStates was manually set based on test debug. This ordering should only be changed if // additional accounts added to genesisState. If a feature breaks the existing ordering, should look into why. - expectedAccountStates []types.AccountState + expectedAccountStates []types.AccountState + expectedParams types.Params + expectedNextAuthenticatorId uint64 + expectedAuthenticatorData []types.AuthenticatorData }{ "non-empty genesis": { genesisState: &types.GenesisState{ @@ -45,6 +48,32 @@ func TestImportExportGenesis(t *testing.T) { }, }, }, + Params: types.Params{ + IsSmartAccountActive: true, + }, + NextAuthenticatorId: 100, + AuthenticatorData: []types.AuthenticatorData{ + { + Address: constants.AliceAccAddress.String(), + Authenticators: []types.AccountAuthenticator{ + { + Id: 1, + Type: "MessageFilter", + Config: []byte("/cosmos.bank.v1beta1.MsgSend"), + }, + }, + }, + { + Address: constants.BobAccAddress.String(), + Authenticators: []types.AccountAuthenticator{ + { + Id: 1, + Type: "ClobPairIdFilter", + Config: []byte("0,1,2"), + }, + }, + }, + }, }, expectedAccountStates: []types.AccountState{ { @@ -69,6 +98,32 @@ func TestImportExportGenesis(t *testing.T) { }, }, }, + expectedParams: types.Params{ + IsSmartAccountActive: true, + }, + expectedNextAuthenticatorId: 100, + expectedAuthenticatorData: []types.AuthenticatorData{ + { + Address: constants.BobAccAddress.String(), + Authenticators: []types.AccountAuthenticator{ + { + Id: 1, + Type: "ClobPairIdFilter", + Config: []byte("0,1,2"), + }, + }, + }, + { + Address: constants.AliceAccAddress.String(), + Authenticators: []types.AccountAuthenticator{ + { + Id: 1, + Type: "MessageFilter", + Config: []byte("/cosmos.bank.v1beta1.MsgSend"), + }, + }, + }, + }, }, "empty genesis": { genesisState: &types.GenesisState{ @@ -95,11 +150,26 @@ func TestImportExportGenesis(t *testing.T) { "Keeper account states do not match Genesis account states", ) + // Check that keeper params are correct + actualParams := k.GetParams(ctx) + require.Equal(t, tc.expectedParams, actualParams) + + // Check that keeper next authenticator id is correct + actualNextAuthenticatorId := k.InitializeOrGetNextAuthenticatorId(ctx) + require.Equal(t, tc.expectedNextAuthenticatorId, actualNextAuthenticatorId) + + // Check that keeper authenticator data is correct + actualAuthenticatorData, _ := k.GetAllAuthenticatorData(ctx) + require.Equal(t, tc.expectedAuthenticatorData, actualAuthenticatorData) + exportedGenesis := accountplus.ExportGenesis(ctx, k) // Check that the exported state matches the expected state expectedGenesis := &types.GenesisState{ - Accounts: tc.expectedAccountStates, + Accounts: tc.expectedAccountStates, + Params: tc.expectedParams, + NextAuthenticatorId: tc.expectedNextAuthenticatorId, + AuthenticatorData: tc.expectedAuthenticatorData, } require.Equal(t, *exportedGenesis, *expectedGenesis) }) diff --git a/protocol/x/accountplus/keeper/authenticators.go b/protocol/x/accountplus/keeper/authenticators.go index 71d8b9aede..82f7289c1f 100644 --- a/protocol/x/accountplus/keeper/authenticators.go +++ b/protocol/x/accountplus/keeper/authenticators.go @@ -96,21 +96,29 @@ func (k Keeper) AddAuthenticator( } k.SetNextAuthenticatorId(ctx, id+1) - - store := prefix.NewStore( - ctx.KVStore(k.storeKey), - []byte(types.AuthenticatorKeyPrefix), - ) authenticator := types.AccountAuthenticator{ Id: id, Type: authenticatorType, Config: config, } - b := k.cdc.MustMarshal(&authenticator) - store.Set(types.KeyAccountId(account, id), b) + k.SetAuthenticator(ctx, account.String(), id, authenticator) return id, nil } +func (k Keeper) SetAuthenticator( + ctx sdk.Context, + account string, + authenticatorId uint64, + authenticator types.AccountAuthenticator, +) { + store := prefix.NewStore( + ctx.KVStore(k.storeKey), + []byte(types.AuthenticatorKeyPrefix), + ) + b := k.cdc.MustMarshal(&authenticator) + store.Set(types.BuildKey(account, authenticatorId), b) +} + // RemoveAuthenticator removes an authenticator from an account func (k Keeper) RemoveAuthenticator(ctx sdk.Context, account sdk.AccAddress, authenticatorId uint64) error { store := prefix.NewStore( diff --git a/protocol/x/accountplus/keeper/keeper.go b/protocol/x/accountplus/keeper/keeper.go index def0ced51e..19f59e7ee6 100644 --- a/protocol/x/accountplus/keeper/keeper.go +++ b/protocol/x/accountplus/keeper/keeper.go @@ -3,6 +3,7 @@ package keeper import ( "errors" "fmt" + "strings" "cosmossdk.io/log" storetypes "cosmossdk.io/store/types" @@ -53,7 +54,14 @@ func (k Keeper) GetAllAccountStates(ctx sdk.Context) ([]types.AccountState, erro accounts := []types.AccountState{} for ; iterator.Valid(); iterator.Next() { - accountState, found := k.GetAccountState(ctx, iterator.Key()) + key := iterator.Key() + + // Temporary workaround to exclude smart account kv pairs. + if strings.HasPrefix(string(key), types.SmartAccountKeyPrefix) { + continue + } + + accountState, found := k.GetAccountState(ctx, key) if !found { return accounts, errors.New("Could not get account state for address: " + sdk.AccAddress(iterator.Key()).String()) } @@ -73,9 +81,67 @@ func (k Keeper) SetGenesisState(ctx sdk.Context, data types.GenesisState) error k.SetAccountState(ctx, address, account) } + k.SetParams(ctx, data.Params) + k.SetNextAuthenticatorId(ctx, data.NextAuthenticatorId) + + for _, data := range data.GetAuthenticatorData() { + address := data.GetAddress() + for _, authenticator := range data.GetAuthenticators() { + k.SetAuthenticator(ctx, address, authenticator.Id, authenticator) + } + } + return nil } +// GetAllAuthenticatorData is used in genesis export to export all the authenticator for all accounts +func (k Keeper) GetAllAuthenticatorData(ctx sdk.Context) ([]types.AuthenticatorData, error) { + var accountAuthenticators []types.AuthenticatorData + + parse := func(key []byte, value []byte) error { + var authenticator types.AccountAuthenticator + err := k.cdc.Unmarshal(value, &authenticator) + if err != nil { + return err + } + + // Extract account address from key + accountAddr := strings.Split(string(key), "/")[2] + + // Check if this entry is for a new address or the same as the last one processed + if len(accountAuthenticators) == 0 || + accountAuthenticators[len(accountAuthenticators)-1].Address != accountAddr { + // If it's a new address, create a new AuthenticatorData entry + accountAuthenticators = append(accountAuthenticators, types.AuthenticatorData{ + Address: accountAddr, + Authenticators: []types.AccountAuthenticator{authenticator}, + }) + } else { + // If it's the same address, append the authenticator to the last entry in the list + lastIndex := len(accountAuthenticators) - 1 + accountAuthenticators[lastIndex].Authenticators = append( + accountAuthenticators[lastIndex].Authenticators, + authenticator, + ) + } + + return nil + } + + // Iterate over all entries in the store using a prefix iterator + iterator := storetypes.KVStorePrefixIterator(ctx.KVStore(k.storeKey), []byte(types.AuthenticatorKeyPrefix)) + defer iterator.Close() + + for ; iterator.Valid(); iterator.Next() { + err := parse(iterator.Key(), iterator.Value()) + if err != nil { + return nil, err + } + } + + return accountAuthenticators, nil +} + func GetAccountPlusStateWithTimestampNonceDetails( address sdk.AccAddress, tsNonce uint64,