diff --git a/ircutils/sasl.go b/ircutils/sasl.go index a4c612c..31d3d0d 100644 --- a/ircutils/sasl.go +++ b/ircutils/sasl.go @@ -3,7 +3,6 @@ package ircutils import ( "encoding/base64" "errors" - "strings" ) var ( @@ -25,6 +24,7 @@ func EncodeSASLResponse(raw []byte) (result []string) { } response := base64.StdEncoding.EncodeToString(raw) + result = make([]string, 0, (len(response)/400)+1) lastLen := 0 for len(response) > 0 { // TODO once we require go 1.21, this can be: lastLen = min(len(response), 400) @@ -48,11 +48,11 @@ func EncodeSASLResponse(raw []byte) (result []string) { // Do not copy a SASLBuffer after first use. type SASLBuffer struct { maxLength int - buffer strings.Builder + buf []byte } // NewSASLBuffer returns a new SASLBuffer. maxLength is the maximum amount of -// base64'ed data to buffer (0 for no limit). +// data to buffer (0 for no limit). func NewSASLBuffer(maxLength int) *SASLBuffer { result := new(SASLBuffer) result.Initialize(maxLength) @@ -69,37 +69,43 @@ func (b *SASLBuffer) Initialize(maxLength int) { // response along with any decoding or protocol errors detected. func (b *SASLBuffer) Add(value string) (done bool, output []byte, err error) { if value == "+" { - output, err = b.getAndReset() - return true, output, err + // total size is a multiple of 400 (possibly 0) + output = b.buf + b.Clear() + return true, output, nil } if len(value) > 400 { - b.buffer.Reset() + b.Clear() return true, nil, ErrSASLTooLong } - if b.maxLength != 0 && (b.buffer.Len()+len(value)) > b.maxLength { - b.buffer.Reset() + curLen := len(b.buf) + chunkDecodedLen := base64.StdEncoding.DecodedLen(len(value)) + if b.maxLength != 0 && (curLen+chunkDecodedLen) > b.maxLength { + b.Clear() return true, nil, ErrSASLLimitExceeded } - b.buffer.WriteString(value) + // "append-make pattern" as in the bytes.Buffer implementation: + b.buf = append(b.buf, make([]byte, chunkDecodedLen)...) + n, err := base64.StdEncoding.Decode(b.buf[curLen:], []byte(value)) + b.buf = b.buf[0 : curLen+n] + if err != nil { + b.Clear() + return true, nil, err + } if len(value) < 400 { - output, err = b.getAndReset() - return true, output, err + output = b.buf + b.Clear() + return true, output, nil } else { - // 400 bytes, wait for continuation line or + return false, nil, nil } } // Clear resets the buffer state. func (b *SASLBuffer) Clear() { - b.buffer.Reset() -} - -func (b *SASLBuffer) getAndReset() (output []byte, err error) { - output, err = base64.StdEncoding.DecodeString(b.buffer.String()) - b.buffer.Reset() - return + // we can't reuse this buffer in general since we may have returned it + b.buf = nil } diff --git a/ircutils/sasl_test.go b/ircutils/sasl_test.go index 9f90f64..8248fa5 100644 --- a/ircutils/sasl_test.go +++ b/ircutils/sasl_test.go @@ -31,7 +31,7 @@ func TestSplitResponse(t *testing.T) { } func TestBuffer(t *testing.T) { - b := NewSASLBuffer(1600) + b := NewSASLBuffer(1200) // less than 400 bytes done, output, err := b.Add("c2hpdmFyYW0Ac2hpdmFyYW0Ac2hpdmFyYW1wYXNzcGhyYXNl") @@ -58,7 +58,7 @@ func TestBuffer(t *testing.T) { // a single + done, output, err = b.Add("+") assertEqual(done, true) - assertEqual(len(output), 0) + assertEqual(output, []byte(nil)) assertEqual(err, nil) // length limit