Skip to content

Commit

Permalink
feat(router): moved to its own pkg again
Browse files Browse the repository at this point in the history
  • Loading branch information
ConsoleTVs committed Nov 12, 2024
1 parent d960be5 commit e091fe2
Show file tree
Hide file tree
Showing 4 changed files with 141 additions and 23 deletions.
29 changes: 29 additions & 0 deletions middleware/transformer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package middleware

import (
"net/http"

"github.com/studiolambda/akumu"
)

func Transformer(transformer akumu.Transformer) akumu.Middleware {
return func(handler http.Handler) http.Handler {
return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
req, err := akumu.TransformWith(request, transformer)

if err != nil {
akumu.
Failed(err).
Handle(writer, request)

return
}

handler.ServeHTTP(writer, req)
})
}
}

func TransformerFunc(transformer akumu.TransformerFunc) akumu.Middleware {
return Transformer(transformer)
}
65 changes: 46 additions & 19 deletions router.go → router/router.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package akumu
package router

import (
"fmt"
Expand All @@ -7,21 +7,24 @@ import (
"path"
"slices"
"strings"

"github.com/studiolambda/akumu"
"github.com/studiolambda/akumu/middleware"
)

type Router struct {
native *http.ServeMux
pattern string
parent *Router
middlewares []Middleware
middlewares []akumu.Middleware
}

func NewRouter() *Router {
return &Router{
native: http.NewServeMux(),
pattern: "",
parent: nil,
middlewares: make([]Middleware, 0),
middlewares: make([]akumu.Middleware, 0),
}
}

Expand All @@ -34,7 +37,7 @@ func (router *Router) Group(pattern string, subrouter func(*Router)) {
})
}

func (router *Router) With(middlewares ...Middleware) *Router {
func (router *Router) With(middlewares ...akumu.Middleware) *Router {
return &Router{
native: nil, // parent's native will be used
pattern: router.pattern,
Expand All @@ -43,6 +46,30 @@ func (router *Router) With(middlewares ...Middleware) *Router {
}
}

func (router *Router) WithTransformer(transformer akumu.Transformer) *Router {
return router.With(middleware.Transformer(transformer))
}

func (router *Router) WithTransformerFunc(transformer akumu.TransformerFunc) *Router {
return router.WithTransformer(transformer)
}

func (router *Router) WithValidator(validator akumu.Validator) *Router {
return router.With(middleware.Validator(validator))
}

func (router *Router) WithValidatorFunc(validator akumu.ValidatorFunc) *Router {
return router.WithValidator(validator)
}

func (router *Router) WithAuthorizer(authorizer akumu.Authorizer) *Router {
return router.With(middleware.Authorizer(authorizer))
}

func (router *Router) WithAuthorizerFunc(authorizer akumu.AuthorizerFunc) *Router {
return router.WithAuthorizer(authorizer)
}

func (router *Router) mux() *http.ServeMux {
if router.parent != nil {
return router.parent.mux()
Expand All @@ -59,11 +86,11 @@ func (router *Router) wrap(handler http.Handler) http.Handler {
return handler
}

func (router *Router) Use(middlewares ...Middleware) {
func (router *Router) Use(middlewares ...akumu.Middleware) {
router.middlewares = append(router.middlewares, middlewares...)
}

func (router *Router) Method(method string, pattern string, handler Handler) {
func (router *Router) Method(method string, pattern string, handler akumu.Handler) {
pattern = path.Join(router.pattern, pattern)

if !strings.HasSuffix(pattern, "/") {
Expand All @@ -75,39 +102,39 @@ func (router *Router) Method(method string, pattern string, handler Handler) {
Handle(fmt.Sprintf("%s %s{$}", method, pattern), router.wrap(handler))
}

func (router *Router) Get(pattern string, handler Handler) {
func (router *Router) Get(pattern string, handler akumu.Handler) {
router.Method(http.MethodGet, pattern, handler)
}

func (router *Router) Head(pattern string, handler Handler) {
func (router *Router) Head(pattern string, handler akumu.Handler) {
router.Method(http.MethodHead, pattern, handler)
}

func (router *Router) Post(pattern string, handler Handler) {
func (router *Router) Post(pattern string, handler akumu.Handler) {
router.Method(http.MethodPost, pattern, handler)
}

func (router *Router) Put(pattern string, handler Handler) {
func (router *Router) Put(pattern string, handler akumu.Handler) {
router.Method(http.MethodPut, pattern, handler)
}

func (router *Router) Patch(pattern string, handler Handler) {
func (router *Router) Patch(pattern string, handler akumu.Handler) {
router.Method(http.MethodPatch, pattern, handler)
}

func (router *Router) Delete(pattern string, handler Handler) {
func (router *Router) Delete(pattern string, handler akumu.Handler) {
router.Method(http.MethodDelete, pattern, handler)
}

func (router *Router) Connect(pattern string, handler Handler) {
func (router *Router) Connect(pattern string, handler akumu.Handler) {
router.Method(http.MethodConnect, pattern, handler)
}

func (router *Router) Options(pattern string, handler Handler) {
func (router *Router) Options(pattern string, handler akumu.Handler) {
router.Method(http.MethodOptions, pattern, handler)
}

func (router *Router) Trace(pattern string, handler Handler) {
func (router *Router) Trace(pattern string, handler akumu.Handler) {
router.Method(http.MethodTrace, pattern, handler)
}

Expand All @@ -131,17 +158,17 @@ func (router *Router) Matches(request *http.Request) bool {
return ok
}

func (router *Router) Handler(method string, pattern string) (Handler, bool) {
func (router *Router) Handler(method string, pattern string) (akumu.Handler, bool) {
if request, err := http.NewRequest(method, pattern, nil); err == nil {
return router.HandlerMatch(request)
}

return nil, false
}

func (router *Router) HandlerMatch(request *http.Request) (Handler, bool) {
func (router *Router) HandlerMatch(request *http.Request) (akumu.Handler, bool) {
if handler, pattern := router.native.Handler(request); pattern != "" {
if handler, ok := handler.(Handler); ok {
if handler, ok := handler.(akumu.Handler); ok {
return handler, true
}
}
Expand All @@ -150,5 +177,5 @@ func (router *Router) HandlerMatch(request *http.Request) (Handler, bool) {
}

func (router *Router) Record(request *http.Request) *httptest.ResponseRecorder {
return RecordHandler(router, request)
return akumu.RecordHandler(router, request)
}
9 changes: 5 additions & 4 deletions router_test.go → router/router_test.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
package akumu_test
package router_test

import (
"net/http"
"testing"

"github.com/studiolambda/akumu"
"github.com/studiolambda/akumu/router"
)

func TestRouterHas(t *testing.T) {
router := akumu.NewRouter()
router := router.NewRouter()

router.Get("/", func(request *http.Request) error {
return akumu.Response(http.StatusOK)
Expand All @@ -28,7 +29,7 @@ func TestRouterHas(t *testing.T) {
}

func TestRouterMatches(t *testing.T) {
router := akumu.NewRouter()
router := router.NewRouter()

router.Get("/", func(request *http.Request) error {
return akumu.Response(http.StatusOK)
Expand All @@ -46,7 +47,7 @@ func TestRouterMatches(t *testing.T) {
}

func TestRouterHandler(t *testing.T) {
router := akumu.NewRouter()
router := router.NewRouter()

router.Get("/", func(request *http.Request) error {
return akumu.Response(http.StatusOK)
Expand Down
61 changes: 61 additions & 0 deletions transformer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package akumu

import (
"errors"
"fmt"
"net/http"
"reflect"
)

type Transformer interface {
Transform(request *http.Request) (*http.Request, error)
}

type TransformerFunc func(request *http.Request) (*http.Request, error)

var (
ErrTransformFailed = errors.New("transform failed")
ErrInvalidTransformerInContext = errors.New("invalid transformer in context")
ErrTransformerNotFoundInContext = errors.New("transformer not found in context")
)

func (transformer TransformerFunc) Transform(request *http.Request) (*http.Request, error) {
return transformer(request)
}

func Transform[T Transformer](request *http.Request) (*http.Request, error) {
return TransformWith(request, *new(T))
}

func TransformFrom(request *http.Request, key any) (*http.Request, error) {
value := request.Context().Value(key)

if value != nil {
return nil, fmt.Errorf("%w: %s", ErrTransformerNotFoundInContext, reflect.TypeOf(key))
}

transformer, ok := value.(Transformer)

if !ok {
return nil, fmt.Errorf("%w: %s", ErrInvalidTransformerInContext, reflect.TypeOf(key))
}

return TransformWith(request, transformer)
}

func TransformWith[T Transformer](request *http.Request, transformer T) (*http.Request, error) {
req, err := transformer.Transform(request)

if err != nil {
return nil, NewProblem(
errors.Join(ErrTransformFailed, err),
http.StatusUnprocessableEntity,
)
}

return req, nil
}

func TransformWithFunc(request *http.Request, transformer TransformerFunc) (*http.Request, error) {
return TransformWith(request, transformer)
}

0 comments on commit e091fe2

Please sign in to comment.