From e091fe281ce20c0d29b2f0a65489db5bc69eb43c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Erik=20C=2E=20For=C3=A9s?= Date: Tue, 12 Nov 2024 21:12:57 +0100 Subject: [PATCH] feat(router): moved to its own pkg again --- middleware/transformer.go | 29 +++++++++++ router.go => router/router.go | 65 +++++++++++++++++-------- router_test.go => router/router_test.go | 9 ++-- transformer.go | 61 +++++++++++++++++++++++ 4 files changed, 141 insertions(+), 23 deletions(-) create mode 100644 middleware/transformer.go rename router.go => router/router.go (57%) rename router_test.go => router/router_test.go (89%) create mode 100644 transformer.go diff --git a/middleware/transformer.go b/middleware/transformer.go new file mode 100644 index 0000000..e7a7164 --- /dev/null +++ b/middleware/transformer.go @@ -0,0 +1,29 @@ +package middleware + +import ( + "net/http" + + "github.com/studiolambda/akumu" +) + +func Transformer(transformer akumu.Transformer) akumu.Middleware { + return func(handler http.Handler) http.Handler { + return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { + req, err := akumu.TransformWith(request, transformer) + + if err != nil { + akumu. + Failed(err). + Handle(writer, request) + + return + } + + handler.ServeHTTP(writer, req) + }) + } +} + +func TransformerFunc(transformer akumu.TransformerFunc) akumu.Middleware { + return Transformer(transformer) +} diff --git a/router.go b/router/router.go similarity index 57% rename from router.go rename to router/router.go index f64a1b1..a9e9d3e 100644 --- a/router.go +++ b/router/router.go @@ -1,4 +1,4 @@ -package akumu +package router import ( "fmt" @@ -7,13 +7,16 @@ import ( "path" "slices" "strings" + + "github.com/studiolambda/akumu" + "github.com/studiolambda/akumu/middleware" ) type Router struct { native *http.ServeMux pattern string parent *Router - middlewares []Middleware + middlewares []akumu.Middleware } func NewRouter() *Router { @@ -21,7 +24,7 @@ func NewRouter() *Router { native: http.NewServeMux(), pattern: "", parent: nil, - middlewares: make([]Middleware, 0), + middlewares: make([]akumu.Middleware, 0), } } @@ -34,7 +37,7 @@ func (router *Router) Group(pattern string, subrouter func(*Router)) { }) } -func (router *Router) With(middlewares ...Middleware) *Router { +func (router *Router) With(middlewares ...akumu.Middleware) *Router { return &Router{ native: nil, // parent's native will be used pattern: router.pattern, @@ -43,6 +46,30 @@ func (router *Router) With(middlewares ...Middleware) *Router { } } +func (router *Router) WithTransformer(transformer akumu.Transformer) *Router { + return router.With(middleware.Transformer(transformer)) +} + +func (router *Router) WithTransformerFunc(transformer akumu.TransformerFunc) *Router { + return router.WithTransformer(transformer) +} + +func (router *Router) WithValidator(validator akumu.Validator) *Router { + return router.With(middleware.Validator(validator)) +} + +func (router *Router) WithValidatorFunc(validator akumu.ValidatorFunc) *Router { + return router.WithValidator(validator) +} + +func (router *Router) WithAuthorizer(authorizer akumu.Authorizer) *Router { + return router.With(middleware.Authorizer(authorizer)) +} + +func (router *Router) WithAuthorizerFunc(authorizer akumu.AuthorizerFunc) *Router { + return router.WithAuthorizer(authorizer) +} + func (router *Router) mux() *http.ServeMux { if router.parent != nil { return router.parent.mux() @@ -59,11 +86,11 @@ func (router *Router) wrap(handler http.Handler) http.Handler { return handler } -func (router *Router) Use(middlewares ...Middleware) { +func (router *Router) Use(middlewares ...akumu.Middleware) { router.middlewares = append(router.middlewares, middlewares...) } -func (router *Router) Method(method string, pattern string, handler Handler) { +func (router *Router) Method(method string, pattern string, handler akumu.Handler) { pattern = path.Join(router.pattern, pattern) if !strings.HasSuffix(pattern, "/") { @@ -75,39 +102,39 @@ func (router *Router) Method(method string, pattern string, handler Handler) { Handle(fmt.Sprintf("%s %s{$}", method, pattern), router.wrap(handler)) } -func (router *Router) Get(pattern string, handler Handler) { +func (router *Router) Get(pattern string, handler akumu.Handler) { router.Method(http.MethodGet, pattern, handler) } -func (router *Router) Head(pattern string, handler Handler) { +func (router *Router) Head(pattern string, handler akumu.Handler) { router.Method(http.MethodHead, pattern, handler) } -func (router *Router) Post(pattern string, handler Handler) { +func (router *Router) Post(pattern string, handler akumu.Handler) { router.Method(http.MethodPost, pattern, handler) } -func (router *Router) Put(pattern string, handler Handler) { +func (router *Router) Put(pattern string, handler akumu.Handler) { router.Method(http.MethodPut, pattern, handler) } -func (router *Router) Patch(pattern string, handler Handler) { +func (router *Router) Patch(pattern string, handler akumu.Handler) { router.Method(http.MethodPatch, pattern, handler) } -func (router *Router) Delete(pattern string, handler Handler) { +func (router *Router) Delete(pattern string, handler akumu.Handler) { router.Method(http.MethodDelete, pattern, handler) } -func (router *Router) Connect(pattern string, handler Handler) { +func (router *Router) Connect(pattern string, handler akumu.Handler) { router.Method(http.MethodConnect, pattern, handler) } -func (router *Router) Options(pattern string, handler Handler) { +func (router *Router) Options(pattern string, handler akumu.Handler) { router.Method(http.MethodOptions, pattern, handler) } -func (router *Router) Trace(pattern string, handler Handler) { +func (router *Router) Trace(pattern string, handler akumu.Handler) { router.Method(http.MethodTrace, pattern, handler) } @@ -131,7 +158,7 @@ func (router *Router) Matches(request *http.Request) bool { return ok } -func (router *Router) Handler(method string, pattern string) (Handler, bool) { +func (router *Router) Handler(method string, pattern string) (akumu.Handler, bool) { if request, err := http.NewRequest(method, pattern, nil); err == nil { return router.HandlerMatch(request) } @@ -139,9 +166,9 @@ func (router *Router) Handler(method string, pattern string) (Handler, bool) { return nil, false } -func (router *Router) HandlerMatch(request *http.Request) (Handler, bool) { +func (router *Router) HandlerMatch(request *http.Request) (akumu.Handler, bool) { if handler, pattern := router.native.Handler(request); pattern != "" { - if handler, ok := handler.(Handler); ok { + if handler, ok := handler.(akumu.Handler); ok { return handler, true } } @@ -150,5 +177,5 @@ func (router *Router) HandlerMatch(request *http.Request) (Handler, bool) { } func (router *Router) Record(request *http.Request) *httptest.ResponseRecorder { - return RecordHandler(router, request) + return akumu.RecordHandler(router, request) } diff --git a/router_test.go b/router/router_test.go similarity index 89% rename from router_test.go rename to router/router_test.go index dcc8d79..5f7e1c9 100644 --- a/router_test.go +++ b/router/router_test.go @@ -1,14 +1,15 @@ -package akumu_test +package router_test import ( "net/http" "testing" "github.com/studiolambda/akumu" + "github.com/studiolambda/akumu/router" ) func TestRouterHas(t *testing.T) { - router := akumu.NewRouter() + router := router.NewRouter() router.Get("/", func(request *http.Request) error { return akumu.Response(http.StatusOK) @@ -28,7 +29,7 @@ func TestRouterHas(t *testing.T) { } func TestRouterMatches(t *testing.T) { - router := akumu.NewRouter() + router := router.NewRouter() router.Get("/", func(request *http.Request) error { return akumu.Response(http.StatusOK) @@ -46,7 +47,7 @@ func TestRouterMatches(t *testing.T) { } func TestRouterHandler(t *testing.T) { - router := akumu.NewRouter() + router := router.NewRouter() router.Get("/", func(request *http.Request) error { return akumu.Response(http.StatusOK) diff --git a/transformer.go b/transformer.go new file mode 100644 index 0000000..b32398f --- /dev/null +++ b/transformer.go @@ -0,0 +1,61 @@ +package akumu + +import ( + "errors" + "fmt" + "net/http" + "reflect" +) + +type Transformer interface { + Transform(request *http.Request) (*http.Request, error) +} + +type TransformerFunc func(request *http.Request) (*http.Request, error) + +var ( + ErrTransformFailed = errors.New("transform failed") + ErrInvalidTransformerInContext = errors.New("invalid transformer in context") + ErrTransformerNotFoundInContext = errors.New("transformer not found in context") +) + +func (transformer TransformerFunc) Transform(request *http.Request) (*http.Request, error) { + return transformer(request) +} + +func Transform[T Transformer](request *http.Request) (*http.Request, error) { + return TransformWith(request, *new(T)) +} + +func TransformFrom(request *http.Request, key any) (*http.Request, error) { + value := request.Context().Value(key) + + if value != nil { + return nil, fmt.Errorf("%w: %s", ErrTransformerNotFoundInContext, reflect.TypeOf(key)) + } + + transformer, ok := value.(Transformer) + + if !ok { + return nil, fmt.Errorf("%w: %s", ErrInvalidTransformerInContext, reflect.TypeOf(key)) + } + + return TransformWith(request, transformer) +} + +func TransformWith[T Transformer](request *http.Request, transformer T) (*http.Request, error) { + req, err := transformer.Transform(request) + + if err != nil { + return nil, NewProblem( + errors.Join(ErrTransformFailed, err), + http.StatusUnprocessableEntity, + ) + } + + return req, nil +} + +func TransformWithFunc(request *http.Request, transformer TransformerFunc) (*http.Request, error) { + return TransformWith(request, transformer) +}