diff --git a/httpbin/helpers.go b/httpbin/helpers.go index 726d1505..71cc6d07 100644 --- a/httpbin/helpers.go +++ b/httpbin/helpers.go @@ -59,6 +59,9 @@ func getURL(r *http.Request) *url.URL { if scheme == "" && r.Header.Get("X-Forwarded-Ssl") == "on" { scheme = "https" } + if scheme == "" && r.TLS != nil { + scheme = "https" + } if scheme == "" { scheme = "http" } diff --git a/httpbin/helpers_test.go b/httpbin/helpers_test.go index 18d9208a..1057df46 100644 --- a/httpbin/helpers_test.go +++ b/httpbin/helpers_test.go @@ -1,9 +1,11 @@ package httpbin import ( + "crypto/tls" "fmt" "io" "net/http" + "net/url" "reflect" "testing" "time" @@ -33,6 +35,82 @@ func assertError(t *testing.T, got, expected error) { } } +func mustParse(s string) *url.URL { + u, e := url.Parse(s) + if e != nil { + panic(e) + } + return u +} + +func TestGetURL(t *testing.T) { + baseUrl, _ := url.Parse("http://example.com/something?foo=bar") + tests := []struct { + name string + input *http.Request + expected *url.URL + }{ + { + "basic test", + &http.Request{ + URL: baseUrl, + Header: http.Header{}, + }, + mustParse("http://example.com/something?foo=bar"), + }, + { + "if TLS is not nil, scheme is https", + &http.Request{ + URL: baseUrl, + TLS: &tls.ConnectionState{}, + Header: http.Header{}, + }, + mustParse("https://example.com/something?foo=bar"), + }, + { + "if X-Forwarded-Proto is present, scheme is that value", + &http.Request{ + URL: baseUrl, + Header: http.Header{"X-Forwarded-Proto": {"https"}}, + }, + mustParse("https://example.com/something?foo=bar"), + }, + { + "if X-Forwarded-Proto is present, scheme is that value (2)", + &http.Request{ + URL: baseUrl, + Header: http.Header{"X-Forwarded-Proto": {"bananas"}}, + }, + mustParse("bananas://example.com/something?foo=bar"), + }, + { + "if X-Forwarded-Ssl is 'on', scheme is https", + &http.Request{ + URL: baseUrl, + Header: http.Header{"X-Forwarded-Ssl": {"on"}}, + }, + mustParse("https://example.com/something?foo=bar"), + }, + { + "if request URL host is empty, host is request.host", + &http.Request{ + URL: mustParse("http:///just/a/path"), + Host: "zombo.com", + }, + mustParse("http://zombo.com/just/a/path"), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + res := getURL(test.input) + if res.String() != test.expected.String() { + t.Fatalf("expected %s, got %s", test.expected, res) + } + }) + } +} + func TestParseDuration(t *testing.T) { okTests := []struct { input string