Skip to content

Commit

Permalink
Refactor Request Validation
Browse files Browse the repository at this point in the history
  • Loading branch information
joelsmith-2019 committed Nov 28, 2023
1 parent 3c3aa11 commit 2070db6
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 39 deletions.
33 changes: 9 additions & 24 deletions x/clock/keeper/msg_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,9 @@ func NewMsgServerImpl(k Keeper) types.MsgServer {
func (k msgServer) RegisterClockContract(goCtx context.Context, req *types.MsgRegisterClockContract) (*types.MsgRegisterClockContractResponse, error) {
ctx := sdk.UnwrapSDKContext(goCtx)

// Validate sender address
if _, err := sdk.AccAddressFromBech32(req.SenderAddress); err != nil {
return nil, errorsmod.Wrapf(types.ErrInvalidAddress, "invalid sender address: %s", req.SenderAddress)
}

// Validate contract address
if _, err := sdk.AccAddressFromBech32(req.ContractAddress); err != nil {
return nil, errorsmod.Wrapf(types.ErrInvalidAddress, "invalid contract address: %s", req.ContractAddress)
// Validate request
if err := req.ValidateBasic(); err != nil {
return nil, err
}

return &types.MsgRegisterClockContractResponse{}, k.RegisterContract(ctx, req.SenderAddress, req.ContractAddress)
Expand All @@ -46,14 +41,9 @@ func (k msgServer) RegisterClockContract(goCtx context.Context, req *types.MsgRe
func (k msgServer) UnregisterClockContract(goCtx context.Context, req *types.MsgUnregisterClockContract) (*types.MsgUnregisterClockContractResponse, error) {
ctx := sdk.UnwrapSDKContext(goCtx)

// Validate sender address
if _, err := sdk.AccAddressFromBech32(req.SenderAddress); err != nil {
return nil, errorsmod.Wrapf(types.ErrInvalidAddress, "invalid sender address: %s", req.SenderAddress)
}

// Validate contract address
if _, err := sdk.AccAddressFromBech32(req.ContractAddress); err != nil {
return nil, errorsmod.Wrapf(types.ErrInvalidAddress, "invalid contract address: %s", req.ContractAddress)
// Validate request
if err := req.ValidateBasic(); err != nil {
return nil, err
}

return &types.MsgUnregisterClockContractResponse{}, k.UnregisterContract(ctx, req.SenderAddress, req.ContractAddress)
Expand All @@ -63,14 +53,9 @@ func (k msgServer) UnregisterClockContract(goCtx context.Context, req *types.Msg
func (k msgServer) UnjailClockContract(goCtx context.Context, req *types.MsgUnjailClockContract) (*types.MsgUnjailClockContractResponse, error) {
ctx := sdk.UnwrapSDKContext(goCtx)

// Validate sender address
if _, err := sdk.AccAddressFromBech32(req.SenderAddress); err != nil {
return nil, errorsmod.Wrapf(types.ErrInvalidAddress, "invalid sender address: %s", req.SenderAddress)
}

// Validate contract address
if _, err := sdk.AccAddressFromBech32(req.ContractAddress); err != nil {
return nil, errorsmod.Wrapf(types.ErrInvalidAddress, "invalid contract address: %s", req.ContractAddress)
// Validate request
if err := req.ValidateBasic(); err != nil {
return nil, err
}

return &types.MsgUnjailClockContractResponse{}, k.SetJailStatusBySender(ctx, req.SenderAddress, req.ContractAddress, false)
Expand Down
29 changes: 14 additions & 15 deletions x/clock/types/msgs.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,7 @@ func (msg MsgRegisterClockContract) Type() string { return TypeMsgRegisterFeePay

// ValidateBasic runs stateless checks on the message
func (msg MsgRegisterClockContract) ValidateBasic() error {
if _, err := sdk.AccAddressFromBech32(msg.ContractAddress); err != nil {
return err
}

return nil
return validateAddresses(msg.SenderAddress, msg.ContractAddress)
}

// GetSignBytes encodes the message for signing
Expand All @@ -60,11 +56,7 @@ func (msg MsgUnregisterClockContract) Type() string { return TypeMsgRegisterFeeP

// ValidateBasic runs stateless checks on the message
func (msg MsgUnregisterClockContract) ValidateBasic() error {
if _, err := sdk.AccAddressFromBech32(msg.ContractAddress); err != nil {
return err
}

return nil
return validateAddresses(msg.SenderAddress, msg.ContractAddress)
}

// GetSignBytes encodes the message for signing
Expand All @@ -86,11 +78,7 @@ func (msg MsgUnjailClockContract) Type() string { return TypeMsgRegisterFeePayCo

// ValidateBasic runs stateless checks on the message
func (msg MsgUnjailClockContract) ValidateBasic() error {
if _, err := sdk.AccAddressFromBech32(msg.ContractAddress); err != nil {
return err
}

return nil
return validateAddresses(msg.SenderAddress, msg.ContractAddress)
}

// GetSignBytes encodes the message for signing
Expand Down Expand Up @@ -139,3 +127,14 @@ func (msg *MsgUpdateParams) ValidateBasic() error {

return msg.Params.Validate()
}

// ValidateAddresses validates the provided addresses
func validateAddresses(addresses ...string) error {
for _, address := range addresses {
if _, err := sdk.AccAddressFromBech32(address); err != nil {
return errors.Wrapf(ErrInvalidAddress, "invalid address: %s", address)
}
}

return nil
}

0 comments on commit 2070db6

Please sign in to comment.