From ef563dced5a8186ba7cad2653c92bf52799f8dfa Mon Sep 17 00:00:00 2001 From: Antoine Pourchet Date: Wed, 10 Jan 2024 17:27:58 -0700 Subject: [PATCH] Readme: Added middleware examples (#7) * Readme: Added middleware examples * more readme nits --- README.md | 100 ++++++++++++++++++++++++++++++++++++++++++------ main.go | 4 +- wrapper.go | 14 +------ wrapper_test.go | 8 ++-- 4 files changed, 97 insertions(+), 29 deletions(-) diff --git a/README.md b/README.md index 0ff9607..728638f 100644 --- a/README.md +++ b/README.md @@ -39,12 +39,12 @@ func ListMovies(params ListMoviesParams) (ListMoviesResponse, error) { return httpwrap.NewHTTPError(http.StatusBadRequest, "Only 2022 movies are searchable.") } - .... + ... return ListMoviesResponse{ Movies: []string{ - "Finding Nemo", - "Good Will Hunting", + "Finding Nemo", + "Good Will Hunting", }, } } @@ -76,10 +76,27 @@ import ( "github.com/gorilla/mux" ) +func main() { + // Tell the httpwrapper to run checkAPICreds as middleware before moving on to call + // the endpoints themselves. + httpWrapper := httpwrap.NewStandardWrapper().Before(checkAPICreds) + + // Using gorilla/mux for this example, but httpWrapper.Wrap will turn your regular endpoint + // functions into the required http.HandlerFunc type. + router := mux.NewRouter() + router.Handle("/movies/list", httpWrapper.Wrap(ListMovies)).Methods("GET") + router.Handle("/raw-handler", httpWrapper.Wrap(RawHTTPHandler)).Methods("GET") + http.Handle("/", router) + + log.Fatal(http.ListenAndServe(":3000", router)) +} + type APICredentials struct { Key string `http:"header=x-application-passcode"` } +... + // checkAPICreds checks the api credentials passed into the request. Those APICredentials // will be populated using the headers in the http request. func checkAPICreds(creds APICredentials) error { @@ -89,18 +106,79 @@ func checkAPICreds(creds APICredentials) error { return httpwrap.NewHTTPError(http.StatusForbidden, "Bad credentials.") } -func main() { - // Tell httpwrapper to run checkAPICreds as middleware before moving on to call - // the endpoints themselves. - httpWrapper := httpwrap.NewStandardWrapper().Before(checkAPICreds) +``` +This example also displays a simple authorization middleware, `checkAPICreds`. + +## Middleware +Middlewares can be used to either short-circuit the http request lifecycle and return early, or to provide additional +information to the endpoint that gets called after it. The following example uses two separate middleware functions +to accomplish both. +```go +import ( + "log" + "net/http" + + "github.com/apourchet/httpwrap" + "github.com/gorilla/mux" +) - // Using gorilla/mux for this example, but httpWrapper.Wrap will turn your regular endpoint - // functions into the required http.HandlerFunc type. +func main() { + httpWrapper := httpwrap.NewStandardWrapper(). + Before(getUserAccountInfo) + wrapperWithAccessCheck := httpWrapper.Before(ensureUserAccess) + + // The listMovies endpoint needs account information only, but the addMovies endpoint also performs + // an access list check. router := mux.NewRouter() router.Handle("/movies/list", httpWrapper.Wrap(ListMovies)).Methods("GET") - router.Handle("/raw-handler", httpWrapper.Wrap(RawHTTPHandler)).Methods("GET") + router.Handle("/movies/add", wrapperWithAccessCheck.Wrap(AddMovie)).Methods("PUT") http.Handle("/", router) - + log.Fatal(http.ListenAndServe(":3000", router)) } + +type UserAuthenticationMaterial struct { + BearerToken string `http:"header=Authorization"` +} + +func getUserAccountInfo(authMaterial UserAuthenticationMaterial) (UserAccountInfo, error) { + userId, err := decodeBearerToken(authMaterial.BearerToken) + if err != nil { + return httpwrap.NewHTTPError(http.StatusUnauthorized, "Bad authentication material.") + } + + // Find the user information in the database for instance. + accountInfo, err := database.FindUserInformation(userId) + if err != nil { + return err + } + + ... + + return UserAccountInfo{ + UserID: userId, + UserHasAccess: false, + }, nil +} + +func ensureUserAccess(accountInfo UserAccountInfo) error { + if !accountInfo.UserHasAccess { + // Returning an error from a middleware will short-circuit the rest of the request lifecycle and + // early-return this to the client. + return httpwrap.NewHTTPError(http.StatusForbidden, "Access forbidden.") + } + return nil +} + +// The two endpoints below will have access not only to the information provided by the middlewares, but +// can gather additional parameters from the http request by taking in extra arguments. +// NOTE: These two endpoints do not _have to_ take in the information from the middleware, so if accountInfo +// is not actually used in the endpoint, it can be omitted from the function signature altogether. +func ListMovies(accountInfo UserAccountInfo, params ListMoviesParams) (ListMoviesResponse, error) { + ... +} + +func AddMovie(accountInfo UserAccountInfo, params AddMovieParams) (AddMovieResponse, error) { + ... +} ``` \ No newline at end of file diff --git a/main.go b/main.go index be6e36b..5adc748 100644 --- a/main.go +++ b/main.go @@ -1,6 +1,8 @@ package httpwrap -import "reflect" +import ( + "reflect" +) type mainFn struct { val reflect.Value diff --git a/wrapper.go b/wrapper.go index 8acb3c3..65dd5bf 100644 --- a/wrapper.go +++ b/wrapper.go @@ -57,7 +57,7 @@ func (w Wrapper) Finally(fn any) Wrapper { // Wrap sets the main handling function to process requests. This Wrap function must // be called to get an `http.Handler` type. -func (w Wrapper) Wrap(fn any) wrappedHttpHandler { +func (w Wrapper) Wrap(fn any) http.Handler { main, err := newMain(fn) if err != nil { panic(err) @@ -74,18 +74,6 @@ type wrappedHttpHandler struct { main mainFn } -// Before adds the before functions to the underlying Wrapper. -func (h wrappedHttpHandler) Before(fns ...any) wrappedHttpHandler { - h.Wrapper = h.Wrapper.Before(fns...) - return h -} - -// Finally sets the `finally` function of the underlying Wrapper. -func (h wrappedHttpHandler) Finally(fn any) wrappedHttpHandler { - h.Wrapper = h.Wrapper.Finally(fn) - return h -} - // ServeHTTP implements `http.Handler`. func (h wrappedHttpHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { ctx := newRunCtx(rw, req, h.construct) diff --git a/wrapper_test.go b/wrapper_test.go index 0e27e29..92b674e 100644 --- a/wrapper_test.go +++ b/wrapper_test.go @@ -213,7 +213,6 @@ func TestWrapper(t *testing.T) { handler := New(). WithConstruct(nopConstructor). - Wrap(func() {}). Before(func() *myerr { return &myerr{} }). @@ -222,7 +221,8 @@ func TestWrapper(t *testing.T) { require.Nil(t, res) require.Error(t, err) rw.WriteHeader(http.StatusCreated) - }) + }). + Wrap(func() {}) handler.ServeHTTP(rw, req) require.Equal(t, http.StatusCreated, rw.Result().StatusCode) @@ -231,12 +231,12 @@ func TestWrapper(t *testing.T) { Before(func() *myerr { return nil }). - Wrap(func() {}). Finally(func(rw http.ResponseWriter, res any, err error) { require.NotNil(t, rw) require.NoError(t, err) rw.WriteHeader(http.StatusCreated) - }) + }). + Wrap(func() {}) handler.ServeHTTP(rw, req) require.Equal(t, http.StatusCreated, rw.Result().StatusCode) })