diff --git a/channel.go b/channel.go index feafd9a6b..872261e6d 100644 --- a/channel.go +++ b/channel.go @@ -143,10 +143,10 @@ func (ch *channel) recv() (messageHeader, []byte, error) { } func (ch *channel) send(streamID uint32, t messageType, flags uint8, p []byte) error { - // TODO: Error on send rather than on recv - //if len(p) > messageLengthMax { - // return status.Errorf(codes.InvalidArgument, "refusing to send, message length %v exceed maximum message size of %v", len(p), messageLengthMax) - //} + if len(p) > messageLengthMax { + return OversizedMessageError(len(p)) + } + if err := writeMessageHeader(ch.bw, ch.hwbuf[:], messageHeader{Length: uint32(len(p)), StreamID: streamID, Type: t, Flags: flags}); err != nil { return err } diff --git a/channel_test.go b/channel_test.go index de8b66d38..9eab63148 100644 --- a/channel_test.go +++ b/channel_test.go @@ -89,21 +89,19 @@ func TestReadWriteMessage(t *testing.T) { func TestMessageOversize(t *testing.T) { var ( - w, r = net.Pipe() - wch, rch = newChannel(w), newChannel(r) - msg = bytes.Repeat([]byte("a message of massive length"), 512<<10) - errs = make(chan error, 1) + w, _ = net.Pipe() + wch = newChannel(w) + msg = bytes.Repeat([]byte("a message of massive length"), 512<<10) + errs = make(chan error, 1) ) go func() { - if err := wch.send(1, 1, 0, msg); err != nil { - errs <- err - } + errs <- wch.send(1, 1, 0, msg) }() - _, _, err := rch.recv() + err := <-errs if err == nil { - t.Fatalf("error expected reading with small buffer") + t.Fatalf("sending oversized message expected to fail") } status, ok := status.FromError(err) @@ -114,12 +112,4 @@ func TestMessageOversize(t *testing.T) { if status.Code() != codes.ResourceExhausted { t.Fatalf("expected grpc status code: %v != %v", status.Code(), codes.ResourceExhausted) } - - select { - case err := <-errs: - if err != nil { - t.Fatal(err) - } - default: - } } diff --git a/errors.go b/errors.go index ec14b7952..632dbe8bd 100644 --- a/errors.go +++ b/errors.go @@ -16,7 +16,12 @@ package ttrpc -import "errors" +import ( + "errors" + + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) var ( // ErrProtocol is a general error in the handling the protocol. @@ -32,3 +37,44 @@ var ( // ErrStreamClosed is when the streaming connection is closed. ErrStreamClosed = errors.New("ttrpc: stream closed") ) + +// OversizedMessageErr is used to indicate refusal to send an oversized message. +// It wraps a ResourceExhausted grpc Status together with the offending message +// length. +type OversizedMessageErr struct { + messageLength int + err error +} + +// OversizedMessageError returns an OversizedMessageErr error for the given message +// length if it exceeds the allowed maximum. Otherwise a nil error is returned. +func OversizedMessageError(messageLength int) error { + if messageLength <= messageLengthMax { + return nil + } + + return &OversizedMessageErr{ + messageLength: messageLength, + err: status.Errorf(codes.ResourceExhausted, "message length %v exceed maximum message size of %v", messageLength, messageLengthMax), + } +} + +// Error returns the error message for the corresponding grpc Status for the error. +func (e *OversizedMessageErr) Error() string { + return e.err.Error() +} + +// Unwrap returns the corresponding error with our grpc status code. +func (e *OversizedMessageErr) Unwrap() error { + return e.err +} + +// RejectedLength retrieves the rejected message length which triggered the error. +func (e *OversizedMessageErr) RejectedLength() int { + return e.messageLength +} + +// MaximumLength retrieves the maximum allowed message length that triggered the error. +func (*OversizedMessageErr) MaximumLength() int { + return messageLengthMax +}