From c4a82d62aa08f0da1d0f6c0f47b3302d8a68db80 Mon Sep 17 00:00:00 2001 From: Steven Hartland Date: Thu, 19 Nov 2020 21:00:23 +0000 Subject: [PATCH] feat: Add TLS Handshake timeout support (#530) * feat: Add TLS Handshake timeout support Add support for configuring a timeout for TLS Handshake call via DialTLSHandshakeTimeout DialOption. If no option is specified then the default timeout is 10 seconds. Also: * Add a default connect timeout of 30 seconds matching that of net/http. Fixes #509 --- redis/conn.go | 57 +++++++++++++++++++++++++++++++++++----------- redis/conn_test.go | 39 +++++++++++++++++++++++++++++++ 2 files changed, 83 insertions(+), 13 deletions(-) diff --git a/redis/conn.go b/redis/conn.go index 33b43be6..5d7841c6 100644 --- a/redis/conn.go +++ b/redis/conn.go @@ -75,17 +75,27 @@ type DialOption struct { } type dialOptions struct { - readTimeout time.Duration - writeTimeout time.Duration - dialer *net.Dialer - dialContext func(ctx context.Context, network, addr string) (net.Conn, error) - db int - username string - password string - clientName string - useTLS bool - skipVerify bool - tlsConfig *tls.Config + readTimeout time.Duration + writeTimeout time.Duration + tlsHandshakeTimeout time.Duration + dialer *net.Dialer + dialContext func(ctx context.Context, network, addr string) (net.Conn, error) + db int + username string + password string + clientName string + useTLS bool + skipVerify bool + tlsConfig *tls.Config +} + +// DialTLSHandshakeTimeout specifies the maximum amount of time waiting to +// wait for a TLS handshake. Zero means no timeout. +// If no DialTLSHandshakeTimeout option is specified then the default is 30 seconds. +func DialTLSHandshakeTimeout(d time.Duration) DialOption { + return DialOption{func(do *dialOptions) { + do.tlsHandshakeTimeout = d + }} } // DialReadTimeout specifies the timeout for reading a single command reply. @@ -104,6 +114,7 @@ func DialWriteTimeout(d time.Duration) DialOption { // DialConnectTimeout specifies the timeout for connecting to the Redis server when // no DialNetDial option is specified. +// If no DialConnectTimeout option is specified then the default is 30 seconds. func DialConnectTimeout(d time.Duration) DialOption { return DialOption{func(do *dialOptions) { do.dialer.Timeout = d @@ -201,13 +212,21 @@ func Dial(network, address string, options ...DialOption) (Conn, error) { return DialContext(context.Background(), network, address, options...) } +type tlsHandshakeTimeoutError struct{} + +func (tlsHandshakeTimeoutError) Timeout() bool { return true } +func (tlsHandshakeTimeoutError) Temporary() bool { return true } +func (tlsHandshakeTimeoutError) Error() string { return "TLS handshake timeout" } + // DialContext connects to the Redis server at the given network and // address using the specified options and context. func DialContext(ctx context.Context, network, address string, options ...DialOption) (Conn, error) { do := dialOptions{ dialer: &net.Dialer{ + Timeout: time.Second * 30, KeepAlive: time.Minute * 5, }, + tlsHandshakeTimeout: time.Second * 10, } for _, option := range options { option.f(&do) @@ -238,10 +257,22 @@ func DialContext(ctx context.Context, network, address string, options ...DialOp } tlsConn := tls.Client(netConn, tlsConfig) - if err := tlsConn.Handshake(); err != nil { - netConn.Close() + errc := make(chan error, 2) // buffered so we don't block timeout or Handshake + if d := do.tlsHandshakeTimeout; d != 0 { + timer := time.AfterFunc(d, func() { + errc <- tlsHandshakeTimeoutError{} + }) + defer timer.Stop() + } + go func() { + errc <- tlsConn.Handshake() + }() + if err := <-errc; err != nil { + // Timeout or Handshake error. + netConn.Close() // nolint: errcheck return nil, err } + netConn = tlsConn } diff --git a/redis/conn_test.go b/redis/conn_test.go index dbc66e74..97d7bec1 100644 --- a/redis/conn_test.go +++ b/redis/conn_test.go @@ -701,6 +701,45 @@ func TestDialUseTLS(t *testing.T) { checkPingPong(t, &buf, c) } +type blockedReader struct { + ch chan struct{} +} + +func (b blockedReader) Read(p []byte) (n int, err error) { + <-b.ch + return 0, nil +} + +func dialTestBlockedConn(ch chan struct{}, w io.Writer) redis.DialOption { + return redis.DialNetDial(func(network, addr string) (net.Conn, error) { + return &testConn{Reader: blockedReader{ch: ch}, Writer: w}, nil + }) +} + +func TestDialTLSHandshakeTimeout(t *testing.T) { + var buf bytes.Buffer + ch := make(chan struct{}) + var err error + go func() { + _, err = redis.Dial("tcp", "example.com:6379", + redis.DialTLSConfig(&clientTLSConfig), + redis.DialTLSHandshakeTimeout(time.Millisecond), + dialTestBlockedConn(ch, &buf), + redis.DialUseTLS(true)) + close(ch) + }() + select { + case <-time.After(time.Second): + t.Fatal("dial didn't timeout") + case <-ch: + if err == nil { + t.Fatal("dial didn't error") + } else if err.Error() != "TLS handshake timeout" { + t.Fatal("dial unexpected error:", err) + } + } +} + func TestDialTLSSKipVerify(t *testing.T) { var buf bytes.Buffer c, err := redis.Dial("tcp", "example.com:6379",