Skip to content

Commit

Permalink
refactor(core,stf,x)!: remove InvokeTyped from router (#21224)
Browse files Browse the repository at this point in the history
  • Loading branch information
julienrbrt authored Aug 23, 2024
1 parent fc87374 commit a554a21
Show file tree
Hide file tree
Showing 55 changed files with 422 additions and 583 deletions.
6 changes: 2 additions & 4 deletions core/router/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@ import (
type Service interface {
// CanInvoke returns an error if the given request cannot be invoked.
CanInvoke(ctx context.Context, typeURL string) error
// InvokeTyped execute a message or query. It should be used when the called knows the type of the response.
InvokeTyped(ctx context.Context, req, res transaction.Msg) error
// InvokeUntyped execute a Msg or query. It should be used when the called doesn't know the type of the response.
InvokeUntyped(ctx context.Context, req transaction.Msg) (res transaction.Msg, err error)
// Invoke execute a message or query. The response should be type casted by the caller to the expected response.
Invoke(ctx context.Context, req transaction.Msg) (res transaction.Msg, err error)
}
16 changes: 8 additions & 8 deletions docs/rfc/rfc-006-handlers.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,11 @@ import (
type PreMsgHandlerRouter interface {
// RegisterGlobalPreMsgHandler will register a pre msg handler that hooks before any message executes.
// Handler will be called before ANY message executes.
RegisterGlobalPreMsgHandler(handler func(ctx context.Context, msg protoiface.MessageV1) error)
RegisterGlobalPreMsgHandler(handler func(ctx context.Context, msg transaction.Msg) error)
// RegisterPreMsgHandler will register a pre msg handler that hooks before the provided message
// with the given message name executes. Handler will be called before the message is executed
// by the module.
RegisterPreMsgHandler(msgName string, handler func(ctx context.Context, msg protoiface.MessageV1) error)
RegisterPreMsgHandler(msgName string, handler func(ctx context.Context, msg transaction.Msg) error)
}

type HasPreMsgHandler interface {
Expand All @@ -105,11 +105,11 @@ import (
type PostMsgHandlerRouter interface {
// RegisterGlobalPostMsgHandler will register a post msg handler that hooks after any message executes.
// Handler will be called after ANY message executes, alongside the response.
RegisterGlobalPostMsgHandler(handler func(ctx context.Context, msg, msgResp protoiface.MessageV1) error)
RegisterGlobalPostMsgHandler(handler func(ctx context.Context, msg, msgResp transaction.Msg) error)
// RegisterPostMsgHandler will register a pre msg handler that hooks after the provided message
// with the given message name executes. Handler will be called after the message is executed
// by the module, alongside the response returned by the module.
RegisterPostMsgHandler(msgName string, handler func(ctx context.Context, msg, msgResp protoiface.MessageV1) error)
RegisterPostMsgHandler(msgName string, handler func(ctx context.Context, msg, msgResp transaction.Msg) error)
}

type HasPostMsgHandler interface {
Expand Down Expand Up @@ -142,15 +142,15 @@ import (
)

type MsgHandlerRouter interface {
RegisterMsgHandler(msgName string, handler func(ctx context.Context, msg protoiface.MessageV1) (msgResp protoiface.MessageV1, err error))
RegisterMsgHandler(msgName string, handler func(ctx context.Context, msg transaction.Msg) (msgResp transaction.Msg, err error))
}

type HasMsgHandler interface {
RegisterMsgHandlers(router MsgHandlerRouter)
}

// RegisterMsgHandler is a helper function to retain type safety when creating handlers, so we do not need to cast messages.
func RegisterMsgHandler[Req, Resp protoiface.MessageV1](router MsgHandlerRouter, handler func(ctx context.Context, req Req) (resp Resp, err error)) {
func RegisterMsgHandler[Req, Resp transaction.Msg](router MsgHandlerRouter, handler func(ctx context.Context, req Req) (resp Resp, err error)) {
// impl detail
}
```
Expand Down Expand Up @@ -186,15 +186,15 @@ import (
)

type QueryHandlerRouter interface {
RegisterQueryHandler(msgName string, handler func(ctx context.Context, req protoiface.MessageV1) (resp protoiface.MessageV1, err error))
RegisterQueryHandler(msgName string, handler func(ctx context.Context, req transaction.Msg) (resp transaction.Msg, err error))
}

type HasQueryHandler interface {
RegisterQueryHandlers(router QueryHandlerRouter)
}

// RegisterQueryHandler is a helper function to retain type safety when creating handlers, so we do not need to cast messages.
func RegisterQueryHandler[Req, Resp protoiface.MessageV1](router QueryHandlerRouter, handler func(ctx context.Context, req Req) (resp Resp, err error)) {
func RegisterQueryHandler[Req, Resp transaction.Msg](router QueryHandlerRouter, handler func(ctx context.Context, req Req) (resp Resp, err error)) {
// impl detail
}

Expand Down
62 changes: 27 additions & 35 deletions runtime/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,21 +45,8 @@ func (m *msgRouterService) CanInvoke(ctx context.Context, typeURL string) error
return nil
}

// InvokeTyped execute a message and fill-in a response.
// The response must be known and passed as a parameter.
// Use InvokeUntyped if the response type is not known.
func (m *msgRouterService) InvokeTyped(ctx context.Context, msg, resp gogoproto.Message) error {
messageName := msgTypeURL(msg)
handler := m.router.HybridHandlerByMsgName(messageName)
if handler == nil {
return fmt.Errorf("unknown message: %s", messageName)
}

return handler(ctx, msg, resp)
}

// InvokeUntyped execute a message and returns a response.
func (m *msgRouterService) InvokeUntyped(ctx context.Context, msg gogoproto.Message) (gogoproto.Message, error) {
// Invoke execute a message and returns a response.
func (m *msgRouterService) Invoke(ctx context.Context, msg gogoproto.Message) (gogoproto.Message, error) {
messageName := msgTypeURL(msg)
respName := m.router.ResponseNameByMsgName(messageName)
if respName == "" {
Expand All @@ -76,7 +63,16 @@ func (m *msgRouterService) InvokeUntyped(ctx context.Context, msg gogoproto.Mess
return nil, fmt.Errorf("could not create response message %s", respName)
}

return msgResp, m.InvokeTyped(ctx, msg, msgResp)
handler := m.router.HybridHandlerByMsgName(messageName)
if handler == nil {
return nil, fmt.Errorf("unknown message: %s", messageName)
}

if err := handler(ctx, msg, msgResp); err != nil {
return nil, err
}

return msgResp, nil
}

// NewQueryRouterService implements router.Service.
Expand Down Expand Up @@ -110,27 +106,12 @@ func (m *queryRouterService) CanInvoke(ctx context.Context, typeURL string) erro
return nil
}

// InvokeTyped execute a message and fill-in a response.
// The response must be known and passed as a parameter.
// Use InvokeUntyped if the response type is not known.
func (m *queryRouterService) InvokeTyped(ctx context.Context, req, resp gogoproto.Message) error {
reqName := msgTypeURL(req)
handlers := m.router.HybridHandlerByRequestName(reqName)
if len(handlers) == 0 {
return fmt.Errorf("unknown request: %s", reqName)
} else if len(handlers) > 1 {
return fmt.Errorf("ambiguous request, query have multiple handlers: %s", reqName)
}

return handlers[0](ctx, req, resp)
}

// InvokeUntyped execute a message and returns a response.
func (m *queryRouterService) InvokeUntyped(ctx context.Context, req gogoproto.Message) (gogoproto.Message, error) {
// Invoke execute a message and returns a response.
func (m *queryRouterService) Invoke(ctx context.Context, req gogoproto.Message) (gogoproto.Message, error) {
reqName := msgTypeURL(req)
respName := m.router.ResponseNameByRequestName(reqName)
if respName == "" {
return nil, fmt.Errorf("could not find response type for request %s (%T)", reqName, req)
return nil, fmt.Errorf("unknown request: could not find response type for request %s (%T)", reqName, req)
}

// get response type
Expand All @@ -143,7 +124,18 @@ func (m *queryRouterService) InvokeUntyped(ctx context.Context, req gogoproto.Me
return nil, fmt.Errorf("could not create response request %s", respName)
}

return reqResp, m.InvokeTyped(ctx, req, reqResp)
handlers := m.router.HybridHandlerByRequestName(reqName)
if len(handlers) == 0 {
return nil, fmt.Errorf("unknown request: %s", reqName)
} else if len(handlers) > 1 {
return nil, fmt.Errorf("ambiguous request, query have multiple handlers: %s", reqName)
}

if err := handlers[0](ctx, req, reqResp); err != nil {
return nil, err
}

return reqResp, nil
}

// msgTypeURL returns the TypeURL of a proto message.
Expand Down
55 changes: 7 additions & 48 deletions runtime/router_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"github.com/stretchr/testify/require"

bankv1beta1 "cosmossdk.io/api/cosmos/bank/v1beta1"
counterv1 "cosmossdk.io/api/cosmos/counter/v1"
coretesting "cosmossdk.io/core/testing"
storetypes "cosmossdk.io/store/types"

Expand Down Expand Up @@ -38,70 +37,30 @@ func TestRouterService(t *testing.T) {
// Messages

t.Run("invalid msg", func(t *testing.T) {
_, err := messageRouterService.InvokeUntyped(testCtx.Ctx, &bankv1beta1.MsgSend{})
_, err := messageRouterService.Invoke(testCtx.Ctx, &bankv1beta1.MsgSend{})
require.ErrorContains(t, err, "could not find response type for message cosmos.bank.v1beta1.MsgSend")
})

t.Run("invoke untyped: valid msg (proto v1)", func(t *testing.T) {
resp, err := messageRouterService.InvokeUntyped(testCtx.Ctx, &countertypes.MsgIncreaseCounter{
t.Run("invoke: valid msg (proto v1)", func(t *testing.T) {
resp, err := messageRouterService.Invoke(testCtx.Ctx, &countertypes.MsgIncreaseCounter{
Signer: "cosmos1",
Count: 42,
})
require.NoError(t, err)
require.NotNil(t, resp)
})

t.Run("invoke typed: valid msg (proto v1)", func(t *testing.T) {
resp := &countertypes.MsgIncreaseCountResponse{}
err := messageRouterService.InvokeTyped(testCtx.Ctx, &countertypes.MsgIncreaseCounter{
Signer: "cosmos1",
Count: 42,
}, resp)
require.NoError(t, err)
require.NotNil(t, resp)
})

t.Run("invoke typed: valid msg (proto v2)", func(t *testing.T) {
resp := &counterv1.MsgIncreaseCountResponse{}
err := messageRouterService.InvokeTyped(testCtx.Ctx, &counterv1.MsgIncreaseCounter{
Signer: "cosmos1",
Count: 42,
}, resp)
require.NoError(t, err)
require.NotNil(t, resp)
})

// Queries

t.Run("invalid query", func(t *testing.T) {
err := queryRouterService.InvokeTyped(testCtx.Ctx, &bankv1beta1.QueryBalanceRequest{}, &bankv1beta1.QueryBalanceResponse{})
require.ErrorContains(t, err, "unknown request: cosmos.bank.v1beta1.QueryBalanceRequest")
})

t.Run("invoke typed: valid query (proto v1)", func(t *testing.T) {
_ = counterKeeper.CountStore.Set(testCtx.Ctx, 42)

resp := &countertypes.QueryGetCountResponse{}
err := queryRouterService.InvokeTyped(testCtx.Ctx, &countertypes.QueryGetCountRequest{}, resp)
require.NoError(t, err)
require.NotNil(t, resp)
require.Equal(t, int64(42), resp.TotalCount)
})

t.Run("invoke typed: valid query (proto v2)", func(t *testing.T) {
_ = counterKeeper.CountStore.Set(testCtx.Ctx, 42)

resp := &counterv1.QueryGetCountResponse{}
err := queryRouterService.InvokeTyped(testCtx.Ctx, &counterv1.QueryGetCountRequest{}, resp)
require.NoError(t, err)
require.NotNil(t, resp)
require.Equal(t, int64(42), resp.TotalCount)
_, err := queryRouterService.Invoke(testCtx.Ctx, &bankv1beta1.QueryBalanceRequest{})
require.ErrorContains(t, err, "could not find response type for request cosmos.bank.v1beta1.QueryBalanceRequest")
})

t.Run("invoke untyped: valid query (proto v1)", func(t *testing.T) {
t.Run("invoke: valid query (proto v1)", func(t *testing.T) {
_ = counterKeeper.CountStore.Set(testCtx.Ctx, 42)

resp, err := queryRouterService.InvokeUntyped(testCtx.Ctx, &countertypes.QueryGetCountRequest{})
resp, err := queryRouterService.Invoke(testCtx.Ctx, &countertypes.QueryGetCountRequest{})
require.NoError(t, err)
require.NotNil(t, resp)
respVal, ok := resp.(*countertypes.QueryGetCountResponse)
Expand Down
2 changes: 1 addition & 1 deletion schema/appdata/batch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ var testBatch = PacketBatch{
}

func batchListener() (Listener, *PacketBatch) {
var got = new(PacketBatch)
got := new(PacketBatch)
l := Listener{
InitializeModuleData: func(m ModuleInitializationData) error {
*got = append(*got, m)
Expand Down
3 changes: 2 additions & 1 deletion schema/diff/diff_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,8 @@ func TestCompareModuleSchemas(t *testing.T) {
{
Name: "foo",
KeyFields: []schema.Field{{Name: "key1", Kind: schema.EnumKind, EnumType: schema.EnumType{Name: "bar", Values: []string{"a"}}}},
}},
},
},
AddedObjectTypes: []schema.ObjectType{
{
Name: "bar",
Expand Down
4 changes: 2 additions & 2 deletions scripts/mockgen.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ $mockgen_cmd -source=x/nft/expected_keepers.go -package testutil -destination x/
$mockgen_cmd -source=x/feegrant/expected_keepers.go -package testutil -destination x/feegrant/testutil/expected_keepers_mocks.go
$mockgen_cmd -source=x/mint/types/expected_keepers.go -package testutil -destination x/mint/testutil/expected_keepers_mocks.go
$mockgen_cmd -source=x/auth/tx/config/expected_keepers.go -package testutil -destination x/auth/tx/testutil/expected_keepers_mocks.go
$mockgen_cmd -source=x/auth/types/expected_keepers.go -package testutil -destination x/auth/testutil/expected_keepers_mocks.go
# $mockgen_cmd -source=x/auth/types/expected_keepers.go -package testutil -destination x/auth/testutil/expected_keepers_mocks.go
$mockgen_cmd -source=x/auth/ante/expected_keepers.go -package testutil -destination x/auth/ante/testutil/expected_keepers_mocks.go
$mockgen_cmd -source=x/authz/expected_keepers.go -package testutil -destination x/authz/testutil/expected_keepers_mocks.go
$mockgen_cmd -source=x/bank/types/expected_keepers.go -package testutil -destination x/bank/testutil/expected_keepers_mocks.go
Expand All @@ -24,5 +24,5 @@ $mockgen_cmd -source=x/slashing/types/expected_keepers.go -package testutil -des
$mockgen_cmd -source=x/genutil/types/expected_keepers.go -package testutil -destination x/genutil/testutil/expected_keepers_mocks.go
$mockgen_cmd -source=x/gov/testutil/expected_keepers.go -package testutil -destination x/gov/testutil/expected_keepers_mocks.go
$mockgen_cmd -source=x/staking/types/expected_keepers.go -package testutil -destination x/staking/testutil/expected_keepers_mocks.go
$mockgen_cmd -source=x/auth/vesting/types/expected_keepers.go -package testutil -destination x/auth/vesting/testutil/expected_keepers_mocks.go
# $mockgen_cmd -source=x/auth/vesting/types/expected_keepers.go -package testutil -destination x/auth/vesting/testutil/expected_keepers_mocks.go
$mockgen_cmd -source=x/protocolpool/types/expected_keepers.go -package testutil -destination x/protocolpool/testutil/expected_keepers_mocks.go
37 changes: 5 additions & 32 deletions server/v2/stf/core_router_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,26 +31,14 @@ func (m msgRouterService) CanInvoke(ctx context.Context, typeURL string) error {
return exCtx.msgRouter.CanInvoke(ctx, typeURL)
}

// InvokeTyped execute a message and fill-in a response.
// The response must be known and passed as a parameter.
// Use InvokeUntyped if the response type is not known.
func (m msgRouterService) InvokeTyped(ctx context.Context, msg, resp transaction.Msg) error {
exCtx, err := getExecutionCtxFromContext(ctx)
if err != nil {
return err
}

return exCtx.msgRouter.InvokeTyped(ctx, msg, resp)
}

// InvokeUntyped execute a message and returns a response.
func (m msgRouterService) InvokeUntyped(ctx context.Context, msg transaction.Msg) (transaction.Msg, error) {
// Invoke execute a message and returns a response.
func (m msgRouterService) Invoke(ctx context.Context, msg transaction.Msg) (transaction.Msg, error) {
exCtx, err := getExecutionCtxFromContext(ctx)
if err != nil {
return nil, err
}

return exCtx.msgRouter.InvokeUntyped(ctx, msg)
return exCtx.msgRouter.Invoke(ctx, msg)
}

// NewQueryRouterService implements router.Service.
Expand All @@ -72,23 +60,8 @@ func (m queryRouterService) CanInvoke(ctx context.Context, typeURL string) error
return exCtx.queryRouter.CanInvoke(ctx, typeURL)
}

// InvokeTyped execute a message and fill-in a response.
// The response must be known and passed as a parameter.
// Use InvokeUntyped if the response type is not known.
func (m queryRouterService) InvokeTyped(
ctx context.Context,
req, resp transaction.Msg,
) error {
exCtx, err := getExecutionCtxFromContext(ctx)
if err != nil {
return err
}

return exCtx.queryRouter.InvokeTyped(ctx, req, resp)
}

// InvokeUntyped execute a message and returns a response.
func (m queryRouterService) InvokeUntyped(
func (m queryRouterService) Invoke(
ctx context.Context,
req transaction.Msg,
) (transaction.Msg, error) {
Expand All @@ -97,5 +70,5 @@ func (m queryRouterService) InvokeUntyped(
return nil, err
}

return exCtx.queryRouter.InvokeUntyped(ctx, req)
return exCtx.queryRouter.Invoke(ctx, req)
}
4 changes: 2 additions & 2 deletions server/v2/stf/stf.go
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ func (s STF[T]) runTxMsgs(
execCtx.setGasLimit(gasLimit)
for i, msg := range msgs {
execCtx.sender = txSenders[i]
resp, err := s.msgRouter.InvokeUntyped(execCtx, msg)
resp, err := s.msgRouter.Invoke(execCtx, msg)
if err != nil {
return nil, 0, nil, fmt.Errorf("message execution at index %d failed: %w", i, err)
}
Expand Down Expand Up @@ -457,7 +457,7 @@ func (s STF[T]) Query(
queryCtx := s.makeContext(ctx, nil, queryState, internal.ExecModeSimulate)
queryCtx.setHeaderInfo(hi)
queryCtx.setGasLimit(gasLimit)
return s.queryRouter.InvokeUntyped(queryCtx, req)
return s.queryRouter.Invoke(queryCtx, req)
}

// RunWithCtx is made to support genesis, if genesis was just the execution of messages instead
Expand Down
Loading

0 comments on commit a554a21

Please sign in to comment.