diff --git a/middleware.go b/middleware.go index 55d8b96..59bbe57 100644 --- a/middleware.go +++ b/middleware.go @@ -19,6 +19,8 @@ type ( Logger *Logger // Skipper defines a function to skip middleware. Skipper middleware.Skipper + // AfterNextSkipper defines a function to skip middleware after the next handler is called. + AfterNextSkipper middleware.Skipper // BeforeNext is a function that is executed before the next handler is called. BeforeNext middleware.BeforeFunc // Enricher is a function that can be used to enrich the logger with additional information. @@ -62,6 +64,10 @@ func Middleware(config Config) echo.MiddlewareFunc { config.Skipper = middleware.DefaultSkipper } + if config.AfterNextSkipper == nil { + config.AfterNextSkipper = middleware.DefaultSkipper + } + if config.Logger == nil { config.Logger = New(os.Stdout, WithTimestamp()) } @@ -129,6 +135,10 @@ func Middleware(config Config) echo.MiddlewareFunc { } } + if config.AfterNextSkipper(c) { + return err + } + stop := time.Now() latency := stop.Sub(start) var mainEvt *zerolog.Event diff --git a/middleware_test.go b/middleware_test.go index 14bc7e3..d99b924 100644 --- a/middleware_test.go +++ b/middleware_test.go @@ -12,6 +12,7 @@ import ( "github.com/labstack/gommon/log" "github.com/rs/zerolog" "github.com/stretchr/testify/assert" + "github.com/ziflex/lecho/v3" ) @@ -161,4 +162,64 @@ func TestMiddleware(t *testing.T) { assert.Contains(t, str, `"level":"info"`) assert.NotContains(t, str, `"level":"warn"`) }) + + t.Run("should skip middleware before calling next handler when Skipper func returns true", func(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/skip", nil) + req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + b := &bytes.Buffer{} + l := lecho.New(b) + l.SetLevel(log.INFO) + m := lecho.Middleware(lecho.Config{ + Logger: l, + Skipper: func(c echo.Context) bool { + return c.Request().URL.Path == "/skip" + }, + }) + + next := func(c echo.Context) error { + return nil + } + + handler := m(next) + err := handler(c) + + assert.NoError(t, err, "should not return error") + + str := b.String() + assert.Empty(t, str, "should not log anything") + }) + + t.Run("should skip middleware after calling next handler when AfterNextSkipper func returns true", func(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + b := &bytes.Buffer{} + l := lecho.New(b) + l.SetLevel(log.INFO) + m := lecho.Middleware(lecho.Config{ + Logger: l, + AfterNextSkipper: func(c echo.Context) bool { + return c.Response().Status == http.StatusMovedPermanently + }, + }) + + next := func(c echo.Context) error { + return c.Redirect(http.StatusMovedPermanently, "/other") + } + + handler := m(next) + err := handler(c) + + assert.NoError(t, err, "should not return error") + + str := b.String() + assert.Empty(t, str, "should not log anything") + }) }