diff --git a/README.md b/README.md index 0d1ec002..90b7497f 100644 --- a/README.md +++ b/README.md @@ -426,8 +426,6 @@ spec: - name: traefik-forward-auth ``` -Note: If using auth host mode, you must apply the middleware to your auth host ingress. - See the examples directory for more examples. #### Selective Container Authentication in Swarm @@ -442,8 +440,6 @@ whoami: - "traefik.http.routers.whoami.middlewares=traefik-forward-auth" ``` -Note: If using auth host mode, you must apply the middleware to the traefik-forward-auth container. - See the examples directory for more examples. #### Rules Based Authentication diff --git a/examples/traefik-v2/kubernetes/advanced-separate-pod/traefik-forward-auth/ingress.yaml b/examples/traefik-v2/kubernetes/advanced-separate-pod/traefik-forward-auth/ingress.yaml index 6d416e01..74ad0e9f 100644 --- a/examples/traefik-v2/kubernetes/advanced-separate-pod/traefik-forward-auth/ingress.yaml +++ b/examples/traefik-v2/kubernetes/advanced-separate-pod/traefik-forward-auth/ingress.yaml @@ -16,7 +16,5 @@ spec: services: - name: traefik-forward-auth port: 4181 - middlewares: - - name: traefik-forward-auth tls: certresolver: default diff --git a/internal/auth.go b/internal/auth.go index 0b0a9676..9b8f0b16 100644 --- a/internal/auth.go +++ b/internal/auth.go @@ -125,24 +125,19 @@ func ValidateDomains(email string, domains CommaSeparatedList) bool { // Get the redirect base func redirectBase(r *http.Request) string { - proto := r.Header.Get("X-Forwarded-Proto") - host := r.Header.Get("X-Forwarded-Host") - - return fmt.Sprintf("%s://%s", proto, host) + return fmt.Sprintf("%s://%s", r.Header.Get("X-Forwarded-Proto"), r.Host) } // Return url func returnUrl(r *http.Request) string { - path := r.Header.Get("X-Forwarded-Uri") - - return fmt.Sprintf("%s%s", redirectBase(r), path) + return fmt.Sprintf("%s%s", redirectBase(r), r.URL.Path) } // Get oauth redirect uri func redirectUri(r *http.Request) string { if use, _ := useAuthDomain(r); use { - proto := r.Header.Get("X-Forwarded-Proto") - return fmt.Sprintf("%s://%s%s", proto, config.AuthHost, config.Path) + p := r.Header.Get("X-Forwarded-Proto") + return fmt.Sprintf("%s://%s%s", p, config.AuthHost, config.Path) } return fmt.Sprintf("%s%s", redirectBase(r), config.Path) @@ -155,7 +150,7 @@ func useAuthDomain(r *http.Request) (bool, string) { } // Does the request match a given cookie domain? - reqMatch, reqHost := matchCookieDomains(r.Header.Get("X-Forwarded-Host")) + reqMatch, reqHost := matchCookieDomains(r.Host) // Do any of the auth hosts match a cookie domain? authMatch, authHost := matchCookieDomains(config.AuthHost) @@ -284,10 +279,8 @@ func Nonce() (error, string) { // Cookie domain func cookieDomain(r *http.Request) string { - host := r.Header.Get("X-Forwarded-Host") - // Check if any of the given cookie domains matches - _, domain := matchCookieDomains(host) + _, domain := matchCookieDomains(r.Host) return domain } @@ -297,7 +290,7 @@ func csrfCookieDomain(r *http.Request) string { if use, domain := useAuthDomain(r); use { host = domain } else { - host = r.Header.Get("X-Forwarded-Host") + host = r.Host } // Remove port diff --git a/internal/auth_test.go b/internal/auth_test.go index 5b0bedaf..74e8d2f2 100644 --- a/internal/auth_test.go +++ b/internal/auth_test.go @@ -2,6 +2,7 @@ package tfa import ( "net/http" + "net/http/httptest" "net/url" "strings" "testing" @@ -196,10 +197,8 @@ func TestAuthValidateEmail(t *testing.T) { func TestRedirectUri(t *testing.T) { assert := assert.New(t) - r, _ := http.NewRequest("GET", "http://example.com", nil) + r := httptest.NewRequest("GET", "http://app.example.com/hello", nil) r.Header.Add("X-Forwarded-Proto", "http") - r.Header.Add("X-Forwarded-Host", "app.example.com") - r.Header.Add("X-Forwarded-Uri", "/hello") // // No Auth Host @@ -241,10 +240,8 @@ func TestRedirectUri(t *testing.T) { // With Auth URL + cookie domain, but from different domain // - will not use auth host // - r, _ = http.NewRequest("GET", "http://another.com", nil) + r = httptest.NewRequest("GET", "https://another.com/hello", nil) r.Header.Add("X-Forwarded-Proto", "https") - r.Header.Add("X-Forwarded-Host", "another.com") - r.Header.Add("X-Forwarded-Uri", "/hello") config.AuthHost = "auth.example.com" config.CookieDomains = []CookieDomain{*NewCookieDomain("example.com")} @@ -378,10 +375,8 @@ func TestValidateState(t *testing.T) { func TestMakeState(t *testing.T) { assert := assert.New(t) - r, _ := http.NewRequest("GET", "http://example.com", nil) + r := httptest.NewRequest("GET", "http://example.com/hello", nil) r.Header.Add("X-Forwarded-Proto", "http") - r.Header.Add("X-Forwarded-Host", "example.com") - r.Header.Add("X-Forwarded-Uri", "/hello") // Test with google p := provider.Google{} diff --git a/internal/server.go b/internal/server.go index 4bce9635..2e20df53 100644 --- a/internal/server.go +++ b/internal/server.go @@ -58,7 +58,11 @@ func (s *Server) RootHandler(w http.ResponseWriter, r *http.Request) { // Modify request r.Method = r.Header.Get("X-Forwarded-Method") r.Host = r.Header.Get("X-Forwarded-Host") - r.URL, _ = url.Parse(r.Header.Get("X-Forwarded-Uri")) + + // Read URI from header if we're acting as forward auth middleware + if _, ok := r.Header["X-Forwarded-Uri"]; ok { + r.URL, _ = url.Parse(r.Header.Get("X-Forwarded-Uri")) + } // Pass to mux s.router.ServeHTTP(w, r) diff --git a/internal/server_test.go b/internal/server_test.go index 8ec0f01d..d461a4cc 100644 --- a/internal/server_test.go +++ b/internal/server_test.go @@ -31,6 +31,37 @@ func init() { * Tests */ +func TestServerRootHandler(t *testing.T) { + assert := assert.New(t) + config = newDefaultConfig() + + // X-Forwarded headers should be read into request + req := httptest.NewRequest("POST", "http://should-use-x-forwarded.com/should?ignore=me", nil) + req.Header.Add("X-Forwarded-Method", "GET") + req.Header.Add("X-Forwarded-Proto", "https") + req.Header.Add("X-Forwarded-Host", "example.com") + req.Header.Add("X-Forwarded-Uri", "/foo?q=bar") + NewServer().RootHandler(httptest.NewRecorder(), req) + + assert.Equal("GET", req.Method, "x-forwarded-method should be read into request") + assert.Equal("example.com", req.Host, "x-forwarded-host should be read into request") + assert.Equal("/foo", req.URL.Path, "x-forwarded-uri should be read into request") + assert.Equal("/foo?q=bar", req.URL.RequestURI(), "x-forwarded-uri should be read into request") + + // Other X-Forwarded headers should be read in into request and original URL + // should be preserved if X-Forwarded-Uri not present + req = httptest.NewRequest("POST", "http://should-use-x-forwarded.com/should-not?ignore=me", nil) + req.Header.Add("X-Forwarded-Method", "GET") + req.Header.Add("X-Forwarded-Proto", "https") + req.Header.Add("X-Forwarded-Host", "example.com") + NewServer().RootHandler(httptest.NewRecorder(), req) + + assert.Equal("GET", req.Method, "x-forwarded-method should be read into request") + assert.Equal("example.com", req.Host, "x-forwarded-host should be read into request") + assert.Equal("/should-not", req.URL.Path, "request url should be preserved if x-forwarded-uri not present") + assert.Equal("/should-not?ignore=me", req.URL.RequestURI(), "request url should be preserved if x-forwarded-uri not present") +} + func TestServerAuthHandlerInvalid(t *testing.T) { assert := assert.New(t) config = newDefaultConfig() @@ -90,10 +121,10 @@ func TestServerAuthHandlerExpired(t *testing.T) { config.Domains = []string{"test.com"} // Should redirect expired cookie - req := newDefaultHttpRequest("/foo") + req := newHTTPRequest("GET", "http://example.com/foo") c := MakeCookie(req, "test@example.com") res, _ := doHttpRequest(req, c) - assert.Equal(307, res.StatusCode, "request with expired cookie should be redirected") + require.Equal(t, 307, res.StatusCode, "request with expired cookie should be redirected") // Check for CSRF cookie var cookie *http.Cookie @@ -116,7 +147,7 @@ func TestServerAuthHandlerValid(t *testing.T) { config = newDefaultConfig() // Should allow valid request email - req := newDefaultHttpRequest("/foo") + req := newHTTPRequest("GET", "http://example.com/foo") c := MakeCookie(req, "test@example.com") config.Domains = []string{} @@ -131,6 +162,7 @@ func TestServerAuthHandlerValid(t *testing.T) { func TestServerAuthCallback(t *testing.T) { assert := assert.New(t) + require := require.New(t) config = newDefaultConfig() // Setup OAuth server @@ -148,27 +180,28 @@ func TestServerAuthCallback(t *testing.T) { } // Should pass auth response request to callback - req := newDefaultHttpRequest("/_oauth") + req := newHTTPRequest("GET", "http://example.com/_oauth") res, _ := doHttpRequest(req, nil) assert.Equal(401, res.StatusCode, "auth callback without cookie shouldn't be authorised") // Should catch invalid csrf cookie - req = newDefaultHttpRequest("/_oauth?state=12345678901234567890123456789012:http://redirect") + nonce := "12345678901234567890123456789012" + req = newHTTPRequest("GET", "http://example.com/_oauth?state="+nonce+":http://redirect") c := MakeCSRFCookie(req, "nononononononononononononononono") res, _ = doHttpRequest(req, c) assert.Equal(401, res.StatusCode, "auth callback with invalid cookie shouldn't be authorised") // Should catch invalid provider cookie - req = newDefaultHttpRequest("/_oauth?state=12345678901234567890123456789012:invalid:http://redirect") - c = MakeCSRFCookie(req, "12345678901234567890123456789012") + req = newHTTPRequest("GET", "http://example.com/_oauth?state="+nonce+":invalid:http://redirect") + c = MakeCSRFCookie(req, nonce) res, _ = doHttpRequest(req, c) assert.Equal(401, res.StatusCode, "auth callback with invalid provider shouldn't be authorised") // Should redirect valid request - req = newDefaultHttpRequest("/_oauth?state=12345678901234567890123456789012:google:http://redirect") - c = MakeCSRFCookie(req, "12345678901234567890123456789012") + req = newHTTPRequest("GET", "http://example.com/_oauth?state="+nonce+":google:http://redirect") + c = MakeCSRFCookie(req, nonce) res, _ = doHttpRequest(req, c) - assert.Equal(307, res.StatusCode, "valid auth callback should be allowed") + require.Equal(307, res.StatusCode, "valid auth callback should be allowed") fwd, _ := res.Location() assert.Equal("http", fwd.Scheme, "valid request should be redirected to return url") @@ -360,17 +393,17 @@ func TestServerRouteHost(t *testing.T) { } // Should block any request - req := newHttpRequest("GET", "https://example.com/", "/") + req := newHTTPRequest("GET", "https://example.com/") res, _ := doHttpRequest(req, nil) assert.Equal(307, res.StatusCode, "request not matching any rule should require auth") // Should allow matching request - req = newHttpRequest("GET", "https://api.example.com/", "/") + req = newHTTPRequest("GET", "https://api.example.com/") res, _ = doHttpRequest(req, nil) assert.Equal(200, res.StatusCode, "request matching allow rule should be allowed") // Should allow matching request - req = newHttpRequest("GET", "https://sub8.example.com/", "/") + req = newHTTPRequest("GET", "https://sub8.example.com/") res, _ = doHttpRequest(req, nil) assert.Equal(200, res.StatusCode, "request matching allow rule should be allowed") } @@ -386,12 +419,12 @@ func TestServerRouteMethod(t *testing.T) { } // Should block any request - req := newHttpRequest("GET", "https://example.com/", "/") + req := newHTTPRequest("GET", "https://example.com/") res, _ := doHttpRequest(req, nil) assert.Equal(307, res.StatusCode, "request not matching any rule should require auth") // Should allow matching request - req = newHttpRequest("PUT", "https://example.com/", "/") + req = newHTTPRequest("PUT", "https://example.com/") res, _ = doHttpRequest(req, nil) assert.Equal(200, res.StatusCode, "request matching allow rule should be allowed") } @@ -441,12 +474,12 @@ func TestServerRouteQuery(t *testing.T) { } // Should block any request - req := newHttpRequest("GET", "https://example.com/", "/?q=no") + req := newHTTPRequest("GET", "https://example.com/?q=no") res, _ := doHttpRequest(req, nil) assert.Equal(307, res.StatusCode, "request not matching any rule should require auth") // Should allow matching request - req = newHttpRequest("GET", "https://api.example.com/", "/?q=test123") + req = newHTTPRequest("GET", "https://api.example.com/?q=test123") res, _ = doHttpRequest(req, nil) assert.Equal(200, res.StatusCode, "request matching allow rule should be allowed") } @@ -531,16 +564,17 @@ func newDefaultConfig() *Config { return config } +// TODO: replace with newHTTPRequest("GET", "http://example.com/"+uri) func newDefaultHttpRequest(uri string) *http.Request { - return newHttpRequest("", "http://example.com/", uri) + return newHTTPRequest("GET", "http://example.com"+uri) } -func newHttpRequest(method, dest, uri string) *http.Request { - r := httptest.NewRequest("", "http://should-use-x-forwarded.com", nil) - p, _ := url.Parse(dest) +func newHTTPRequest(method, target string) *http.Request { + u, _ := url.Parse(target) + r := httptest.NewRequest(method, target, nil) r.Header.Add("X-Forwarded-Method", method) - r.Header.Add("X-Forwarded-Proto", p.Scheme) - r.Header.Add("X-Forwarded-Host", p.Host) - r.Header.Add("X-Forwarded-Uri", uri) + r.Header.Add("X-Forwarded-Proto", u.Scheme) + r.Header.Add("X-Forwarded-Host", u.Host) + r.Header.Add("X-Forwarded-Uri", u.RequestURI()) return r }