From 09f3004fce1ceb2ec686d644439c5135687d29bf Mon Sep 17 00:00:00 2001 From: Billy Keyes Date: Sat, 1 Feb 2020 10:58:43 -0800 Subject: [PATCH] Fix flushing when data is larger than the buffer (#2) copyFrom and copyLinesFrom did not increment the start offset after reading, so the same lines were copied over and over again. --- gitdiff/io.go | 11 ++++-- gitdiff/io_test.go | 86 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 95 insertions(+), 2 deletions(-) diff --git a/gitdiff/io.go b/gitdiff/io.go index 515fb10..af5c847 100644 --- a/gitdiff/io.go +++ b/gitdiff/io.go @@ -147,11 +147,16 @@ func isLen(r io.ReaderAt, n int64) (bool, error) { return false, err } +const ( + byteBufferSize = 32 * 1024 // from io.Copy + lineBufferSize = 32 +) + // copyFrom writes bytes starting from offset off in src to dst stopping at the // end of src or at the first error. copyFrom returns the number of bytes // written and any error. func copyFrom(dst io.Writer, src io.ReaderAt, off int64) (written int64, err error) { - buf := make([]byte, 32*1024) // stolen from io.Copy + buf := make([]byte, byteBufferSize) for { nr, rerr := src.ReadAt(buf, off) if nr > 0 { @@ -167,6 +172,7 @@ func copyFrom(dst io.Writer, src io.ReaderAt, off int64) (written int64, err err err = io.ErrShortWrite break } + off += int64(nr) } if rerr != nil { if rerr != io.EOF { @@ -182,7 +188,7 @@ func copyFrom(dst io.Writer, src io.ReaderAt, off int64) (written int64, err err // the end of src or at the first error. copyLinesFrom returns the number of // lines written and any error. func copyLinesFrom(dst io.Writer, src LineReaderAt, off int64) (written int64, err error) { - buf := make([][]byte, 32) + buf := make([][]byte, lineBufferSize) ReadLoop: for { nr, rerr := src.ReadLinesAt(buf, off) @@ -201,6 +207,7 @@ ReadLoop: break ReadLoop } } + off += int64(nr) } if rerr != nil { if rerr != io.EOF { diff --git a/gitdiff/io_test.go b/gitdiff/io_test.go index 1d1050d..8d8a18b 100644 --- a/gitdiff/io_test.go +++ b/gitdiff/io_test.go @@ -4,6 +4,7 @@ import ( "bytes" "fmt" "io" + "math/rand" "testing" ) @@ -114,3 +115,88 @@ func TestLineReaderAt(t *testing.T) { }) } } + +func TestCopyFrom(t *testing.T) { + tests := map[string]struct { + Bytes int64 + Offset int64 + }{ + "copyAll": { + Bytes: byteBufferSize / 2, + }, + "copyPartial": { + Bytes: byteBufferSize / 2, + Offset: byteBufferSize / 4, + }, + "copyLarge": { + Bytes: 8 * byteBufferSize, + }, + } + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + data := make([]byte, test.Bytes) + rand.Read(data) + + var dst bytes.Buffer + n, err := copyFrom(&dst, bytes.NewReader(data), test.Offset) + if err != nil { + t.Fatalf("unexpected error copying data: %v", err) + } + if n != test.Bytes-test.Offset { + t.Fatalf("incorrect number of bytes copied: expected %d, actual %d", test.Bytes-test.Offset, n) + } + + expected := data[test.Offset:] + if !bytes.Equal(expected, dst.Bytes()) { + t.Fatalf("incorrect data copied:\nexpected: %v\nactual: %v", expected, dst.Bytes()) + } + }) + } +} + +func TestCopyLinesFrom(t *testing.T) { + tests := map[string]struct { + Lines int64 + Offset int64 + }{ + "copyAll": { + Lines: lineBufferSize / 2, + }, + "copyPartial": { + Lines: lineBufferSize / 2, + Offset: lineBufferSize / 4, + }, + "copyLarge": { + Lines: 8 * lineBufferSize, + }, + } + + const lineLength = 128 + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + data := make([]byte, test.Lines*lineLength) + for i := range data { + data[i] = byte(32 + rand.Intn(95)) // ascii letters, numbers, symbols + if i%lineLength == lineLength-1 { + data[i] = '\n' + } + } + + var dst bytes.Buffer + n, err := copyLinesFrom(&dst, &lineReaderAt{r: bytes.NewReader(data)}, test.Offset) + if err != nil { + t.Fatalf("unexpected error copying data: %v", err) + } + if n != test.Lines-test.Offset { + t.Fatalf("incorrect number of lines copied: expected %d, actual %d", test.Lines-test.Offset, n) + } + + expected := data[test.Offset*lineLength:] + if !bytes.Equal(expected, dst.Bytes()) { + t.Fatalf("incorrect data copied:\nexpected: %v\nactual: %v", expected, dst.Bytes()) + } + }) + } +}