From e4a6245f352ff245fdd1153b104f5f12b02bb6d6 Mon Sep 17 00:00:00 2001 From: Justin Israel Date: Wed, 28 Jun 2023 16:00:38 +1200 Subject: [PATCH 1/4] Implement optional concurrent "Range" requests (refs #86) --- v3/client.go | 83 ++++++++++---- v3/client_test.go | 58 ++++++++++ v3/go.mod | 2 + v3/go.sum | 2 + v3/pkg/grabtest/handler.go | 49 +++++++- v3/request.go | 9 ++ v3/response.go | 10 +- v3/transfer.go | 227 +++++++++++++++++++++++++++++++++---- 8 files changed, 387 insertions(+), 53 deletions(-) create mode 100644 v3/go.sum diff --git a/v3/client.go b/v3/client.go index d960815..476420f 100644 --- a/v3/client.go +++ b/v3/client.go @@ -5,6 +5,7 @@ import ( "context" "fmt" "io" + "net" "net/http" "os" "path/filepath" @@ -48,11 +49,23 @@ type Client struct { // NewClient returns a new file download Client, using default configuration. func NewClient() *Client { + dialer := &net.Dialer{} return &Client{ UserAgent: "grab", HTTPClient: &http.Client{ Transport: &http.Transport{ Proxy: http.ProxyFromEnvironment, + DialContext: func(ctx context.Context, network string, addr string) (net.Conn, error) { + conn, err := dialer.DialContext(ctx, network, addr) + if err == nil { + // Default net.TCPConn calls SetNoDelay(true) + // which likely could be an impact on performance + // with large file downloads, and many ACKs on networks + // with higher latency + err = conn.(*net.TCPConn).SetNoDelay(false) + } + return conn, err + }, }, }, } @@ -86,6 +99,7 @@ func (c *Client) Do(req *Request) *Response { ctx: ctx, cancel: cancel, bufferSize: req.BufferSize, + transfer: (*transfer)(nil), } if resp.bufferSize == 0 { // default to Client.BufferSize @@ -330,13 +344,19 @@ func (c *Client) headRequest(resp *Response) stateFunc { } resp.optionsKnown = true - if resp.Request.NoResume { - return c.getRequest - } + // If we are going to do a range request, then we need to perform + // the HEAD req to check for support. + // Otherwise, we may not need to do the HEAD request if we have + // enough information already. + if resp.Request.RangeRequestMax <= 0 { + if resp.Request.NoResume { + return c.getRequest + } - if resp.Filename != "" && resp.fi == nil { - // destination path is already known and does not exist - return c.getRequest + if resp.Filename != "" && resp.fi == nil { + // destination path is already known and does not exist + return c.getRequest + } } hreq := new(http.Request) @@ -365,6 +385,13 @@ func (c *Client) headRequest(resp *Response) stateFunc { } func (c *Client) getRequest(resp *Response) stateFunc { + if resp.Request.RangeRequestMax > 0 && resp.acceptRanges { + // For a concurrent range request, we don't do a single + // GET request here. It will be handled later in the transfer, + // based on the HEAD response + return c.openWriter + } + resp.HTTPResponse, resp.err = c.doHTTPRequest(resp.Request.HTTPRequest) if resp.err != nil { return c.closeResponse @@ -410,11 +437,12 @@ func (c *Client) readResponse(resp *Response) stateFunc { resp.Filename = filepath.Join(resp.Request.Filename, filename) } - if !resp.Request.NoStore && resp.requestMethod() == "HEAD" { - if resp.HTTPResponse.Header.Get("Accept-Ranges") == "bytes" { - resp.CanResume = true + if resp.requestMethod() == "HEAD" { + resp.acceptRanges = resp.HTTPResponse.Header.Get("Accept-Ranges") == "bytes" + if !resp.Request.NoStore { + resp.CanResume = resp.acceptRanges + return c.statFileInfo } - return c.statFileInfo } return c.openWriter } @@ -431,6 +459,12 @@ func (c *Client) openWriter(resp *Response) stateFunc { } } + if resp.bufferSize < 1 { + resp.bufferSize = 32 * 1024 + } + + var writerAt io.WriterAt + if resp.Request.NoStore { resp.writer = &resp.storeBuffer } else { @@ -453,11 +487,12 @@ func (c *Client) openWriter(resp *Response) stateFunc { return c.closeResponse } resp.writer = f + writerAt = f // seek to start or end - whence := os.SEEK_SET + whence := io.SeekStart if resp.bytesResumed > 0 { - whence = os.SEEK_END + whence = io.SeekEnd } _, resp.err = f.Seek(0, whence) if resp.err != nil { @@ -469,13 +504,21 @@ func (c *Client) openWriter(resp *Response) stateFunc { if resp.bufferSize < 1 { resp.bufferSize = 32 * 1024 } - b := make([]byte, resp.bufferSize) - resp.transfer = newTransfer( - resp.Request.Context(), - resp.Request.RateLimiter, - resp.writer, - resp.HTTPResponse.Body, - b) + + if resp.Request.RangeRequestMax > 0 && resp.acceptRanges && writerAt != nil { + // TODO: should we inspect resp.HTTPResponse.ContentLength + // and have a threshold under which a certain size should + // not use range requests? ie < 1MB? 256KB? + resp.transfer = newTransferRanges(c.HTTPClient, resp, writerAt) + + } else { + resp.transfer = newTransfer( + resp.Request.Context(), + resp.Request.RateLimiter, + resp.writer, + resp.HTTPResponse.Body, + resp.bufferSize) + } // next step is copyFile, but this will be called later in another goroutine return nil @@ -507,7 +550,7 @@ func (c *Client) copyFile(resp *Response) stateFunc { t.Truncate(0) } - bytesCopied, resp.err = resp.transfer.copy() + bytesCopied, resp.err = resp.transfer.Copy() if resp.err != nil { return c.closeResponse } diff --git a/v3/client_test.go b/v3/client_test.go index cbd0a81..04090d4 100644 --- a/v3/client_test.go +++ b/v3/client_test.go @@ -912,3 +912,61 @@ func TestNoStore(t *testing.T) { }) }) } + +// TestRangeRequest tests the option of using parallel range requests to download +// chunks of the remote resource +func TestRangeRequest(t *testing.T) { + size := int64(32768) + testCases := []struct { + Name string + Chunks int + StatusCode int + }{ + {"NumChunksNeg", -1, http.StatusOK}, + {"NumChunks0", 0, http.StatusOK}, + {"NumChunks1", 1, http.StatusPartialContent}, + {"NumChunks5", 5, http.StatusPartialContent}, + } + + for _, test := range testCases { + t.Run(test.Name, func(t *testing.T) { + opts := []grabtest.HandlerOption{ + grabtest.ContentLength(int(size)), + grabtest.StatusCode(func(r *http.Request) int { + if test.Chunks > 0 { + return http.StatusPartialContent + } + return http.StatusOK + }), + } + + grabtest.WithTestServer(t, func(url string) { + name := fmt.Sprintf(".testRangeRequest-%s", test.Name) + req := mustNewRequest(name, url) + req.RangeRequestMax = test.Chunks + + resp := DefaultClient.Do(req) + defer os.Remove(resp.Filename) + + err := resp.Err() + if err == ErrBadLength { + t.Errorf("error: %v", err) + } else if err != nil { + panic(err) + } else if resp.Size() != size { + t.Errorf("expected %v bytes, got %v bytes", size, resp.Size()) + } + + if resp.HTTPResponse.StatusCode != test.StatusCode { + t.Errorf("expected status code %v, got %d", test.StatusCode, resp.HTTPResponse.StatusCode) + } + + if bps := resp.BytesPerSecond(); bps <= 0 { + t.Errorf("expected BytesPerSecond > 0, got %v", bps) + } + + testComplete(t, resp) + }, opts...) + }) + } +} diff --git a/v3/go.mod b/v3/go.mod index be7f202..5cac744 100644 --- a/v3/go.mod +++ b/v3/go.mod @@ -1,3 +1,5 @@ module github.com/cavaliergopher/grab/v3 go 1.14 + +require golang.org/x/sync v0.3.0 diff --git a/v3/go.sum b/v3/go.sum new file mode 100644 index 0000000..4c4db29 --- /dev/null +++ b/v3/go.sum @@ -0,0 +1,2 @@ +golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E= +golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= diff --git a/v3/pkg/grabtest/handler.go b/v3/pkg/grabtest/handler.go index efe829c..b11ddae 100644 --- a/v3/pkg/grabtest/handler.go +++ b/v3/pkg/grabtest/handler.go @@ -5,6 +5,8 @@ import ( "fmt" "net/http" "net/http/httptest" + "strconv" + "strings" "testing" "time" ) @@ -106,20 +108,55 @@ func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { w.Header().Set("Last-Modified", lastMod.Format(http.TimeFormat)) // set content-length - offset := 0 + var offset int64 + length := int64(h.contentLength) if h.acceptRanges { if reqRange := r.Header.Get("Range"); reqRange != "" { - if _, err := fmt.Sscanf(reqRange, "bytes=%d-", &offset); err != nil { + const b = `bytes=` + var limit int64 + start, end, ok := strings.Cut(reqRange[len(b):], "-") + if !ok { httpError(w, http.StatusBadRequest) return } - if offset >= h.contentLength { - httpError(w, http.StatusRequestedRangeNotSatisfiable) + var err error + if start != "" { + offset, err = strconv.ParseInt(start, 10, 64) + if err != nil { + httpError(w, http.StatusBadRequest) + return + } + if offset > length { + offset = length + } + } + if end != "" { + limit, err = strconv.ParseInt(end, 10, 64) + if err != nil { + httpError(w, http.StatusBadRequest) + return + } + } + + if start != "" && end == "" { + length = length - offset + } else if start == "" && end != "" { + // unsupported range format: - + httpError(w, http.StatusBadRequest) + } else { + length = limit - offset + } + + if length > int64(h.contentLength) { + code := http.StatusRequestedRangeNotSatisfiable + msg := fmt.Sprintf("%s: requested range length %d "+ + "is greater than ContentLength %d", http.StatusText(code), length, h.contentLength) + http.Error(w, msg, code) return } } } - w.Header().Set("Content-Length", fmt.Sprintf("%d", h.contentLength-offset)) + w.Header().Set("Content-Length", fmt.Sprintf("%d", length)) // apply header blacklist for _, key := range h.headerBlacklist { @@ -133,7 +170,7 @@ func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { if r.Method == "GET" { // use buffered io to reduce overhead on the reader bw := bufio.NewWriterSize(w, 4096) - for i := offset; !isRequestClosed(r) && i < h.contentLength; i++ { + for i := offset; !isRequestClosed(r) && i < int64(offset+length); i++ { bw.Write([]byte{byte(i)}) if h.rateLimiter != nil { bw.Flush() diff --git a/v3/request.go b/v3/request.go index f86cfc3..33292ff 100644 --- a/v3/request.go +++ b/v3/request.go @@ -99,6 +99,15 @@ type Request struct { // the Response object. AfterCopy Hook + // RangeRequestMax enables the use of "Range" requests, if supported by the + // server, to download multiple chunks. A value > 0 defines how many chunks + // to execute concurrently. + // If the server does not support "accept-range", then the original single + // request behaviour is used. + // Note that the BufferSize will be applied as a separate buffer for each of + // the concurrent range request chunks + RangeRequestMax int + // hash, checksum and deleteOnError - set via SetChecksum. hash hash.Hash checksum []byte diff --git a/v3/response.go b/v3/response.go index 05bbca1..75f476b 100644 --- a/v3/response.go +++ b/v3/response.go @@ -47,6 +47,10 @@ type Response struct { // previous downloads, as the 'Accept-Ranges: bytes' header is set. CanResume bool + // specifies that the remote server advertised that it supports partial + // requests, as the 'Accept-Ranges: bytes' header is set. + acceptRanges bool + // DidResume specifies that the file transfer resumed a previously incomplete // transfer. DidResume bool @@ -84,7 +88,7 @@ type Response struct { // transfer is responsible for copying data from the remote server to a local // file, tracking progress and allowing for cancelation. - transfer *transfer + transfer transferer // bufferSize specifies the size in bytes of the transfer buffer. bufferSize int @@ -242,8 +246,8 @@ func (c *Response) checksumUnsafe() ([]byte, error) { return nil, err } defer f.Close() - t := newTransfer(c.Request.Context(), nil, c.Request.hash, f, nil) - if _, err = t.copy(); err != nil { + t := newTransfer(c.Request.Context(), nil, c.Request.hash, f, 0) + if _, err = t.Copy(); err != nil { return nil, err } sum := c.Request.hash.Sum(nil) diff --git a/v3/transfer.go b/v3/transfer.go index d938b49..ddfb3a1 100644 --- a/v3/transfer.go +++ b/v3/transfer.go @@ -2,47 +2,76 @@ package grab import ( "context" + "errors" + "fmt" "io" + "net/http" "sync/atomic" "time" + "golang.org/x/sync/errgroup" + "github.com/cavaliergopher/grab/v3/pkg/bps" ) +type transferer interface { + // Copy performs the bytes copy from the reader to the writer, + // reports the progress, and transfer rate. + // Returns bytes written and error, using same behaviour as io.Copy.Buffer + Copy() (written int64, err error) + // N returns the number of bytes transferred. + N() int64 + // BPS returns the current bytes per second transfer rate using a simple moving average. + BPS() float64 +} + +func newGauge() bps.Gauge { + // five second moving average sampling every second + return bps.NewSMA(6) + +} + type transfer struct { - n int64 // must be 64bit aligned on 386 - ctx context.Context - gauge bps.Gauge - lim RateLimiter - w io.Writer - r io.Reader - b []byte + n int64 // must be 64bit aligned on 386 + ctx context.Context + gauge bps.Gauge + lim RateLimiter + w io.Writer + r io.Reader + bufSize int } -func newTransfer(ctx context.Context, lim RateLimiter, dst io.Writer, src io.Reader, buf []byte) *transfer { +func newTransfer(ctx context.Context, lim RateLimiter, dst io.Writer, src io.Reader, bufSize int) *transfer { return &transfer{ - ctx: ctx, - gauge: bps.NewSMA(6), // five second moving average sampling every second - lim: lim, - w: dst, - r: src, - b: buf, + ctx: ctx, + gauge: newGauge(), + lim: lim, + w: dst, + r: src, + bufSize: bufSize, } } -// copy behaves similarly to io.CopyBuffer except that it checks for cancelation +// Copy behaves similarly to io.CopyBuffer except that it checks for cancelation // of the given context.Context, reports progress in a thread-safe manner and // tracks the transfer rate. -func (c *transfer) copy() (written int64, err error) { +func (c *transfer) Copy() (written int64, err error) { + if c == nil { + return 0, errors.New("nil *transfer instance") + } + // maintain a bps gauge in another goroutine ctx, cancel := context.WithCancel(c.ctx) defer cancel() go bps.Watch(ctx, c.gauge, c.N, time.Second) // start the transfer - if c.b == nil { - c.b = make([]byte, 32*1024) + bufSize := c.bufSize + if bufSize < 1 { + bufSize = 32 * 1024 } + buf := make([]byte, bufSize) + for { select { case <-c.ctx.Done(): @@ -51,9 +80,9 @@ func (c *transfer) copy() (written int64, err error) { default: // keep working } - nr, er := c.r.Read(c.b) + nr, er := c.r.Read(buf) if nr > 0 { - nw, ew := c.w.Write(c.b[0:nr]) + nw, ew := c.w.Write(buf[0:nr]) if nw > 0 { written += int64(nw) atomic.StoreInt64(&c.n, written) @@ -85,17 +114,167 @@ func (c *transfer) copy() (written int64, err error) { } // N returns the number of bytes transferred. -func (c *transfer) N() (n int64) { +func (c *transfer) N() int64 { if c == nil { return 0 } - n = atomic.LoadInt64(&c.n) - return + return atomic.LoadInt64(&c.n) } // BPS returns the current bytes per second transfer rate using a simple moving // average. -func (c *transfer) BPS() (bps float64) { +func (c *transfer) BPS() float64 { + if c == nil || c.gauge == nil { + return 0 + } + return c.gauge.BPS() +} + +type transferRanges struct { + n int64 // must be 64bit aligned on 386 + ctx context.Context + client HTTPClient + gauge bps.Gauge + lim RateLimiter + w io.WriterAt + r *http.Request + length int64 + workers int + bufSize int +} + +func newTransferRanges(client HTTPClient, headResp *Response, dst io.WriterAt) *transferRanges { + return &transferRanges{ + ctx: headResp.Request.Context(), + client: client, + gauge: newGauge(), + lim: headResp.Request.RateLimiter, + w: dst, + r: headResp.Request.HTTPRequest, + length: headResp.HTTPResponse.ContentLength, + workers: headResp.Request.RangeRequestMax, + bufSize: headResp.bufferSize, + } +} + +// Copy performs concurrent http Range requests to transfer chunks and write them at +// offsets to the WriterAt instance. +// Checks for cancelation of the given context.Context, reports progress in a +// thread-safe manner and tracks the transfer rate. +func (c *transferRanges) Copy() (written int64, err error) { + if c == nil { + return 0, errors.New("nil *transferRanges instance") + } + + if c.length == 0 { + err = errors.New("cannot transfer ranges: ContentLength is 0") + return + } + + if c.workers < 1 { + c.workers = 1 + } + + if c.bufSize < 1 { + c.bufSize = 32 * 1024 + } + + c.n = 0 + + // maintain a bps gauge in another goroutine + ctx, cancel := context.WithCancel(c.ctx) + defer cancel() + go bps.Watch(ctx, c.gauge, c.N, time.Second) + + wg, ctx := errgroup.WithContext(ctx) + + chunkSize := c.length / int64(c.workers) + var start, end int64 + for i := 1; i <= c.workers; i++ { + if i == c.workers { + end = c.length + } else { + end = start + chunkSize + } + offset := start + limit := end + wg.Go(func() error { + return c.requestChunk(ctx, offset, limit) + }) + start = end + } + + err = wg.Wait() + return c.N(), err +} + +func (c *transferRanges) requestChunk(ctx context.Context, offset, limit int64) error { + req := c.r.Clone(ctx) + req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", offset, limit)) + + resp, err := c.client.Do(req) + if err != nil { + return err + } + if resp.StatusCode != http.StatusPartialContent { + return fmt.Errorf("server responded with %d status code, expected %d for range request", + resp.StatusCode, http.StatusPartialContent) + } + + defer resp.Body.Close() + + // start the transfer + buf := make([]byte, c.bufSize) + + for { + select { + case <-ctx.Done(): + return ctx.Err() + default: + // keep working + } + nr, er := resp.Body.Read(buf) + if nr > 0 { + nw, ew := c.w.WriteAt(buf[0:nr], offset) + if nw > 0 { + atomic.AddInt64(&c.n, int64(nw)) + } + if ew != nil { + return ew + } + if nr != nw { + return io.ErrShortWrite + } + offset += int64(nw) + // wait for rate limiter + if c.lim != nil { + if er = c.lim.WaitN(ctx, nr); er != nil { + return er + } + } + } + if er != nil { + if er != io.EOF { + return er + } + break + } + } + + return nil +} + +// N returns the total number of bytes transferred across all concurrent chunks. +func (c *transferRanges) N() int64 { + if c == nil { + return 0 + } + return atomic.LoadInt64(&c.n) +} + +// BPS returns the current bytes per second transfer rate using a simple moving +// average, across all concurrent chunks. +func (c *transferRanges) BPS() float64 { if c == nil || c.gauge == nil { return 0 } From 294f943fa7669807822630181058cd46f6d6948a Mon Sep 17 00:00:00 2001 From: Justin Israel Date: Wed, 28 Jun 2023 22:23:40 +0000 Subject: [PATCH 2/4] grabtest: handler should expect status 206 for Range requests --- v3/pkg/grabtest/handler.go | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/v3/pkg/grabtest/handler.go b/v3/pkg/grabtest/handler.go index b11ddae..fcaf82d 100644 --- a/v3/pkg/grabtest/handler.go +++ b/v3/pkg/grabtest/handler.go @@ -35,11 +35,16 @@ type handler struct { func NewHandler(options ...HandlerOption) (http.Handler, error) { h := &handler{ - statusCodeFunc: func(req *http.Request) int { return http.StatusOK }, methodWhitelist: []string{"GET", "HEAD"}, contentLength: DefaultHandlerContentLength, acceptRanges: true, } + h.statusCodeFunc = func(req *http.Request) int { + if h.acceptRanges && strings.HasPrefix(req.Header.Get("Range"), "bytes=") { + return http.StatusPartialContent + } + return http.StatusOK + } for _, option := range options { if err := option(h); err != nil { return nil, err From b11b2d27d9f471cfd6b85aaa91a0730411cce951 Mon Sep 17 00:00:00 2001 From: Justin Israel Date: Tue, 4 Jul 2023 14:50:43 +1200 Subject: [PATCH 3/4] On 86_range_requests: Implement a min size check before deciding to use a "Range" request --- v3/client.go | 34 ++++++++++++++++++---------------- v3/client_test.go | 16 +++++++++++----- v3/request.go | 6 ++++++ 3 files changed, 35 insertions(+), 21 deletions(-) diff --git a/v3/client.go b/v3/client.go index 476420f..42782cd 100644 --- a/v3/client.go +++ b/v3/client.go @@ -386,10 +386,12 @@ func (c *Client) headRequest(resp *Response) stateFunc { func (c *Client) getRequest(resp *Response) stateFunc { if resp.Request.RangeRequestMax > 0 && resp.acceptRanges { - // For a concurrent range request, we don't do a single - // GET request here. It will be handled later in the transfer, - // based on the HEAD response - return c.openWriter + if resp.HTTPResponse.ContentLength >= resp.Request.RangeRequestMinSize { + // For a concurrent range request, we don't do a single + // GET request here. It will be handled later in the transfer, + // based on the HEAD response + return c.openWriter + } } resp.HTTPResponse, resp.err = c.doHTTPRequest(resp.Request.HTTPRequest) @@ -506,20 +508,20 @@ func (c *Client) openWriter(resp *Response) stateFunc { } if resp.Request.RangeRequestMax > 0 && resp.acceptRanges && writerAt != nil { - // TODO: should we inspect resp.HTTPResponse.ContentLength - // and have a threshold under which a certain size should - // not use range requests? ie < 1MB? 256KB? - resp.transfer = newTransferRanges(c.HTTPClient, resp, writerAt) - - } else { - resp.transfer = newTransfer( - resp.Request.Context(), - resp.Request.RateLimiter, - resp.writer, - resp.HTTPResponse.Body, - resp.bufferSize) + if resp.HTTPResponse.ContentLength >= resp.Request.RangeRequestMinSize { + resp.transfer = newTransferRanges(c.HTTPClient, resp, writerAt) + // next step is copyFile, but this will be called later in another goroutine + return nil + } } + resp.transfer = newTransfer( + resp.Request.Context(), + resp.Request.RateLimiter, + resp.writer, + resp.HTTPResponse.Body, + resp.bufferSize) + // next step is copyFile, but this will be called later in another goroutine return nil } diff --git a/v3/client_test.go b/v3/client_test.go index 04090d4..3707010 100644 --- a/v3/client_test.go +++ b/v3/client_test.go @@ -921,11 +921,16 @@ func TestRangeRequest(t *testing.T) { Name string Chunks int StatusCode int + MinSize int64 }{ - {"NumChunksNeg", -1, http.StatusOK}, - {"NumChunks0", 0, http.StatusOK}, - {"NumChunks1", 1, http.StatusPartialContent}, - {"NumChunks5", 5, http.StatusPartialContent}, + {Name: "NumChunksNeg", Chunks: -1, StatusCode: http.StatusOK}, + {Name: "NumChunks0", StatusCode: http.StatusOK}, + {Name: "NumChunks1", Chunks: 1, StatusCode: http.StatusPartialContent}, + {Name: "NumChunks5", Chunks: 5, StatusCode: http.StatusPartialContent}, + + // should not run a Range request because the Content-Length is + // not large enough + {Name: "RangeRequestMinSize", Chunks: 5, MinSize: size + 1, StatusCode: http.StatusOK}, } for _, test := range testCases { @@ -933,7 +938,7 @@ func TestRangeRequest(t *testing.T) { opts := []grabtest.HandlerOption{ grabtest.ContentLength(int(size)), grabtest.StatusCode(func(r *http.Request) int { - if test.Chunks > 0 { + if test.Chunks > 0 && size >= test.MinSize { return http.StatusPartialContent } return http.StatusOK @@ -944,6 +949,7 @@ func TestRangeRequest(t *testing.T) { name := fmt.Sprintf(".testRangeRequest-%s", test.Name) req := mustNewRequest(name, url) req.RangeRequestMax = test.Chunks + req.RangeRequestMinSize = test.MinSize resp := DefaultClient.Do(req) defer os.Remove(resp.Filename) diff --git a/v3/request.go b/v3/request.go index 33292ff..5065ca6 100644 --- a/v3/request.go +++ b/v3/request.go @@ -108,6 +108,12 @@ type Request struct { // the concurrent range request chunks RangeRequestMax int + // RangeRequestMax sets minimum size in bytes for the Content-Length of the + // remote file to be downloaded using a "Range" request. If the file size + // is less than RangeRequestMinSize, then the entire body will be downloaded + // using a normal request. + RangeRequestMinSize int64 + // hash, checksum and deleteOnError - set via SetChecksum. hash hash.Hash checksum []byte From 8b4f8d2a810e32e9cd3860177e05aaeb01dd0b84 Mon Sep 17 00:00:00 2001 From: Justin Israel Date: Sun, 27 Aug 2023 14:21:06 +1200 Subject: [PATCH 4/4] truncate failed range requests to lowest successful offset, to fix resume support --- v3/client.go | 37 ++++++++++------- v3/client_test.go | 101 +++++++++++++++++++++++++++++++++++++++++++++- v3/response.go | 11 ++++- v3/transfer.go | 52 ++++++++++++++++++++++-- 4 files changed, 180 insertions(+), 21 deletions(-) diff --git a/v3/client.go b/v3/client.go index 42782cd..3959de7 100644 --- a/v3/client.go +++ b/v3/client.go @@ -3,6 +3,7 @@ package grab import ( "bytes" "context" + "errors" "fmt" "io" "net" @@ -369,7 +370,8 @@ func (c *Client) headRequest(resp *Response) stateFunc { } resp.HTTPResponse.Body.Close() - if resp.HTTPResponse.StatusCode != http.StatusOK { + if resp.HTTPResponse.StatusCode != http.StatusOK && + resp.HTTPResponse.StatusCode != http.StatusPartialContent { return c.getRequest } @@ -385,13 +387,11 @@ func (c *Client) headRequest(resp *Response) stateFunc { } func (c *Client) getRequest(resp *Response) stateFunc { - if resp.Request.RangeRequestMax > 0 && resp.acceptRanges { - if resp.HTTPResponse.ContentLength >= resp.Request.RangeRequestMinSize { - // For a concurrent range request, we don't do a single - // GET request here. It will be handled later in the transfer, - // based on the HEAD response - return c.openWriter - } + if resp.isRangeRequest() { + // For a concurrent range request, we don't do a single + // GET request here. It will be handled later in the transfer, + // based on the HEAD response + return c.openWriter } resp.HTTPResponse, resp.err = c.doHTTPRequest(resp.Request.HTTPRequest) @@ -473,7 +473,7 @@ func (c *Client) openWriter(resp *Response) stateFunc { // compute write flags flag := os.O_CREATE | os.O_WRONLY if resp.fi != nil { - if resp.DidResume { + if resp.DidResume && !resp.isRangeRequest() { flag = os.O_APPEND | os.O_WRONLY } else { // truncate later in copyFile, if not cancelled @@ -507,12 +507,10 @@ func (c *Client) openWriter(resp *Response) stateFunc { resp.bufferSize = 32 * 1024 } - if resp.Request.RangeRequestMax > 0 && resp.acceptRanges && writerAt != nil { - if resp.HTTPResponse.ContentLength >= resp.Request.RangeRequestMinSize { - resp.transfer = newTransferRanges(c.HTTPClient, resp, writerAt) - // next step is copyFile, but this will be called later in another goroutine - return nil - } + if resp.isRangeRequest() && writerAt != nil { + resp.transfer = newTransferRanges(c.HTTPClient, resp, writerAt) + // next step is copyFile, but this will be called later in another goroutine + return nil } resp.transfer = newTransfer( @@ -554,6 +552,15 @@ func (c *Client) copyFile(resp *Response) stateFunc { bytesCopied, resp.err = resp.transfer.Copy() if resp.err != nil { + // If we ran parallel ranges and some of them failed, we need + // to truncate the file to the lowest successful range to avoid + // having any gaps during a subsequent resume operation. + var rangesErr transferRangesErr + if errors.As(resp.err, &rangesErr) { + if t, ok := resp.writer.(truncater); ok { + t.Truncate(rangesErr.LastOffsetEnd) + } + } return c.closeResponse } closeWriter(resp) diff --git a/v3/client_test.go b/v3/client_test.go index 3707010..6dd653f 100644 --- a/v3/client_test.go +++ b/v3/client_test.go @@ -16,6 +16,7 @@ import ( "os" "path/filepath" "strings" + "sync" "testing" "time" @@ -206,8 +207,8 @@ func TestAutoResume(t *testing.T) { segs := 8 size := 1048576 sum := grabtest.DefaultHandlerSHA256ChecksumBytes //grab/v3test.MustHexDecodeString("fbbab289f7f94b25736c58be46a994c441fd02552cc6022352e3d86d2fab7c83") - filename := ".testAutoResume" + filename := ".testAutoResume" defer os.Remove(filename) for i := 0; i < segs; i++ { @@ -229,6 +230,31 @@ func TestAutoResume(t *testing.T) { }) } + filename2 := ".testAutoResumeRange" + defer os.Remove(filename2) + + for i := 0; i < segs; i++ { + segsize := (i + 1) * (size / segs) + t.Run(fmt.Sprintf("RangeWith%vBytes", segsize), func(t *testing.T) { + grabtest.WithTestServer(t, func(url string) { + req := mustNewRequest(filename2, url) + req.RangeRequestMinSize = 1 + req.RangeRequestMax = 5 + if i == segs-1 { + req.SetChecksum(sha256.New(), sum, false) + } + resp := mustDo(req) + if i > 0 && !resp.DidResume { + t.Errorf("expected Response.DidResume to be true") + } + testComplete(t, resp) + }, + grabtest.ContentLength(segsize), + grabtest.StatusCode(func(r *http.Request) int { return http.StatusPartialContent }), + ) + }) + } + t.Run("WithFailure", func(t *testing.T) { grabtest.WithTestServer(t, func(url string) { // request smaller segment @@ -976,3 +1002,76 @@ func TestRangeRequest(t *testing.T) { }) } } + +type rangeTestClient struct { + fn func(req *http.Request) (*http.Response, error) +} + +func (c *rangeTestClient) Do(req *http.Request) (*http.Response, error) { + return c.fn(req) +} + +func TestRangeRequestAutoResume(t *testing.T) { + const ( + NumChunks = 8 + Size = 1048576 + BadChunkStart = 393216 + ) + sum := grabtest.DefaultHandlerSHA256ChecksumBytes + expectErr := fmt.Errorf("TEST: cancelled") + + client := NewClient() + var wg sync.WaitGroup + client.HTTPClient = &rangeTestClient{func(req *http.Request) (*http.Response, error) { + wg.Add(1) + // Catch a range in the middle and wait for the other + // ranges to finish. Then, fail this range. + if strings.HasPrefix(req.Header.Get("Range"), fmt.Sprintf("bytes=%v", BadChunkStart)) { + go func() { + time.Sleep(100 * time.Millisecond) + wg.Done() + }() + wg.Wait() + return nil, expectErr + } + defer wg.Done() + return DefaultClient.HTTPClient.Do(req) + }} + + filename := ".testRangeRequestAutoResume" + defer os.Remove(filename) + + opts := []grabtest.HandlerOption{ + grabtest.ContentLength(int(Size)), + grabtest.StatusCode(func(r *http.Request) int { + return http.StatusPartialContent + }), + } + + grabtest.WithTestServer(t, func(url string) { + // run a request with parallel range chunks, where a + // chunk in the middle is not written + req := mustNewRequest(filename, url) + req.RangeRequestMinSize = 1 + req.RangeRequestMax = NumChunks + req.SetChecksum(sha256.New(), sum, false) + resp := client.Do(req) + if err := resp.Err(); !errors.Is(err, expectErr) { + t.Fatal(err.Error()) + } + testComplete(t, resp) + if resp.BytesComplete() >= resp.Size() { + t.Fatalf("Expected BytesComplete() [%v] < Size() [%v]", resp.BytesComplete(), resp.Size()) + } + + st, err := os.Stat(resp.Filename) + if err != nil { + t.Fatalf(err.Error()) + } + if st.Size() > BadChunkStart { + t.Fatalf("Partially written file size %v is not <= %v", st.Size(), BadChunkStart) + } + }, + opts..., + ) +} diff --git a/v3/response.go b/v3/response.go index 75f476b..aa8e024 100644 --- a/v3/response.go +++ b/v3/response.go @@ -82,7 +82,7 @@ type Response struct { // enabled. storeBuffer bytes.Buffer - // bytesCompleted specifies the number of bytes which were already + // bytesResumed specifies the number of bytes which were already // transferred before this transfer began. bytesResumed int64 @@ -260,3 +260,12 @@ func (c *Response) closeResponseBody() error { } return c.HTTPResponse.Body.Close() } + +func (c *Response) isRangeRequest() bool { + if c.Request.RangeRequestMax > 0 && c.acceptRanges { + if c.HTTPResponse.ContentLength >= c.Request.RangeRequestMinSize { + return true + } + } + return false +} diff --git a/v3/transfer.go b/v3/transfer.go index ddfb3a1..d92c049 100644 --- a/v3/transfer.go +++ b/v3/transfer.go @@ -130,6 +130,26 @@ func (c *transfer) BPS() float64 { return c.gauge.BPS() } +// transferRangesErr wraps a http error with extra information +// about the offset ranges that were successfully written +type transferRangesErr struct { + // wrapped error + err error + // the end byte offset of the last successfully written range + LastOffsetEnd int64 +} + +func (e transferRangesErr) Error() string { + if e.err != nil { + return e.err.Error() + } + return "transferRangesErr: nil" +} + +func (e transferRangesErr) Unwrap() error { + return e.err +} + type transferRanges struct { n int64 // must be 64bit aligned on 386 ctx context.Context @@ -139,6 +159,7 @@ type transferRanges struct { w io.WriterAt r *http.Request length int64 + offset int64 workers int bufSize int } @@ -152,6 +173,7 @@ func newTransferRanges(client HTTPClient, headResp *Response, dst io.WriterAt) * w: dst, r: headResp.Request.HTTPRequest, length: headResp.HTTPResponse.ContentLength, + offset: headResp.bytesResumed, workers: headResp.Request.RangeRequestMax, bufSize: headResp.bufferSize, } @@ -188,23 +210,45 @@ func (c *transferRanges) Copy() (written int64, err error) { wg, ctx := errgroup.WithContext(ctx) - chunkSize := c.length / int64(c.workers) + chunkSize := (c.length - c.offset) / int64(c.workers) var start, end int64 + start += c.offset + completed := make([]int64, c.workers) for i := 1; i <= c.workers; i++ { if i == c.workers { - end = c.length + end = c.offset + c.length } else { end = start + chunkSize } + if end > c.length { + end = c.length + } + id := i - 1 offset := start limit := end wg.Go(func() error { - return c.requestChunk(ctx, offset, limit) + e := c.requestChunk(ctx, offset, limit) + if e == nil { + // when a chunk succeeds, record the ending offset + completed[id] = limit + } + return e }) start = end } - err = wg.Wait() + if err = wg.Wait(); err != nil { + rangeErr := transferRangesErr{err: err} + // find the last successful end offset before an error + for _, offset := range completed { + if offset == 0 { + break + } + rangeErr.LastOffsetEnd = offset + } + err = rangeErr + } + return c.N(), err }