Skip to content

Commit

Permalink
use default methodMatcher instead of flags
Browse files Browse the repository at this point in the history
  • Loading branch information
Corné de Jong committed Jun 20, 2024
1 parent d1be7b4 commit c3e6635
Show file tree
Hide file tree
Showing 4 changed files with 146 additions and 29 deletions.
36 changes: 26 additions & 10 deletions mux.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,12 @@ var (

// NewRouter returns a new router instance.
func NewRouter() *Router {
return &Router{namedRoutes: make(map[string]*Route)}
return &Router{
namedRoutes: make(map[string]*Route),
routeConf: routeConf{
methodMatcher: methodDefaultMatcher{},
},
}
}

// Router registers routes to be matched and dispatches a handler.
Expand Down Expand Up @@ -104,9 +109,8 @@ type routeConf struct {

buildVarsFunc BuildVarsFunc

// If true, methods will be matched case insensitive.
// The methodCaseInsensitiveMatcher will be used instead of methodMatcher
matchMethodCaseInsensitive bool
// Holds the default method matcher
methodMatcher matcher
}

// returns an effective deep copy of `routeConf`
Expand Down Expand Up @@ -282,12 +286,6 @@ func (r *Router) OmitRouteFromContext(value bool) *Router {
return r
}

// MatchMethodCaseInsensitive defines the behaviour of ignoring casing for request methods.
func (r *Router) MatchMethodCaseInsensitive(value bool) *Router {
r.matchMethodCaseInsensitive = value
return r
}

// UseEncodedPath tells the router to match the encoded original path
// to the routes.
// For eg. "/path/foo%2Fbar/to" will match the path "/path/{var}/to".
Expand All @@ -299,6 +297,24 @@ func (r *Router) UseEncodedPath() *Router {
return r
}

// MatchMethodCaseInsensitive defines the behaviour of ignoring casing for request methods.
func (r *Router) MatchMethodDefault() *Router {
r.methodMatcher = methodDefaultMatcher{}
return r
}

// MatchMethodCaseInsensitive defines the behaviour of ignoring casing for request methods.
func (r *Router) MatchMethodCaseInsensitive() *Router {
r.methodMatcher = methodCaseInsensitiveMatcher{}
return r
}

// MatchMethodExact defines the behaviour of matching exact request methods.
func (r *Router) MatchMethodExact() *Router {
r.methodMatcher = methodCaseExactMatcher{}
return r
}

// ----------------------------------------------------------------------------
// Route factories
// ----------------------------------------------------------------------------
Expand Down
66 changes: 64 additions & 2 deletions mux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2132,8 +2132,8 @@ func TestMethodMatchingCaseInsensitiveOnRouter(t *testing.T) {
t.Errorf("Expecting code %v", 405)
}

r.MatchMethodCaseInsensitive(true)
r.HandleFunc("/a", func1).Methods("get")
r.MatchMethodCaseInsensitive()
r.HandleFunc("/a", func1).Methods("get").Name("t")
req, _ = http.NewRequest("get", "http://localhost/a", nil)

match = new(RouteMatch)
Expand All @@ -2148,6 +2148,68 @@ func TestMethodMatchingCaseInsensitiveOnRouter(t *testing.T) {
}
}

func TestMethodMatchingCaseExact(t *testing.T) {
func1 := func(w http.ResponseWriter, r *http.Request) {}

r := NewRouter()
r.HandleFunc("/a", func1).Methods("get")
r.HandleFunc("/b", func1).MethodsCaseExact("get")

req, _ := http.NewRequest("get", "http://localhost/a", nil)
match := new(RouteMatch)
matched := r.Match(req, match)

if matched {
t.Error("Should not have matched route for method")
}

if match.MatchErr != ErrMethodMismatch {
t.Error("Should get ErrMethodMismatch error")
}

resp := NewRecorder()
r.ServeHTTP(resp, req)
if resp.Code != http.StatusMethodNotAllowed {
t.Errorf("Expecting code %v", 405)
}

req, _ = http.NewRequest("GET", "http://localhost/b", nil)
match = new(RouteMatch)
matched = r.Match(req, match)

if matched {
t.Error("Should not have matched route for method")
}

if match.MatchErr != ErrMethodMismatch {
t.Error("Should get ErrMethodMismatch error")
}

resp = NewRecorder()
r.ServeHTTP(resp, req)
if resp.Code != http.StatusMethodNotAllowed {
t.Errorf("Expecting code %v", 405)
}

resp = NewRecorder()
r.ServeHTTP(resp, req)
if resp.Code != http.StatusMethodNotAllowed {
t.Errorf("Expecting code %v", 405)
}

req, _ = http.NewRequest("get", "http://localhost/b", nil)
match = new(RouteMatch)
matched = r.Match(req, match)

if !matched {
t.Error("Should have matched route")
}

if match.MatchErr != nil {
t.Error("Should not have any matching error. Found:", match.MatchErr)
}
}

func TestMultipleDefinitionOfSamePathWithDifferentMethods(t *testing.T) {
emptyHandler := func(w http.ResponseWriter, r *http.Request) {}

Expand Down
10 changes: 5 additions & 5 deletions old_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -281,29 +281,29 @@ var hostMatcherTests = []hostMatcherTest{
}

type methodMatcherTest struct {
matcher methodMatcher
matcher methodDefaultMatcher
method string
result bool
}

var methodMatcherTests = []methodMatcherTest{
{
matcher: methodMatcher([]string{"GET", "POST", "PUT"}),
matcher: ([]string{"GET", "POST", "PUT"}),
method: "GET",
result: true,
},
{
matcher: methodMatcher([]string{"GET", "POST", "PUT"}),
matcher: methodDefaultMatcher([]string{"GET", "POST", "PUT"}),
method: "POST",
result: true,
},
{
matcher: methodMatcher([]string{"GET", "POST", "PUT"}),
matcher: methodDefaultMatcher([]string{"GET", "POST", "PUT"}),
method: "PUT",
result: true,
},
{
matcher: methodMatcher([]string{"GET", "POST", "PUT"}),
matcher: methodDefaultMatcher([]string{"GET", "POST", "PUT"}),
method: "DELETE",
result: false,
},
Expand Down
63 changes: 51 additions & 12 deletions route.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,15 @@ func (r *Route) Match(req *http.Request, match *RouteMatch) bool {
// Match everything.
for _, m := range r.matchers {
if matched := m.Match(req, match); !matched {
if _, ok := m.(methodMatcher); ok {
if _, ok := m.(methodDefaultMatcher); ok {
matchErr = ErrMethodMismatch
continue
} else if _, ok := m.(methodCaseInsensitiveMatcher); ok {
matchErr = ErrMethodMismatch
continue
} else if _, ok := m.(methodCaseExactMatcher); ok {
matchErr = ErrMethodMismatch
continue
}

// Multiple routes may share the same path but use different HTTP methods. For instance:
Expand Down Expand Up @@ -325,30 +328,50 @@ func (r *Route) MatcherFunc(f MatcherFunc) *Route {

// Methods --------------------------------------------------------------------

// methodMatcher matches the request against HTTP methods.
type methodMatcher []string
// methodDefaultMatcher matches the request against HTTP methods.
// The supplied methods will be transformed to uppercase. The request method not.
type methodDefaultMatcher []string

func (m methodMatcher) Match(r *http.Request, match *RouteMatch) bool {
func (m methodDefaultMatcher) Match(r *http.Request, match *RouteMatch) bool {
return matchInArray(m, r.Method)
}

// methodMatcher matches the request against HTTP methods without case sensitivity.
// Both the supplied methods as well as the request method will be transformed to uppercase.
type methodCaseInsensitiveMatcher []string

func (m methodCaseInsensitiveMatcher) Match(r *http.Request, match *RouteMatch) bool {
return matchInArray(m, strings.ToUpper(r.Method))
}

// methodCaseExactMatcher matches the request against HTTP methods exactly.
// No transformation of supplied methods or the request method is applied.
type methodCaseExactMatcher []string

func (m methodCaseExactMatcher) Match(r *http.Request, match *RouteMatch) bool {
return matchInArray(m, r.Method)
}

// Methods adds a matcher for HTTP methods.
// It accepts a sequence of one or more methods to be matched, e.g.:
// "GET", "POST", "PUT".
func (r *Route) Methods(methods ...string) *Route {
if r.routeConf.matchMethodCaseInsensitive {
if _, ok := r.methodMatcher.(methodCaseInsensitiveMatcher); ok {
return r.MethodsCaseInsensitive(methods...)
} else if _, ok := r.methodMatcher.(methodCaseExactMatcher); ok {
return r.MethodsCaseExact(methods...)
} else {
return r.MethodsCaseSensitive(methods...)
return r.MethodsDefault(methods...)
}
}

// Methods adds a matcher for HTTP methods.
// It accepts a sequence of one or more methods to be matched, e.g.:
// "GET", "POST", "PUT".
func (r *Route) MethodsDefault(methods ...string) *Route {
return r.addMatcher(methodDefaultMatcher(sliceToUpper(methods)))
}

// MethodsCaseInsensitive adds a matcher for HTTP methods without case sensitivity.
// This will override the initial config on the router for 'matchMethodCaseInsensitive'
// It accepts a sequence of one or more methods to be matched, e.g.:
Expand All @@ -357,12 +380,12 @@ func (r *Route) MethodsCaseInsensitive(methods ...string) *Route {
return r.addMatcher(methodCaseInsensitiveMatcher(sliceToUpper(methods)))
}

// MethodsCaseInsensitive adds a matcher for HTTP methods with case sensitivity.
// MethodsCaseInsensitive adds a matcher for exact HTTP methods with no transformation.
// This will override the initial config on the router for 'matchMethodCaseInsensitive'
// It accepts a sequence of one or more methods to be matched, e.g.:
// "GET", "POST", "PUT".
func (r *Route) MethodsCaseSensitive(methods ...string) *Route {
return r.addMatcher(methodMatcher(sliceToUpper(methods)))
func (r *Route) MethodsCaseExact(methods ...string) *Route {
return r.addMatcher(methodCaseExactMatcher(methods))
}

// Path -----------------------------------------------------------------------
Expand Down Expand Up @@ -732,7 +755,11 @@ func (r *Route) GetMethods() ([]string, error) {
return nil, r.err
}
for _, m := range r.matchers {
if methods, ok := m.(methodMatcher); ok {
if methods, ok := m.(methodDefaultMatcher); ok {
return []string(methods), nil
} else if methods, ok := m.(methodCaseInsensitiveMatcher); ok {
return []string(methods), nil
} else if methods, ok := m.(methodCaseExactMatcher); ok {
return []string(methods), nil
}
}
Expand Down Expand Up @@ -790,8 +817,20 @@ func (r *Route) buildVars(m map[string]string) map[string]string {
return m
}

// MatchMethodDefault defines the behaviour of matching request methods with the default matcher on this route.
func (r *Route) MatchMethodDefault() *Route {
r.methodMatcher = methodDefaultMatcher{}
return r
}

// MatchMethodCaseInsensitive defines the behaviour of ignoring casing for request methods on this route.
func (r *Route) MatchMethodCaseInsensitive(value bool) *Route {
r.matchMethodCaseInsensitive = value
func (r *Route) MatchMethodCaseInsensitive() *Route {
r.methodMatcher = methodCaseInsensitiveMatcher{}
return r
}

// MatchMethodCaseExact defines the behaviour of matching exact request methods on this route.
func (r *Route) MatchMethodCaseExact(value bool) *Route {
r.methodMatcher = methodCaseExactMatcher{}
return r
}

0 comments on commit c3e6635

Please sign in to comment.