diff --git a/transport.go b/transport.go index 7b81f08..51dd2c0 100644 --- a/transport.go +++ b/transport.go @@ -115,3 +115,12 @@ func DoerTransport(cl interface { }) Transport { return RoundTripFunc(cl.Do) } + +// ErrorTransport always returns the specified error instead of connecting. +// It is intended for use in testing +// or to prevent accidental use of http.DefaultClient. +func ErrorTransport(err error) Transport { + return RoundTripFunc(func(req *http.Request) (*http.Response, error) { + return nil, err + }) +} diff --git a/transport_example_test.go b/transport_example_test.go index 737351b..256108f 100644 --- a/transport_example_test.go +++ b/transport_example_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "crypto/md5" + "errors" "fmt" "io" "net/http" @@ -111,9 +112,8 @@ func ExampleLogTransport() { fmt.Println("Error!", err) } // Works for bad responses too - baseTrans = requests.RoundTripFunc(func(req *http.Request) (*http.Response, error) { - return nil, fmt.Errorf("can't connect") - }) + baseTrans = requests.ErrorTransport(errors.New("can't connect")) + trans = requests.LogTransport(baseTrans, logger) if err := requests.