Skip to content

Commit

Permalink
fix(api): url wildcard param (#17)
Browse files Browse the repository at this point in the history
* fix(api): url wildcard param

* fix: unit test failing due to unset route context

---------

Co-authored-by: Niklas Treml <[email protected]>
  • Loading branch information
y-eight and niklastreml authored Nov 24, 2023
1 parent 2fadcf4 commit 2ea017e
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 6 deletions.
8 changes: 4 additions & 4 deletions pkg/checks/roundtrip.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func (rt *RoundTrip) Run(ctx context.Context) (Result, error) {

func (rt *RoundTrip) Startup(ctx context.Context, cResult chan<- Result) error {
// TODO register http handler for this check
http.HandleFunc("/rtt", func(w http.ResponseWriter, r *http.Request) {
http.HandleFunc("rtt", func(w http.ResponseWriter, r *http.Request) {
// TODO handle
})

Expand All @@ -58,7 +58,7 @@ func (rt *RoundTrip) Startup(ctx context.Context, cResult chan<- Result) error {
// Shutdown is called once when the check is unregistered or sparrow shuts down

func (rt *RoundTrip) Shutdown(ctx context.Context) error {
http.Handle("/rtt", http.NotFoundHandler())
http.Handle("rtt", http.NotFoundHandler())

return nil
}
Expand All @@ -78,11 +78,11 @@ func (rt *RoundTrip) Schema() (*openapi3.SchemaRef, error) {
}

func (rt *RoundTrip) RegisterHandler(ctx context.Context, router *api.RoutingTree) {
router.Add(http.MethodGet, "/rtt", rt.handleRoundtrip)
router.Add(http.MethodGet, "rtt", rt.handleRoundtrip)
}

func (rt *RoundTrip) DeregisterHandler(ctx context.Context, router *api.RoutingTree) {
router.Remove(http.MethodGet, "/rtt")
router.Remove(http.MethodGet, "rtt")
}

func (rt *RoundTrip) handleRoundtrip(w http.ResponseWriter, r *http.Request) {
Expand Down
2 changes: 1 addition & 1 deletion pkg/sparrow/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ func (s *Sparrow) getOpenapi(w http.ResponseWriter, r *http.Request) {
// Returns a 404 if no handler is registered for the request
func (s *Sparrow) handleChecks(w http.ResponseWriter, r *http.Request) {
method := r.Method
path := r.URL.Path
path := chi.URLParam(r, "*")

handler, ok := s.routingTree.Get(method, path)
if !ok {
Expand Down
11 changes: 10 additions & 1 deletion pkg/sparrow/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,15 @@ func TestSparrow_getCheckMetrics(t *testing.T) {
}
}

func addRouteParams(r *http.Request, values map[string]string) *http.Request {
rctx := chi.NewRouteContext()
for k, v := range values {
rctx.URLParams.Add(k, v)
}

return r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, rctx))
}

func TestSparrow_handleChecks(t *testing.T) {
type route struct {
Method string
Expand Down Expand Up @@ -224,7 +233,7 @@ func TestSparrow_handleChecks(t *testing.T) {
wantCode int
}{
{name: "no check handlers", fields: fields{routingTree: api.NewRoutingTree()}, args: args{w: httptest.NewRecorder(), r: httptest.NewRequest(http.MethodGet, "/v1/notfound", bytes.NewBuffer([]byte{}))}, wantCode: http.StatusNotFound, want: []byte(http.StatusText(http.StatusNotFound))},
{name: "has check handlers", fields: fields{routingTree: api.NewRoutingTree(), routes: []route{{Method: http.MethodGet, Path: "/v1/test", Handler: func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("test")) }}}}, args: args{w: httptest.NewRecorder(), r: httptest.NewRequest(http.MethodGet, "/v1/test", bytes.NewBuffer([]byte{}))}, wantCode: http.StatusOK, want: []byte("test")},
{name: "has check handlers", fields: fields{routingTree: api.NewRoutingTree(), routes: []route{{Method: http.MethodGet, Path: "/v1/test", Handler: func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("test")) }}}}, args: args{w: httptest.NewRecorder(), r: addRouteParams(httptest.NewRequest(http.MethodGet, "/v1/test", bytes.NewBuffer([]byte{})), map[string]string{"*": "/v1/test"})}, wantCode: http.StatusOK, want: []byte("test")},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand Down

0 comments on commit 2ea017e

Please sign in to comment.