From a7f9743f8032447244236ead90535d079bf2b654 Mon Sep 17 00:00:00 2001 From: dylanhitt Date: Tue, 4 Mar 2025 20:17:00 -0500 Subject: [PATCH] feat: add Server option of WithStripTrailingSlash Add the ability for a user to implicitly strip all slashes from all routes as well as strip all ending slashes from all requests at the fuego Server level. --- default_middlewares.go | 10 ++++++++ option.go | 11 +++++++++ option_test.go | 8 +++++++ server.go | 10 ++++++++ server_test.go | 54 ++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 93 insertions(+) diff --git a/default_middlewares.go b/default_middlewares.go index 05c3432a..47393df1 100644 --- a/default_middlewares.go +++ b/default_middlewares.go @@ -5,6 +5,7 @@ import ( "log/slog" "net" "net/http" + "strings" "time" "github.com/google/uuid" @@ -158,3 +159,12 @@ func (l defaultLogger) middleware(next http.Handler) http.Handler { } }) } + +func stripTrailingSlashMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if len(r.URL.Path) > 1 { + r.URL.Path = strings.TrimRight(r.URL.Path, "/") + } + next.ServeHTTP(w, r) + }) +} diff --git a/option.go b/option.go index 9b0e6664..0f448dcc 100644 --- a/option.go +++ b/option.go @@ -5,6 +5,7 @@ import ( "net/http" "slices" "strconv" + "strings" "github.com/getkin/kin-openapi/openapi3" ) @@ -460,3 +461,13 @@ func OptionSecurity(securityRequirements ...openapi3.SecurityRequirement) func(* *r.Operation.Security = append(*r.Operation.Security, securityRequirements...) } } + +// OptionStripTrailingSlash ensure that the route declaration +// will have its ending trailing slash stripped. +func OptionStripTrailingSlash() func(*BaseRoute) { + return func(r *BaseRoute) { + if len(r.Path) > 1 { + r.Path = strings.TrimRight(r.Path, "/") + } + } +} diff --git a/option_test.go b/option_test.go index 59147765..dc0cdabc 100644 --- a/option_test.go +++ b/option_test.go @@ -1016,3 +1016,11 @@ func TestDefaultStatusCode(t *testing.T) { require.Equal(t, 500, w.Code) }) } + +func TestOptionStripTrailingSlash(t *testing.T) { + t.Run("Route trailing slash is stripped", func(t *testing.T) { + s := fuego.NewServer() + route := fuego.Get(s, "/test/", helloWorld, fuego.OptionStripTrailingSlash()) + require.Equal(t, "/test", route.Path) + }) +} diff --git a/server.go b/server.go index b08d8b30..470738c1 100644 --- a/server.go +++ b/server.go @@ -449,3 +449,13 @@ func WithLoggingMiddleware(loggingConfig LoggingConfig) func(*Server) { } } } + +// WithStripTrailingSlash ensure all declared routes trailing slash +// is stripped. This option also applies a middleware +// that strips the trailing slash from every incoming request. +func WithStripTrailingSlash() func(*Server) { + return func(s *Server) { + s.routeOptions = append(s.routeOptions, OptionStripTrailingSlash()) + s.globalMiddlewares = append(s.globalMiddlewares, stripTrailingSlashMiddleware) + } +} diff --git a/server_test.go b/server_test.go index 2e96655a..4942bac9 100644 --- a/server_test.go +++ b/server_test.go @@ -843,3 +843,57 @@ func TestWithSeveralGlobalMiddelwares(t *testing.T) { require.Equal(t, "two", res.Body.String()) }) } + +func TestWithStripTrailingSlash(t *testing.T) { + s := NewServer( + WithStripTrailingSlash(), + WithAddr(":9998"), + ) + Get(s, "/withtrailingslash/", dummyController) + Get(s, "/withouttrailingslash", dummyController) + + err := s.setup() + require.NoError(t, err) + + t.Run("requests with trailing slash", func(t *testing.T) { + t.Run("route with trailing slash", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/withtrailingslash/", nil) + res := httptest.NewRecorder() + + s.Handler.ServeHTTP(res, req) + + t.Log(res.Body.String()) + require.Equal(t, 200, res.Code) + }) + t.Run("route without trailing slash", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/withouttrailingslash/", nil) + res := httptest.NewRecorder() + + s.Handler.ServeHTTP(res, req) + + t.Log(res.Body.String()) + require.Equal(t, 200, res.Code) + }) + }) + + t.Run("requests without trailing slash", func(t *testing.T) { + t.Run("route with trailing slash", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/withtrailingslash", nil) + res := httptest.NewRecorder() + + s.Handler.ServeHTTP(res, req) + + t.Log(res.Body.String()) + require.Equal(t, 200, res.Code) + }) + t.Run("route without trailing slash", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/withouttrailingslash", nil) + res := httptest.NewRecorder() + + s.Handler.ServeHTTP(res, req) + + t.Log(res.Body.String()) + require.Equal(t, 200, res.Code) + }) + }) +}