From 618e30080314412245b6eb06e1d26bf8db80fce5 Mon Sep 17 00:00:00 2001 From: Bartlomiej Plotka Date: Wed, 6 Jul 2022 21:57:11 +0200 Subject: [PATCH] re-implement ReaderAt implementation with race protection (#1673) Signed-off-by: bwplotka Co-authored-by: Harshavardhana --- api-put-object-streaming.go | 50 +++++++++++++++---------------------- hook-reader.go | 48 +++++++++++++++++++++++------------ 2 files changed, 52 insertions(+), 46 deletions(-) diff --git a/api-put-object-streaming.go b/api-put-object-streaming.go index 7f145bb97..11b3a5255 100644 --- a/api-put-object-streaming.go +++ b/api-put-object-streaming.go @@ -130,32 +130,32 @@ func (c *Client) putObjectMultipartStreamFromReadAt(ctx context.Context, bucketN var complMultipartUpload completeMultipartUpload // Declare a channel that sends the next part number to be uploaded. - // Buffered to 10000 because thats the maximum number of parts allowed - // by S3. - uploadPartsCh := make(chan uploadPartReq, 10000) + uploadPartsCh := make(chan uploadPartReq) // Declare a channel that sends back the response of a part upload. - // Buffered to 10000 because thats the maximum number of parts allowed - // by S3. - uploadedPartsCh := make(chan uploadedPartRes, 10000) + uploadedPartsCh := make(chan uploadedPartRes) // Used for readability, lastPartNumber is always totalPartsCount. lastPartNumber := totalPartsCount + partitionCtx, partitionCancel := context.WithCancel(ctx) + defer partitionCancel() // Send each part number to the channel to be processed. - for p := 1; p <= totalPartsCount; p++ { - uploadPartsCh <- uploadPartReq{PartNum: p} - } - close(uploadPartsCh) - - partsBuf := make([][]byte, opts.getNumThreads()) - for i := range partsBuf { - partsBuf[i] = make([]byte, 0, partSize) - } + go func() { + defer close(uploadPartsCh) + + for p := 1; p <= totalPartsCount; p++ { + select { + case <-partitionCtx.Done(): + return + case uploadPartsCh <- uploadPartReq{PartNum: p}: + } + } + }() // Receive each part number from the channel allowing three parallel uploads. for w := 1; w <= opts.getNumThreads(); w++ { - go func(w int, partSize int64) { + go func(partSize int64) { for { var uploadReq uploadPartReq var ok bool @@ -181,21 +181,11 @@ func (c *Client) putObjectMultipartStreamFromReadAt(ctx context.Context, bucketN partSize = lastPartSize } - n, rerr := readFull(io.NewSectionReader(reader, readOffset, partSize), partsBuf[w-1][:partSize]) - if rerr != nil && rerr != io.ErrUnexpectedEOF && rerr != io.EOF { - uploadedPartsCh <- uploadedPartRes{ - Error: rerr, - } - // Exit the goroutine. - return - } - - // Get a section reader on a particular offset. - hookReader := newHook(bytes.NewReader(partsBuf[w-1][:n]), opts.Progress) + sectionReader := newHook(io.NewSectionReader(reader, readOffset, partSize), opts.Progress) // Proceed to upload the part. objPart, err := c.uploadPart(ctx, bucketName, objectName, - uploadID, hookReader, uploadReq.PartNum, + uploadID, sectionReader, uploadReq.PartNum, "", "", partSize, opts.ServerSideEncryption, !opts.DisableContentSha256, @@ -218,7 +208,7 @@ func (c *Client) putObjectMultipartStreamFromReadAt(ctx context.Context, bucketN Part: uploadReq.Part, } } - }(w, partSize) + }(partSize) } // Gather the responses as they occur and update any @@ -229,12 +219,12 @@ func (c *Client) putObjectMultipartStreamFromReadAt(ctx context.Context, bucketN return UploadInfo{}, ctx.Err() case uploadRes := <-uploadedPartsCh: if uploadRes.Error != nil { + return UploadInfo{}, uploadRes.Error } // Update the totalUploadedSize. totalUploadedSize += uploadRes.Size - // Store the parts to be completed in order. complMultipartUpload.Parts = append(complMultipartUpload.Parts, CompletePart{ ETag: uploadRes.Part.ETag, PartNumber: uploadRes.Part.PartNumber, diff --git a/hook-reader.go b/hook-reader.go index f251c1e95..07bc7dbcf 100644 --- a/hook-reader.go +++ b/hook-reader.go @@ -20,6 +20,7 @@ package minio import ( "fmt" "io" + "sync" ) // hookReader hooks additional reader in the source stream. It is @@ -27,6 +28,7 @@ import ( // notified about the exact number of bytes read from the primary // source on each Read operation. type hookReader struct { + mu sync.RWMutex source io.Reader hook io.Reader } @@ -34,6 +36,9 @@ type hookReader struct { // Seek implements io.Seeker. Seeks source first, and if necessary // seeks hook if Seek method is appropriately found. func (hr *hookReader) Seek(offset int64, whence int) (n int64, err error) { + hr.mu.Lock() + defer hr.mu.Unlock() + // Verify for source has embedded Seeker, use it. sourceSeeker, ok := hr.source.(io.Seeker) if ok { @@ -43,18 +48,21 @@ func (hr *hookReader) Seek(offset int64, whence int) (n int64, err error) { } } - // Verify if hook has embedded Seeker, use it. - hookSeeker, ok := hr.hook.(io.Seeker) - if ok { - var m int64 - m, err = hookSeeker.Seek(offset, whence) - if err != nil { - return 0, err - } - if n != m { - return 0, fmt.Errorf("hook seeker seeked %d bytes, expected source %d bytes", m, n) + if hr.hook != nil { + // Verify if hook has embedded Seeker, use it. + hookSeeker, ok := hr.hook.(io.Seeker) + if ok { + var m int64 + m, err = hookSeeker.Seek(offset, whence) + if err != nil { + return 0, err + } + if n != m { + return 0, fmt.Errorf("hook seeker seeked %d bytes, expected source %d bytes", m, n) + } } } + return n, nil } @@ -62,14 +70,19 @@ func (hr *hookReader) Seek(offset int64, whence int) (n int64, err error) { // value 'n' number of bytes are reported through the hook. Returns // error for all non io.EOF conditions. func (hr *hookReader) Read(b []byte) (n int, err error) { + hr.mu.RLock() + defer hr.mu.RUnlock() + n, err = hr.source.Read(b) if err != nil && err != io.EOF { return n, err } - // Progress the hook with the total read bytes from the source. - if _, herr := hr.hook.Read(b[:n]); herr != nil { - if herr != io.EOF { - return n, herr + if hr.hook != nil { + // Progress the hook with the total read bytes from the source. + if _, herr := hr.hook.Read(b[:n]); herr != nil { + if herr != io.EOF { + return n, herr + } } } return n, err @@ -79,7 +92,10 @@ func (hr *hookReader) Read(b []byte) (n int, err error) { // reports the data read from the source to the hook. func newHook(source, hook io.Reader) io.Reader { if hook == nil { - return source + return &hookReader{source: source} + } + return &hookReader{ + source: source, + hook: hook, } - return &hookReader{source, hook} }