Skip to content

Commit

Permalink
initialize megavault module account in 7.0.0 upgrade handler (backport
Browse files Browse the repository at this point in the history
…#2394) (#2400)

Co-authored-by: Tian <[email protected]>
  • Loading branch information
mergify[bot] and tqin7 authored Sep 27, 2024
1 parent fd396aa commit f39b3e9
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 1 deletion.
1 change: 1 addition & 0 deletions protocol/app/upgrades.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ func (app *App) setupUpgradeHandlers() {
v7_0_0.CreateUpgradeHandler(
app.ModuleManager,
app.configurator,
app.AccountKeeper,
app.PricesKeeper,
app.VaultKeeper,
),
Expand Down
75 changes: 75 additions & 0 deletions protocol/app/upgrades/v7.0.0/upgrade.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import (
upgradetypes "cosmossdk.io/x/upgrade/types"
sdk "github.com/cosmos/cosmos-sdk/types"
"github.com/cosmos/cosmos-sdk/types/module"
authkeeper "github.com/cosmos/cosmos-sdk/x/auth/keeper"
authtypes "github.com/cosmos/cosmos-sdk/x/auth/types"
"github.com/dydxprotocol/v4-chain/protocol/lib"
"github.com/dydxprotocol/v4-chain/protocol/lib/slinky"
pricestypes "github.com/dydxprotocol/v4-chain/protocol/x/prices/types"
Expand All @@ -20,6 +22,75 @@ const (
QUOTE_QUANTUMS_PER_MEGAVAULT_SHARE = 1_000_000
)

var (
ModuleAccsToInitialize = []string{
vaulttypes.MegavaultAccountName,
}
)

// This module account initialization logic is copied from v3.0.0 upgrade handler.
func initializeModuleAccs(ctx sdk.Context, ak authkeeper.AccountKeeper) {
for _, modAccName := range ModuleAccsToInitialize {
// Get module account and relevant permissions from the accountKeeper.
addr, perms := ak.GetModuleAddressAndPermissions(modAccName)
if addr == nil {
panic(fmt.Sprintf(
"Did not find %v in `ak.GetModuleAddressAndPermissions`. This is not expected. Skipping.",
modAccName,
))
}

// Try to get the account in state.
acc := ak.GetAccount(ctx, addr)
if acc != nil {
// Account has been initialized.
macc, isModuleAccount := acc.(sdk.ModuleAccountI)
if isModuleAccount {
// Module account was correctly initialized. Skipping
ctx.Logger().Info(fmt.Sprintf(
"module account %+v was correctly initialized. No-op",
macc,
))
continue
}
// Module account has been initialized as a BaseAccount. Change to module account.
// Note: We need to get the base account to retrieve its account number, and convert it
// in place into a module account.
baseAccount, ok := acc.(*authtypes.BaseAccount)
if !ok {
panic(fmt.Sprintf(
"cannot cast %v into a BaseAccount, acc = %+v",
modAccName,
acc,
))
}
newModuleAccount := authtypes.NewModuleAccount(
baseAccount,
modAccName,
perms...,
)
ak.SetModuleAccount(ctx, newModuleAccount)
ctx.Logger().Info(fmt.Sprintf(
"Successfully converted %v to module account in state: %+v",
modAccName,
newModuleAccount,
))
continue
}

// Account has not been initialized at all. Initialize it as module.
// Implementation taken from
// https://github.com/dydxprotocol/cosmos-sdk/blob/bdf96fdd/x/auth/keeper/keeper.go#L213
newModuleAccount := authtypes.NewEmptyModuleAccount(modAccName, perms...)
maccI := (ak.NewAccount(ctx, newModuleAccount)).(sdk.ModuleAccountI) // this set the account number
ak.SetModuleAccount(ctx, maccI)
ctx.Logger().Info(fmt.Sprintf(
"Successfully initialized module account in state: %+v",
newModuleAccount,
))
}
}

func initCurrencyPairIDCache(ctx sdk.Context, k pricestypes.PricesKeeper) {
marketParams := k.GetAllMarketParams(ctx)
for _, mp := range marketParams {
Expand Down Expand Up @@ -126,13 +197,17 @@ func migrateVaultSharesToMegavaultShares(ctx sdk.Context, k vaultkeeper.Keeper)
func CreateUpgradeHandler(
mm *module.Manager,
configurator module.Configurator,
accountKeeper authkeeper.AccountKeeper,
pricesKeeper pricestypes.PricesKeeper,
vaultKeeper vaultkeeper.Keeper,
) upgradetypes.UpgradeHandler {
return func(ctx context.Context, plan upgradetypes.Plan, vm module.VersionMap) (module.VersionMap, error) {
sdkCtx := lib.UnwrapSDKContext(ctx, "app/upgrades")
sdkCtx.Logger().Info(fmt.Sprintf("Running %s Upgrade...", UpgradeName))

// Initialize module accounts.
initializeModuleAccs(sdkCtx, accountKeeper)

// Initialize the currency pair ID cache for all existing market params.
initCurrencyPairIDCache(sdkCtx, pricesKeeper)

Expand Down
23 changes: 22 additions & 1 deletion protocol/app/upgrades/v7.0.0/upgrade_container_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (

"github.com/cosmos/gogoproto/proto"

authtypes "github.com/cosmos/cosmos-sdk/x/auth/types"
v_7_0_0 "github.com/dydxprotocol/v4-chain/protocol/app/upgrades/v7.0.0"
"github.com/dydxprotocol/v4-chain/protocol/dtypes"
"github.com/dydxprotocol/v4-chain/protocol/testing/containertest"
Expand Down Expand Up @@ -48,9 +49,12 @@ func preUpgradeChecks(node *containertest.Node, t *testing.T) {
}

func postUpgradeChecks(node *containertest.Node, t *testing.T) {
// Add test for your upgrade handler logic below
// Check that vault quoting params are successfully migrated to vault params.
postUpgradeVaultParamsCheck(node, t)
// Check that vault shares are successfully migrated to megavault shares.
postUpgradeMegavaultSharesCheck(node, t)
// Check that megavault module account is successfully initialized.
postUpgradeMegavaultModuleAccCheck(node, t)

// Check that the affiliates module has been initialized with the default tiers.
postUpgradeAffiliatesModuleTiersCheck(node, t)
Expand Down Expand Up @@ -181,3 +185,20 @@ func postUpgradeAffiliatesModuleTiersCheck(node *containertest.Node, t *testing.
require.NoError(t, err)
require.Equal(t, affiliatestypes.DefaultAffiliateTiers, affiliateTiersResp.Tiers)
}

func postUpgradeMegavaultModuleAccCheck(node *containertest.Node, t *testing.T) {
resp, err := containertest.Query(
node,
authtypes.NewQueryClient,
authtypes.QueryClient.ModuleAccountByName,
&authtypes.QueryModuleAccountByNameRequest{
Name: vaulttypes.MegavaultAccountName,
},
)
require.NoError(t, err)
require.NotNil(t, resp)

moduleAccResp := authtypes.QueryModuleAccountByNameResponse{}
err = proto.UnmarshalText(resp.String(), &moduleAccResp)
require.NoError(t, err)
}

0 comments on commit f39b3e9

Please sign in to comment.