diff --git a/datachannel.go b/datachannel.go index e975a04f7b5..218ebfbec08 100644 --- a/datachannel.go +++ b/datachannel.go @@ -40,6 +40,8 @@ type DataChannel struct { readyState atomic.Value // DataChannelState bufferedAmountLowThreshold uint64 detachCalled bool + readLoopActive chan struct{} + isGracefulClosed bool // The binaryType represents attribute MUST, on getting, return the value to // which it was last set. On setting, if the new value is either the string @@ -225,6 +227,10 @@ func (d *DataChannel) OnOpen(f func()) { func (d *DataChannel) onOpen() { d.mu.RLock() handler := d.onOpenHandler + if d.isGracefulClosed { + d.mu.RUnlock() + return + } d.mu.RUnlock() if handler != nil { @@ -252,6 +258,10 @@ func (d *DataChannel) OnDial(f func()) { func (d *DataChannel) onDial() { d.mu.RLock() handler := d.onDialHandler + if d.isGracefulClosed { + d.mu.RUnlock() + return + } d.mu.RUnlock() if handler != nil { @@ -261,6 +271,10 @@ func (d *DataChannel) onDial() { // OnClose sets an event handler which is invoked when // the underlying data transport has been closed. +// Note: Due to backwards compatibility, there is a chance that +// OnClose can be called, even if the GracefulClose is used. +// If this is the case for you, you can deregister OnClose +// prior to GracefulClose. func (d *DataChannel) OnClose(f func()) { d.mu.Lock() defer d.mu.Unlock() @@ -292,6 +306,10 @@ func (d *DataChannel) OnMessage(f func(msg DataChannelMessage)) { func (d *DataChannel) onMessage(msg DataChannelMessage) { d.mu.RLock() handler := d.onMessageHandler + if d.isGracefulClosed { + d.mu.RUnlock() + return + } d.mu.RUnlock() if handler == nil { @@ -302,6 +320,10 @@ func (d *DataChannel) onMessage(msg DataChannelMessage) { func (d *DataChannel) handleOpen(dc *datachannel.DataChannel, isRemote, isAlreadyNegotiated bool) { d.mu.Lock() + if d.isGracefulClosed { + d.mu.Unlock() + return + } d.dataChannel = dc bufferedAmountLowThreshold := d.bufferedAmountLowThreshold onBufferedAmountLow := d.onBufferedAmountLow @@ -326,7 +348,12 @@ func (d *DataChannel) handleOpen(dc *datachannel.DataChannel, isRemote, isAlread d.mu.Lock() defer d.mu.Unlock() + if d.isGracefulClosed { + return + } + if !d.api.settingEngine.detach.DataChannels { + d.readLoopActive = make(chan struct{}) go d.readLoop() } } @@ -342,6 +369,10 @@ func (d *DataChannel) OnError(f func(err error)) { func (d *DataChannel) onError(err error) { d.mu.RLock() handler := d.onErrorHandler + if d.isGracefulClosed { + d.mu.RUnlock() + return + } d.mu.RUnlock() if handler != nil { @@ -350,6 +381,12 @@ func (d *DataChannel) onError(err error) { } func (d *DataChannel) readLoop() { + defer func() { + d.mu.Lock() + readLoopActive := d.readLoopActive + d.mu.Unlock() + defer close(readLoopActive) + }() buffer := make([]byte, dataChannelBufferSize) for { n, isString, err := d.dataChannel.ReadDataChannel(buffer) @@ -449,7 +486,32 @@ func (d *DataChannel) Detach() (datachannel.ReadWriteCloser, error) { // Close Closes the DataChannel. It may be called regardless of whether // the DataChannel object was created by this peer or the remote peer. func (d *DataChannel) Close() error { + return d.close(false) +} + +// GracefulClose Closes the DataChannel. It may be called regardless of whether +// the DataChannel object was created by this peer or the remote peer. It also waits +// for any goroutines it started to complete. This is only safe to call outside of +// DataChannel callbacks or if in a callback, in its own goroutine. +func (d *DataChannel) GracefulClose() error { + return d.close(true) +} + +// Normally, close only stops writes from happening, so graceful=true +// will wait for reads to be finished based on underlying SCTP association +// closure or a SCTP reset stream from the other side. This is safe to call +// with graceful=true after tearing down a PeerConnection but not +// necessarily before. For example, if you used a vnet and dropped all packets +// right before closing the DataChannel, you'd need never see a reset stream. +func (d *DataChannel) close(shouldGracefullyClose bool) error { d.mu.Lock() + d.isGracefulClosed = true + readLoopActive := d.readLoopActive + if shouldGracefullyClose && readLoopActive != nil { + defer func() { + <-readLoopActive + }() + } haveSctpTransport := d.dataChannel != nil d.mu.Unlock() diff --git a/go.mod b/go.mod index 703a3a2deb8..23016b4ad49 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,7 @@ go 1.19 require ( github.com/pion/datachannel v1.5.8 github.com/pion/dtls/v3 v3.0.0 - github.com/pion/ice/v3 v3.0.15 + github.com/pion/ice/v3 v3.0.16 github.com/pion/interceptor v0.1.29 github.com/pion/logging v0.2.2 github.com/pion/randutil v0.1.0 diff --git a/go.sum b/go.sum index 30c08c2d278..027ab9ac6fd 100644 --- a/go.sum +++ b/go.sum @@ -42,8 +42,8 @@ github.com/pion/dtls/v2 v2.2.12 h1:KP7H5/c1EiVAAKUmXyCzPiQe5+bCJrpOeKg/L05dunk= github.com/pion/dtls/v2 v2.2.12/go.mod h1:d9SYc9fch0CqK90mRk1dC7AkzzpwJj6u2GU3u+9pqFE= github.com/pion/dtls/v3 v3.0.0 h1:m2hzwPkzqoBjVKXm5ymNuX01OAjht82TdFL6LoTzgi4= github.com/pion/dtls/v3 v3.0.0/go.mod h1:tiX7NaneB0wNoRaUpaMVP7igAlkMCTQkbpiY+OfeIi0= -github.com/pion/ice/v3 v3.0.15 h1:6FFM1k1Ei36keZN1drl8/xaEs+NpMMG6M+MsVRchXho= -github.com/pion/ice/v3 v3.0.15/go.mod h1:SdmubtIsCcvdb1ZInrTUz7Iaqi90/rYd1pzbzlMxsZg= +github.com/pion/ice/v3 v3.0.16 h1:YoPlNg3jU1UT/DDTa9v/g1vH6A2/pAzehevI1o66H8E= +github.com/pion/ice/v3 v3.0.16/go.mod h1:SdmubtIsCcvdb1ZInrTUz7Iaqi90/rYd1pzbzlMxsZg= github.com/pion/interceptor v0.1.29 h1:39fsnlP1U8gw2JzOFWdfCU82vHvhW9o0rZnZF56wF+M= github.com/pion/interceptor v0.1.29/go.mod h1:ri+LGNjRUc5xUNtDEPzfdkmSqISixVTBF/z/Zms/6T4= github.com/pion/logging v0.2.2 h1:M9+AIj/+pxNsDfAT64+MAVgJO0rsyLnoJKCqf//DoeY= diff --git a/icegatherer.go b/icegatherer.go index 232145a974d..38f45d4c57d 100644 --- a/icegatherer.go +++ b/icegatherer.go @@ -190,13 +190,31 @@ func (g *ICEGatherer) Gather() error { // Close prunes all local candidates, and closes the ports. func (g *ICEGatherer) Close() error { + return g.close(false /* shouldGracefullyClose */) +} + +// GracefulClose prunes all local candidates, and closes the ports. It also waits +// for any goroutines it started to complete. This is only safe to call outside of +// ICEGatherer callbacks or if in a callback, in its own goroutine. +func (g *ICEGatherer) GracefulClose() error { + return g.close(true /* shouldGracefullyClose */) +} + +func (g *ICEGatherer) close(shouldGracefullyClose bool) error { g.lock.Lock() defer g.lock.Unlock() if g.agent == nil { return nil - } else if err := g.agent.Close(); err != nil { - return err + } + if shouldGracefullyClose { + if err := g.agent.GracefulClose(); err != nil { + return err + } + } else { + if err := g.agent.Close(); err != nil { + return err + } } g.agent = nil diff --git a/icetransport.go b/icetransport.go index 8e272bf984b..c23dafa293e 100644 --- a/icetransport.go +++ b/icetransport.go @@ -16,6 +16,7 @@ import ( "github.com/pion/ice/v3" "github.com/pion/logging" "github.com/pion/webrtc/v4/internal/mux" + "github.com/pion/webrtc/v4/internal/util" ) // ICETransport allows an application access to information about the ICE @@ -187,6 +188,17 @@ func (t *ICETransport) restart() error { // Stop irreversibly stops the ICETransport. func (t *ICETransport) Stop() error { + return t.stop(false /* shouldGracefullyClose */) +} + +// GracefulStop irreversibly stops the ICETransport. It also waits +// for any goroutines it started to complete. This is only safe to call outside of +// ICETransport callbacks or if in a callback, in its own goroutine. +func (t *ICETransport) GracefulStop() error { + return t.stop(true /* shouldGracefullyClose */) +} + +func (t *ICETransport) stop(shouldGracefullyClose bool) error { t.lock.Lock() defer t.lock.Unlock() @@ -197,8 +209,18 @@ func (t *ICETransport) Stop() error { } if t.mux != nil { - return t.mux.Close() + var closeErrs []error + if shouldGracefullyClose && t.gatherer != nil { + // we can't access icegatherer/icetransport.Close via + // mux's net.Conn Close so we call it earlier here. + closeErrs = append(closeErrs, t.gatherer.GracefulClose()) + } + closeErrs = append(closeErrs, t.mux.Close()) + return util.FlattenErrs(closeErrs) } else if t.gatherer != nil { + if shouldGracefullyClose { + return t.gatherer.GracefulClose() + } return t.gatherer.Close() } return nil diff --git a/internal/mux/mux.go b/internal/mux/mux.go index 20c399c8cb5..412f5f5adba 100644 --- a/internal/mux/mux.go +++ b/internal/mux/mux.go @@ -129,6 +129,10 @@ func (m *Mux) readLoop() { } if err = m.dispatch(buf[:n]); err != nil { + if errors.Is(err, io.ErrClosedPipe) { + // if the buffer was closed, that's not an error we care to report + return + } m.log.Errorf("mux: ending readLoop dispatch error %s", err.Error()) return } diff --git a/operations.go b/operations.go index bc366ac34db..67d24eebecc 100644 --- a/operations.go +++ b/operations.go @@ -13,12 +13,13 @@ type operation func() // Operations is a task executor. type operations struct { - mu sync.Mutex - busy bool - ops *list.List + mu sync.Mutex + busyCh chan struct{} + ops *list.List updateNegotiationNeededFlagOnEmptyChain *atomicBool onNegotiationNeeded func() + isClosed bool } func newOperations( @@ -33,21 +34,34 @@ func newOperations( } // Enqueue adds a new action to be executed. If there are no actions scheduled, -// the execution will start immediately in a new goroutine. +// the execution will start immediately in a new goroutine. If the queue has been +// closed, the operation will be dropped. The queue is only deliberately closed +// by a user. func (o *operations) Enqueue(op operation) { + o.mu.Lock() + defer o.mu.Unlock() + _ = o.tryEnqueue(op) +} + +// tryEnqueue attempts to enqueue the given operation. It returns false +// if the op is invalid or the queue is closed. mu must be locked by +// tryEnqueue's caller. +func (o *operations) tryEnqueue(op operation) bool { if op == nil { - return + return false } - o.mu.Lock() - running := o.busy + if o.isClosed { + return false + } o.ops.PushBack(op) - o.busy = true - o.mu.Unlock() - if !running { + if o.busyCh == nil { + o.busyCh = make(chan struct{}) go o.start() } + + return true } // IsEmpty checks if there are tasks in the queue @@ -62,12 +76,38 @@ func (o *operations) IsEmpty() bool { func (o *operations) Done() { var wg sync.WaitGroup wg.Add(1) - o.Enqueue(func() { + o.mu.Lock() + enqueued := o.tryEnqueue(func() { wg.Done() }) + o.mu.Unlock() + if !enqueued { + return + } wg.Wait() } +// GracefulClose waits for the operations queue to be cleared and forbids +// new operations from being enqueued. +func (o *operations) GracefulClose() { + o.mu.Lock() + if o.isClosed { + o.mu.Unlock() + return + } + // do not enqueue anymore ops from here on + // o.isClosed=true will also not allow a new busyCh + // to be created. + o.isClosed = true + + busyCh := o.busyCh + o.mu.Unlock() + if busyCh == nil { + return + } + <-busyCh +} + func (o *operations) pop() func() { o.mu.Lock() defer o.mu.Unlock() @@ -87,12 +127,17 @@ func (o *operations) start() { defer func() { o.mu.Lock() defer o.mu.Unlock() - if o.ops.Len() == 0 { - o.busy = false + // this wil lbe the most recent busy chan + close(o.busyCh) + + if o.ops.Len() == 0 || o.isClosed { + o.busyCh = nil return } + // either a new operation was enqueued while we // were busy, or an operation panicked + o.busyCh = make(chan struct{}) go o.start() }() diff --git a/operations_test.go b/operations_test.go index 428c2b4df97..3b84a1def5b 100644 --- a/operations_test.go +++ b/operations_test.go @@ -19,6 +19,8 @@ func TestOperations_Enqueue(t *testing.T) { onNegotiationNeededCalledCount++ onNegotiationNeededCalledCountMu.Unlock() }) + defer ops.GracefulClose() + for resultSet := 0; resultSet < 100; resultSet++ { results := make([]int, 16) resultSetCopy := resultSet @@ -46,5 +48,35 @@ func TestOperations_Enqueue(t *testing.T) { func TestOperations_Done(*testing.T) { ops := newOperations(&atomicBool{}, func() { }) + defer ops.GracefulClose() + ops.Done() +} + +func TestOperations_GracefulClose(t *testing.T) { + ops := newOperations(&atomicBool{}, func() { + }) + + counter := 0 + var counterMu sync.Mutex + incFunc := func() { + counterMu.Lock() + counter++ + counterMu.Unlock() + } + const times = 25 + for i := 0; i < times; i++ { + ops.Enqueue(incFunc) + } + ops.Done() + counterMu.Lock() + counterCur := counter + counterMu.Unlock() + assert.Equal(t, counterCur, times) + + ops.GracefulClose() + for i := 0; i < times; i++ { + ops.Enqueue(incFunc) + } ops.Done() + assert.Equal(t, counterCur, times) } diff --git a/peerconnection.go b/peerconnection.go index a764f98993e..0f30b6abf42 100644 --- a/peerconnection.go +++ b/peerconnection.go @@ -56,6 +56,8 @@ type PeerConnection struct { idpLoginURL *string isClosed *atomicBool + isGracefulClosed *atomicBool + isGracefulClosedDone chan struct{} isNegotiationNeeded *atomicBool updateNegotiationNeededFlagOnEmptyChain *atomicBool @@ -117,6 +119,8 @@ func (api *API) NewPeerConnection(configuration Configuration) (*PeerConnection, ICECandidatePoolSize: 0, }, isClosed: &atomicBool{}, + isGracefulClosed: &atomicBool{}, + isGracefulClosedDone: make(chan struct{}), isNegotiationNeeded: &atomicBool{}, updateNegotiationNeededFlagOnEmptyChain: &atomicBool{}, lastOffer: "", @@ -2092,13 +2096,34 @@ func (pc *PeerConnection) writeRTCP(pkts []rtcp.Packet, _ interceptor.Attributes return pc.dtlsTransport.WriteRTCP(pkts) } -// Close ends the PeerConnection +// Close ends the PeerConnection. func (pc *PeerConnection) Close() error { + return pc.close(false /* shouldGracefullyClose */) +} + +// GracefulClose ends the PeerConnection. It also waits +// for any goroutines it started to complete. This is only safe to call outside of +// PeerConnection callbacks or if in a callback, in its own goroutine. +func (pc *PeerConnection) GracefulClose() error { + return pc.close(true /* shouldGracefullyClose */) +} + +func (pc *PeerConnection) close(shouldGracefullyClose bool) error { // https://www.w3.org/TR/webrtc/#dom-rtcpeerconnection-close (step #1) // https://www.w3.org/TR/webrtc/#dom-rtcpeerconnection-close (step #2) + alreadyGracefullyClosed := shouldGracefullyClose && pc.isGracefulClosed.swap(true) if pc.isClosed.swap(true) { + if alreadyGracefullyClosed { + // similar but distinct condition where we may be waiting for some + // other graceful close to finish. Incorrectly using isClosed may + // leak a goroutine. + <-pc.isGracefulClosedDone + } return nil } + if shouldGracefullyClose && !alreadyGracefullyClosed { + defer close(pc.isGracefulClosedDone) + } // https://www.w3.org/TR/webrtc/#dom-rtcpeerconnection-close (step #3) pc.signalingState.Set(SignalingStateClosed) @@ -2142,12 +2167,28 @@ func (pc *PeerConnection) Close() error { // https://www.w3.org/TR/webrtc/#dom-rtcpeerconnection-close (step #8, #9, #10) if pc.iceTransport != nil { - closeErrs = append(closeErrs, pc.iceTransport.Stop()) + if shouldGracefullyClose { + // note that it isn't canon to stop gracefully + closeErrs = append(closeErrs, pc.iceTransport.GracefulStop()) + } else { + closeErrs = append(closeErrs, pc.iceTransport.Stop()) + } } // https://www.w3.org/TR/webrtc/#dom-rtcpeerconnection-close (step #11) pc.updateConnectionState(pc.ICEConnectionState(), pc.dtlsTransport.State()) + if shouldGracefullyClose { + pc.ops.GracefulClose() + + // note that it isn't canon to stop gracefully + pc.sctpTransport.lock.Lock() + for _, d := range pc.sctpTransport.dataChannels { + closeErrs = append(closeErrs, d.GracefulClose()) + } + pc.sctpTransport.lock.Unlock() + } + return util.FlattenErrs(closeErrs) } @@ -2321,8 +2362,11 @@ func (pc *PeerConnection) startTransports(iceRole ICERole, dtlsRole DTLSRole, re } pc.dtlsTransport.internalOnCloseHandler = func() { - pc.log.Info("Closing PeerConnection from DTLS CloseNotify") + if pc.isClosed.get() { + return + } + pc.log.Info("Closing PeerConnection from DTLS CloseNotify") go func() { if pcClosErr := pc.Close(); pcClosErr != nil { pc.log.Warnf("Failed to close PeerConnection from DTLS CloseNotify: %s", pcClosErr) diff --git a/peerconnection_close_test.go b/peerconnection_close_test.go index 5360d701fc2..e2fbfeca88b 100644 --- a/peerconnection_close_test.go +++ b/peerconnection_close_test.go @@ -179,3 +179,69 @@ func TestPeerConnection_Close_DuringICE(t *testing.T) { t.Error("pcOffer.Close() Timeout") } } + +func TestPeerConnection_CloseWithIncomingMessages(t *testing.T) { + // Limit runtime in case of deadlocks + lim := test.TimeOut(time.Second * 20) + defer lim.Stop() + + report := test.CheckRoutinesStrict(t) + defer report() + + pcOffer, pcAnswer, err := newPair() + if err != nil { + t.Fatal(err) + } + + var dcAnswer *DataChannel + answerDataChannelOpened := make(chan struct{}) + pcAnswer.OnDataChannel(func(d *DataChannel) { + // Make sure this is the data channel we were looking for. (Not the one + // created in signalPair). + if d.Label() != "data" { + return + } + dcAnswer = d + close(answerDataChannelOpened) + }) + + dcOffer, err := pcOffer.CreateDataChannel("data", nil) + if err != nil { + t.Fatal(err) + } + + offerDataChannelOpened := make(chan struct{}) + dcOffer.OnOpen(func() { + close(offerDataChannelOpened) + }) + + err = signalPair(pcOffer, pcAnswer) + if err != nil { + t.Fatal(err) + } + + <-offerDataChannelOpened + <-answerDataChannelOpened + + msgNum := 0 + dcOffer.OnMessage(func(_ DataChannelMessage) { + t.Log("msg", msgNum) + msgNum++ + }) + + // send 50 messages, then close pcOffer, and then send another 50 + for i := 0; i < 100; i++ { + if i == 50 { + err = pcOffer.GracefulClose() + if err != nil { + t.Fatal(err) + } + } + _ = dcAnswer.Send([]byte("hello!")) + } + + err = pcAnswer.GracefulClose() + if err != nil { + t.Fatal(err) + } +}