From 73a761b2351011d2d43a650adf1305caf0d38380 Mon Sep 17 00:00:00 2001 From: Edward McFarlane Date: Fri, 25 Oct 2024 17:45:12 -0400 Subject: [PATCH] Enforce discard limits on readers This enforces limits on discard to avoid unbounded reads. Where resources were already exhausted no further reads are done and discards have been removed. These discards were an optimization to reuse connections. When a stream is partially read all subsequent reads will now return EOF errors to avoid reading in a corrupted state. --- compression.go | 12 ++++-------- connect_ext_test.go | 7 +------ envelope.go | 33 ++++++++++++++------------------- protocol.go | 8 ++++---- protocol_connect.go | 14 +++++--------- protocol_grpc.go | 4 ++-- 6 files changed, 30 insertions(+), 48 deletions(-) diff --git a/compression.go b/compression.go index ee43412b..0a1db568 100644 --- a/compression.go +++ b/compression.go @@ -96,17 +96,13 @@ func (c *compressionPool) Decompress(dst *bytes.Buffer, src *bytes.Buffer, readM } return errorf(CodeInvalidArgument, "decompress: %w", err) } - if readMaxBytes > 0 && bytesRead > readMaxBytes { - discardedBytes, err := io.Copy(io.Discard, decompressor) - _ = c.putDecompressor(decompressor) - if err != nil { - return errorf(CodeResourceExhausted, "message is larger than configured max %d - unable to determine message size: %w", readMaxBytes, err) - } - return errorf(CodeResourceExhausted, "message size %d is larger than configured max %d", bytesRead+discardedBytes, readMaxBytes) - } if err := c.putDecompressor(decompressor); err != nil { return errorf(CodeUnknown, "recycle decompressor: %w", err) } + if readMaxBytes > 0 && bytesRead > readMaxBytes { + // Resource is exhausted, fail fast without reading more data from the reader. + return errorf(CodeResourceExhausted, "decompressed message size is larger than configured max %d", readMaxBytes) + } return nil } diff --git a/connect_ext_test.go b/connect_ext_test.go index b93c5708..a783e85d 100644 --- a/connect_ext_test.go +++ b/connect_ext_test.go @@ -1197,7 +1197,6 @@ func TestHandlerWithReadMaxBytes(t *testing.T) { _, err := client.Ping(context.Background(), connect.NewRequest(pingRequest)) assert.NotNil(t, err, assert.Sprintf("expected non-nil error for large message")) assert.Equal(t, connect.CodeOf(err), connect.CodeResourceExhausted) - assert.True(t, strings.HasSuffix(err.Error(), fmt.Sprintf("message size %d is larger than configured max %d", proto.Size(pingRequest), readMaxBytes))) }) t.Run("read_max_large", func(t *testing.T) { t.Parallel() @@ -1206,16 +1205,14 @@ func TestHandlerWithReadMaxBytes(t *testing.T) { } // Serializes to much larger than readMaxBytes (5 MiB) pingRequest := &pingv1.PingRequest{Text: strings.Repeat("abcde", 1024*1024)} - expectedSize := proto.Size(pingRequest) // With gzip request compression, the error should indicate the envelope size (before decompression) is too large. if compressed { - expectedSize = gzipCompressedSize(t, pingRequest) + expectedSize := gzipCompressedSize(t, pingRequest) assert.True(t, expectedSize > readMaxBytes, assert.Sprintf("expected compressed size %d > %d", expectedSize, readMaxBytes)) } _, err := client.Ping(context.Background(), connect.NewRequest(pingRequest)) assert.NotNil(t, err, assert.Sprintf("expected non-nil error for large message")) assert.Equal(t, connect.CodeOf(err), connect.CodeResourceExhausted) - assert.Equal(t, err.Error(), fmt.Sprintf("resource_exhausted: message size %d is larger than configured max %d", expectedSize, readMaxBytes)) }) } newHTTP2Server := func(t *testing.T) *memhttp.Server { @@ -1378,7 +1375,6 @@ func TestClientWithReadMaxBytes(t *testing.T) { _, err := client.Ping(context.Background(), connect.NewRequest(pingRequest)) assert.NotNil(t, err, assert.Sprintf("expected non-nil error for large message")) assert.Equal(t, connect.CodeOf(err), connect.CodeResourceExhausted) - assert.True(t, strings.HasSuffix(err.Error(), fmt.Sprintf("message size %d is larger than configured max %d", proto.Size(pingRequest), readMaxBytes))) }) t.Run("read_max_large", func(t *testing.T) { t.Parallel() @@ -1397,7 +1393,6 @@ func TestClientWithReadMaxBytes(t *testing.T) { _, err := client.Ping(context.Background(), connect.NewRequest(pingRequest)) assert.NotNil(t, err, assert.Sprintf("expected non-nil error for large message")) assert.Equal(t, connect.CodeOf(err), connect.CodeResourceExhausted) - assert.Equal(t, err.Error(), fmt.Sprintf("resource_exhausted: message size %d is larger than configured max %d", expectedSize, readMaxBytes)) }) } t.Run("connect", func(t *testing.T) { diff --git a/envelope.go b/envelope.go index bc85c551..ec296dee 100644 --- a/envelope.go +++ b/envelope.go @@ -228,9 +228,13 @@ type envelopeReader struct { compressionPool *compressionPool bufferPool *bufferPool readMaxBytes int + isEOF bool } func (r *envelopeReader) Unmarshal(message any) *Error { + if r.isEOF { + return NewError(CodeInternal, io.EOF) + } buffer := r.bufferPool.Get() var dontRelease *bytes.Buffer defer func() { @@ -240,25 +244,20 @@ func (r *envelopeReader) Unmarshal(message any) *Error { }() env := &envelope{Data: buffer} - err := r.Read(env) - switch { - case err == nil && env.IsSet(flagEnvelopeCompressed) && r.compressionPool == nil: + if err := r.Read(env); err != nil { + // Mark the reader as EOF so that subsequent reads return EOF. + r.isEOF = true + return err + } + if env.IsSet(flagEnvelopeCompressed) && r.compressionPool == nil { return errorf( CodeInternal, "protocol error: sent compressed message without compression support", ) - case err == nil && - (env.Flags == 0 || env.Flags == flagEnvelopeCompressed) && - env.Data.Len() == 0: + } else if (env.Flags == 0 || env.Flags == flagEnvelopeCompressed) && env.Data.Len() == 0 { // This is a standard message (because none of the top 7 bits are set) and // there's no data, so the zero value of the message is correct. return nil - case err != nil && errors.Is(err, io.EOF): - // The stream has ended. Propagate the EOF to the caller. - return err - case err != nil: - // Something's wrong. - return err } data := env.Data @@ -317,7 +316,7 @@ func (r *envelopeReader) Read(env *envelope) *Error { // The stream ended cleanly. That's expected, but we need to propagate an EOF // to the user so that they know that the stream has ended. We shouldn't // add any alarming text about protocol errors, though. - return NewError(CodeUnknown, err) + return NewError(CodeInternal, err) } err = wrapIfMaxBytesError(err, "read 5 byte message prefix") err = wrapIfContextDone(r.ctx, err) @@ -332,12 +331,8 @@ func (r *envelopeReader) Read(env *envelope) *Error { } size := int64(binary.BigEndian.Uint32(prefixes[1:5])) if r.readMaxBytes > 0 && size > int64(r.readMaxBytes) { - n, err := io.CopyN(io.Discard, r.reader, size) - r.bytesRead += n - if err != nil && !errors.Is(err, io.EOF) { - return errorf(CodeResourceExhausted, "message is larger than configured max %d - unable to determine message size: %w", r.readMaxBytes, err) - } - return errorf(CodeResourceExhausted, "message size %d is larger than configured max %d", size, r.readMaxBytes) + // Resource is exhausted, fail fast without reading more data from the stream. + return errorf(CodeResourceExhausted, "received message size %d is larger than configured max %d", size, r.readMaxBytes) } // We've read the prefix, so we know how many bytes to expect. // CopyN will return an error if it doesn't read the requested diff --git a/protocol.go b/protocol.go index 9add614c..dc8e3d06 100644 --- a/protocol.go +++ b/protocol.go @@ -287,12 +287,12 @@ func isCommaOrSpace(c rune) bool { } func discard(reader io.Reader) (int64, error) { - if lr, ok := reader.(*io.LimitedReader); ok { - return io.Copy(io.Discard, lr) - } // We don't want to get stuck throwing data away forever, so limit how much // we're willing to do here. - lr := &io.LimitedReader{R: reader, N: discardLimit} + lr, ok := reader.(*io.LimitedReader) + if !ok { + lr = &io.LimitedReader{R: reader, N: discardLimit} + } return io.Copy(io.Discard, lr) } diff --git a/protocol_connect.go b/protocol_connect.go index 6828ab4d..5f7cec94 100644 --- a/protocol_connect.go +++ b/protocol_connect.go @@ -1088,8 +1088,8 @@ type connectUnaryUnmarshaler struct { codec Codec compressionPool *compressionPool bufferPool *bufferPool - alreadyRead bool readMaxBytes int + isEOF bool } func (u *connectUnaryUnmarshaler) Unmarshal(message any) *Error { @@ -1097,10 +1097,10 @@ func (u *connectUnaryUnmarshaler) Unmarshal(message any) *Error { } func (u *connectUnaryUnmarshaler) UnmarshalFunc(message any, unmarshal func([]byte, any) error) *Error { - if u.alreadyRead { + if u.isEOF { return NewError(CodeInternal, io.EOF) } - u.alreadyRead = true + u.isEOF = true data := u.bufferPool.Get() defer u.bufferPool.Put(data) reader := u.reader @@ -1118,12 +1118,8 @@ func (u *connectUnaryUnmarshaler) UnmarshalFunc(message any, unmarshal func([]by return errorf(CodeUnknown, "read message: %w", err) } if u.readMaxBytes > 0 && bytesRead > int64(u.readMaxBytes) { - // Attempt to read to end in order to allow connection re-use - discardedBytes, err := io.Copy(io.Discard, u.reader) - if err != nil { - return errorf(CodeResourceExhausted, "message is larger than configured max %d - unable to determine message size: %w", u.readMaxBytes, err) - } - return errorf(CodeResourceExhausted, "message size %d is larger than configured max %d", bytesRead+discardedBytes, u.readMaxBytes) + // Resource is exhausted, fail fast without reading more data from the stream. + return errorf(CodeResourceExhausted, "message size is larger than configured max %d", u.readMaxBytes) } if data.Len() > 0 && u.compressionPool != nil { decompressed := u.bufferPool.Get() diff --git a/protocol_grpc.go b/protocol_grpc.go index e10ecad7..db1890e8 100644 --- a/protocol_grpc.go +++ b/protocol_grpc.go @@ -319,8 +319,8 @@ func (g *grpcClient) NewConn( } } else { conn.readTrailers = func(_ *grpcUnmarshaler, call *duplexHTTPCall) http.Header { - // To access HTTP trailers, we need to read the body to EOF. - _, _ = discard(call) + // Caller must guarantee the body is read to EOF to access + // trailers. return call.ResponseTrailer() } }