From f84a9009f6f44da7b412549dd7062e0fc7eedc4d Mon Sep 17 00:00:00 2001 From: Rod Vagg Date: Wed, 5 Jul 2023 19:06:38 +1000 Subject: [PATCH] fix: don't sendEvent if context cancelled Fixes: https://github.com/filecoin-project/lassie/issues/343 --- pkg/retriever/graphsyncretriever.go | 6 ++--- pkg/retriever/httpretriever.go | 2 +- pkg/retriever/parallelpeerretriever.go | 35 +++++++++++++++----------- 3 files changed, 24 insertions(+), 19 deletions(-) diff --git a/pkg/retriever/graphsyncretriever.go b/pkg/retriever/graphsyncretriever.go index 1309c36e..e43f1c59 100644 --- a/pkg/retriever/graphsyncretriever.go +++ b/pkg/retriever/graphsyncretriever.go @@ -184,7 +184,7 @@ func (pg *ProtocolGraphsync) Retrieve( eventsSubscriber := func(event datatransfer.Event, channelState datatransfer.ChannelState) { switch event.Code { case datatransfer.Open: - shared.sendEvent(events.Proposed(retrieval.Clock.Now(), retrieval.request.RetrievalID, candidate)) + shared.sendEvent(ctx, events.Proposed(retrieval.Clock.Now(), retrieval.request.RetrievalID, candidate)) case datatransfer.NewVoucherResult: lastVoucher := channelState.LastVoucherResult() resType, err := retrievaltypes.DealResponseFromNode(lastVoucher.Voucher) @@ -192,12 +192,12 @@ func (pg *ProtocolGraphsync) Retrieve( return } if resType.Status == retrievaltypes.DealStatusAccepted { - shared.sendEvent(events.Accepted(retrieval.Clock.Now(), retrieval.request.RetrievalID, candidate)) + shared.sendEvent(ctx, events.Accepted(retrieval.Clock.Now(), retrieval.request.RetrievalID, candidate)) } case datatransfer.DataReceivedProgress: if !receivedFirstByte { receivedFirstByte = true - shared.sendEvent(events.FirstByte(retrieval.Clock.Now(), retrieval.request.RetrievalID, candidate, retrieval.Clock.Since(retrievalStart), multicodec.TransportGraphsyncFilecoinv1)) + shared.sendEvent(ctx, events.FirstByte(retrieval.Clock.Now(), retrieval.request.RetrievalID, candidate, retrieval.Clock.Since(retrievalStart), multicodec.TransportGraphsyncFilecoinv1)) } if lastBytesReceivedTimer != nil { doneLk.Lock() diff --git a/pkg/retriever/httpretriever.go b/pkg/retriever/httpretriever.go index 2ef1e893..62359034 100644 --- a/pkg/retriever/httpretriever.go +++ b/pkg/retriever/httpretriever.go @@ -117,7 +117,7 @@ func (ph *ProtocolHttp) Retrieve( var ttfb time.Duration rdr := newTimeToFirstByteReader(resp.Body, func() { ttfb = retrieval.Clock.Since(retrievalStart) - shared.sendEvent(events.FirstByte(retrieval.Clock.Now(), retrieval.request.RetrievalID, candidate, ttfb, multicodec.TransportIpfsGatewayHttp)) + shared.sendEvent(ctx, events.FirstByte(retrieval.Clock.Now(), retrieval.request.RetrievalID, candidate, ttfb, multicodec.TransportIpfsGatewayHttp)) }) cfg := verifiedcar.Config{ Root: retrieval.request.Cid, diff --git a/pkg/retriever/parallelpeerretriever.go b/pkg/retriever/parallelpeerretriever.go index df9512a7..354ecbe5 100644 --- a/pkg/retriever/parallelpeerretriever.go +++ b/pkg/retriever/parallelpeerretriever.go @@ -97,8 +97,9 @@ func (shared *retrievalShared) canSendResult() bool { // sendResult will only send a result to the parent goroutine if a retrieval has // finished (likely by a success), otherwise it will send the result -func (shared *retrievalShared) sendResult(result retrievalResult) bool { +func (shared *retrievalShared) sendResult(ctx context.Context, result retrievalResult) bool { select { + case <-ctx.Done(): case <-shared.finishChan: return false case shared.resultChan <- result: @@ -112,9 +113,9 @@ func (shared *retrievalShared) sendResult(result retrievalResult) bool { return true } -func (shared *retrievalShared) sendEvent(event events.EventWithProviderID) { +func (shared *retrievalShared) sendEvent(ctx context.Context, event events.EventWithProviderID) { retrievalEvent := event.(types.RetrievalEvent) - shared.sendResult(retrievalResult{PeerID: event.ProviderId(), Event: &retrievalEvent}) + shared.sendResult(ctx, retrievalResult{PeerID: event.ProviderId(), Event: &retrievalEvent}) } func (cfg *parallelPeerRetriever) Retrieve( @@ -193,7 +194,11 @@ func (retrieval *retrieval) RetrieveFromAsyncCandidates(asyncCandidates types.In select { case <-finishAll: case <-time.After(100 * time.Millisecond): - logger.Warn("Unable to successfully cancel all retrieval attempts withing 100ms") + logger.Errorf( + "Possible leak: unable to successfully cancel all %s retrieval attempts for %s within 100ms", + retrieval.Protocol.Code().String(), + retrieval.request.Cid.String(), + ) } return stats, err } @@ -297,7 +302,7 @@ func (retrieval *retrieval) runRetrievalCandidate( var retrievalErr error var done func() - shared.sendEvent(events.StartedRetrieval(retrieval.parallelPeerRetriever.Clock.Now(), retrieval.request.RetrievalID, candidate, retrieval.Protocol.Code())) + shared.sendEvent(ctx, events.StartedRetrieval(retrieval.parallelPeerRetriever.Clock.Now(), retrieval.request.RetrievalID, candidate, retrieval.Protocol.Code())) connectCtx := ctx if timeout != 0 { var timeoutFunc func() @@ -309,15 +314,15 @@ func (retrieval *retrieval) runRetrievalCandidate( connectTime, err := retrieval.Protocol.Connect(connectCtx, retrieval, startTime, candidate) if err != nil { if ctx.Err() == nil { // not cancelled, maybe timed out though - logger.Warnf("Failed to connect to SP %s: %v", candidate.MinerPeer.ID, err) + logger.Warnf("Failed to connect to SP %s on protocol %s: %v", candidate.MinerPeer.ID, retrieval.Protocol.Code().String(), err) retrievalErr = fmt.Errorf("%w: %v", ErrConnectFailed, err) if err := retrieval.Session.RecordFailure(retrieval.request.RetrievalID, candidate.MinerPeer.ID); err != nil { - logger.Errorf("Error recording retrieval failure: %v", err) + logger.Errorf("Error recording retrieval failure on protocol %s: %v", retrieval.Protocol.Code().String(), err) } - shared.sendEvent(events.FailedRetrieval(retrieval.parallelPeerRetriever.Clock.Now(), retrieval.request.RetrievalID, candidate, retrievalErr.Error())) + shared.sendEvent(ctx, events.FailedRetrieval(retrieval.parallelPeerRetriever.Clock.Now(), retrieval.request.RetrievalID, candidate, retrievalErr.Error())) } } else { - shared.sendEvent(events.ConnectedToProvider(retrieval.parallelPeerRetriever.Clock.Now(), retrieval.request.RetrievalID, candidate)) + shared.sendEvent(ctx, events.ConnectedToProvider(retrieval.parallelPeerRetriever.Clock.Now(), retrieval.request.RetrievalID, candidate)) retrieval.Session.RecordConnectTime(candidate.MinerPeer.ID, connectTime) @@ -332,12 +337,12 @@ func (retrieval *retrieval) runRetrievalCandidate( if errors.Is(retrievalErr, ErrRetrievalTimedOut) { msg = fmt.Sprintf("timeout after %s", timeout) } - shared.sendEvent(events.FailedRetrieval(retrieval.parallelPeerRetriever.Clock.Now(), retrieval.request.RetrievalID, candidate, msg)) + shared.sendEvent(ctx, events.FailedRetrieval(retrieval.parallelPeerRetriever.Clock.Now(), retrieval.request.RetrievalID, candidate, msg)) if err := retrieval.Session.RecordFailure(retrieval.request.RetrievalID, candidate.MinerPeer.ID); err != nil { - logger.Errorf("Error recording retrieval failure: %v", err) + logger.Errorf("Error recording retrieval failure for protocol %s: %v", retrieval.Protocol.Code().String(), err) } } else { - shared.sendEvent(events.Success( + shared.sendEvent(ctx, events.Success( retrieval.parallelPeerRetriever.Clock.Now(), retrieval.request.RetrievalID, candidate, @@ -359,13 +364,13 @@ func (retrieval *retrieval) runRetrievalCandidate( if shared.canSendResult() { if retrievalErr != nil { if ctx.Err() != nil { // cancelled, don't report the error - shared.sendResult(retrievalResult{PeerID: candidate.MinerPeer.ID}) + shared.sendResult(ctx, retrievalResult{PeerID: candidate.MinerPeer.ID}) } else { // an error of some kind to report - shared.sendResult(retrievalResult{PeerID: candidate.MinerPeer.ID, Err: retrievalErr}) + shared.sendResult(ctx, retrievalResult{PeerID: candidate.MinerPeer.ID, Err: retrievalErr}) } } else { // success, we have stats and no errors - shared.sendResult(retrievalResult{PeerID: candidate.MinerPeer.ID, Stats: stats}) + shared.sendResult(ctx, retrievalResult{PeerID: candidate.MinerPeer.ID, Stats: stats}) } } // else nothing to do, we were cancelled