From 890cbcd6ce7827c37c42c681e52ebd77f98e188e Mon Sep 17 00:00:00 2001 From: Stefano Scafiti Date: Thu, 18 May 2023 16:05:14 +0200 Subject: [PATCH] Forward request context inside webhook handler --- webhooks.go | 13 +++++++------ webhooks_test.go | 7 ++++--- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/webhooks.go b/webhooks.go index 172af97..96ce79b 100644 --- a/webhooks.go +++ b/webhooks.go @@ -1,6 +1,7 @@ package gocardless import ( + "context" "crypto/hmac" "crypto/sha256" "encoding/hex" @@ -12,15 +13,15 @@ import ( // EventHandler is the interface that must be implemented to handle events from a webhook. type EventHandler interface { - HandleEvent(Event) error + HandleEvent(context.Context, Event) error } // EventHandlerFunc can be used to convert a function into an EventHandler -type EventHandlerFunc func(Event) error +type EventHandlerFunc func(context.Context, Event) error // HandleEvent will call the EventHandlerFunc function -func (h EventHandlerFunc) HandleEvent(e Event) error { - return h(e) +func (h EventHandlerFunc) HandleEvent(ctx context.Context, e Event) error { + return h(ctx, e) } // WebhookHandler allows you to process incoming events from webhooks. @@ -43,7 +44,7 @@ func NewWebhookHandler(secret string, h EventHandler) (*WebhookHandler, error) { // ServeHTTP processes incoming webhooks and dispatches events to the corresponsing handlers. func (h *WebhookHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { sig, err := hex.DecodeString(r.Header.Get("Webhook-Signature")) - if len(sig) == 0 { + if len(sig) == 0 || err != nil { http.Error(w, "invalid signature", 498) return } @@ -65,7 +66,7 @@ func (h *WebhookHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } for _, event := range events.Events { - err := h.HandleEvent(event) + err := h.HandleEvent(r.Context(), event) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return diff --git a/webhooks_test.go b/webhooks_test.go index 5e246fe..4005f7e 100644 --- a/webhooks_test.go +++ b/webhooks_test.go @@ -1,6 +1,7 @@ package gocardless import ( + "context" "errors" "net/http" "net/http/httptest" @@ -9,7 +10,7 @@ import ( ) func TestWebhookFailsWithInvalidSignature(t *testing.T) { - wh, err := NewWebhookHandler("testing", EventHandlerFunc(func(e Event) error { + wh, err := NewWebhookHandler("testing", EventHandlerFunc(func(ctx context.Context, e Event) error { t.Error("unexpected call") return nil })) @@ -36,7 +37,7 @@ func TestWebhookFailsWithInvalidSignature(t *testing.T) { func TestWebhookFailsWithValidSignature(t *testing.T) { var called int - wh, err := NewWebhookHandler("testing", EventHandlerFunc(func(e Event) error { + wh, err := NewWebhookHandler("testing", EventHandlerFunc(func(ctx context.Context, e Event) error { called++ expectedID := "EVTESTNE86TNZS" if e.Id != expectedID { @@ -72,7 +73,7 @@ func TestWebhookFailsWithValidSignature(t *testing.T) { func TestWebhookWhenHandlerFails(t *testing.T) { var called int - wh, err := NewWebhookHandler("testing", EventHandlerFunc(func(e Event) error { + wh, err := NewWebhookHandler("testing", EventHandlerFunc(func(ctx context.Context, e Event) error { called++ expectedID := "EVTESTNE86TNZS" if e.Id != expectedID {