diff --git a/matchers/serve.go b/matchers/serve.go index 2dc13cf..53593c5 100644 --- a/matchers/serve.go +++ b/matchers/serve.go @@ -17,6 +17,7 @@ import ( func Serve(expected interface{}) *ServeMatcher { return &ServeMatcher{ expected: expected, + client: http.DefaultClient, docker: occam.NewDocker(), } } @@ -27,6 +28,7 @@ type ServeMatcher struct { endpoint string docker occam.Docker response string + client *http.Client } // OnPort sets the container port that is expected to be exposed. @@ -35,6 +37,14 @@ func (sm *ServeMatcher) OnPort(port int) *ServeMatcher { return sm } +// WithClient sets the http client that will be used to make the request. This +// allows for non-default client settings like custom redirect handling or +// adding a cookie jar. +func (sm *ServeMatcher) WithClient(client *http.Client) *ServeMatcher { + sm.client = client + return sm +} + // WithEndpoint sets the endpoint or subdirectory where the expected content // should be available. For example, WithEndpoint("/health") will attempt to // access the server's /health endpoint. @@ -74,7 +84,7 @@ func (sm *ServeMatcher) Match(actual interface{}) (success bool, err error) { return false, fmt.Errorf("ServeMatcher looking for response from container port %s which is not in container port map", port) } - response, err := http.Get(fmt.Sprintf("http://%s:%s%s", container.Host(), container.HostPort(port), sm.endpoint)) + response, err := sm.client.Get(fmt.Sprintf("http://%s:%s%s", container.Host(), container.HostPort(port), sm.endpoint)) if err != nil { return false, err diff --git a/matchers/serve_test.go b/matchers/serve_test.go index c7a5231..202afc9 100644 --- a/matchers/serve_test.go +++ b/matchers/serve_test.go @@ -39,6 +39,9 @@ func testServe(t *testing.T, context spec.G, it spec.S) { case "/": w.WriteHeader(http.StatusOK) fmt.Fprint(w, "some string") + case "/redirect": + w.Header()["Location"] = []string{"/"} + w.WriteHeader(http.StatusMovedPermanently) case "/empty": // do nothing case "/teapot": @@ -149,6 +152,38 @@ func testServe(t *testing.T, context spec.G, it spec.S) { }) }) + context("when given a client", func() { + var ( + redirectFunctionCalled bool + ) + + it.Before(func() { + redirectFunctionCalled = false + + client := &http.Client{ + CheckRedirect: func(req *http.Request, via []*http.Request) error { + redirectFunctionCalled = true + return nil + }, + } + + matcher = matcher.WithClient(client) + }) + + it("uses the provided client", func() { + result, err := matcher.WithEndpoint("/redirect").Match(occam.Container{ + Ports: map[string]string{ + "8080": port, + }, + Env: map[string]string{"PORT": "8080"}, + }) + Expect(err).NotTo(HaveOccurred()) + Expect(result).To(BeTrue()) + + Expect(redirectFunctionCalled).To(BeTrue()) + }) + }) + context("when given a port", func() { it.Before(func() { matcher = matcher.OnPort(8080)