diff --git a/sctptransport.go b/sctptransport.go index 800de1dfa20..819d207568f 100644 --- a/sctptransport.go +++ b/sctptransport.go @@ -45,6 +45,7 @@ type SCTPTransport struct { // OnStateChange func() onErrorHandler func(error) + onCloseHandler func(error) sctpAssociation *sctp.Association onDataChannelHandler func(*DataChannel) @@ -176,6 +177,7 @@ func (r *SCTPTransport) acceptDataChannels(a *sctp.Association) { dataChannels = append(dataChannels, dc.dataChannel) } r.lock.RUnlock() + ACCEPT: for { dc, err := datachannel.Accept(a, &datachannel.Config{ @@ -185,6 +187,9 @@ ACCEPT: if !errors.Is(err, io.EOF) { r.log.Errorf("Failed to accept data channel: %v", err) r.onError(err) + r.onClose(err) + } else { + r.onClose(nil) } return } @@ -232,9 +237,14 @@ ACCEPT: MaxRetransmits: maxRetransmits, }, r, r.api.settingEngine.LoggerFactory.NewLogger("ortc")) if err != nil { + // This data channel is invalid. Close it and log an error. + if err1 := dc.Close(); err1 != nil { + r.log.Errorf("Failed to close invalid data channel: %v", err1) + } r.log.Errorf("Failed to accept data channel: %v", err) r.onError(err) - return + // We've received a datachannel with invalid configuration. We can still receive other datachannels. + continue ACCEPT } <-r.onDataChannel(rtcDC) @@ -251,8 +261,7 @@ ACCEPT: } } -// OnError sets an event handler which is invoked when -// the SCTP connection error occurs. +// OnError sets an event handler which is invoked when the SCTP Association errors. func (r *SCTPTransport) OnError(f func(err error)) { r.lock.Lock() defer r.lock.Unlock() @@ -269,6 +278,23 @@ func (r *SCTPTransport) onError(err error) { } } +// OnClose sets an event handler which is invoked when the SCTP Association closes. +func (r *SCTPTransport) OnClose(f func(err error)) { + r.lock.Lock() + defer r.lock.Unlock() + r.onCloseHandler = f +} + +func (r *SCTPTransport) onClose(err error) { + r.lock.RLock() + handler := r.onCloseHandler + r.lock.RUnlock() + + if handler != nil { + go handler(err) + } +} + // OnDataChannel sets an event handler which is invoked when a data // channel message arrives from a remote peer. func (r *SCTPTransport) OnDataChannel(f func(*DataChannel)) { diff --git a/sctptransport_test.go b/sctptransport_test.go index 9943e8f0629..c02e228f05a 100644 --- a/sctptransport_test.go +++ b/sctptransport_test.go @@ -6,7 +6,13 @@ package webrtc -import "testing" +import ( + "bytes" + "testing" + "time" + + "github.com/stretchr/testify/require" +) func TestGenerateDataChannelID(t *testing.T) { sctpTransportWithChannels := func(ids []uint16) *SCTPTransport { @@ -55,3 +61,66 @@ func TestGenerateDataChannelID(t *testing.T) { } } } + +func TestSCTPTransportOnClose(t *testing.T) { + offerPC, answerPC, err := newPair() + require.NoError(t, err) + + answerPC.OnDataChannel(func(dc *DataChannel) { + dc.OnMessage(func(_ DataChannelMessage) { + if err1 := dc.Send([]byte("hello")); err1 != nil { + t.Error("failed to send message") + } + }) + }) + + recvMsg := make(chan struct{}, 1) + offerPC.OnConnectionStateChange(func(state PeerConnectionState) { + if state == PeerConnectionStateConnected { + defer func() { + offerPC.OnConnectionStateChange(nil) + }() + + dc, createErr := offerPC.CreateDataChannel(expectedLabel, nil) + if createErr != nil { + t.Errorf("Failed to create a PC pair for testing") + return + } + dc.OnMessage(func(msg DataChannelMessage) { + if !bytes.Equal(msg.Data, []byte("hello")) { + t.Error("invalid msg received") + } + recvMsg <- struct{}{} + }) + dc.OnOpen(func() { + if err1 := dc.Send([]byte("hello")); err1 != nil { + t.Error("failed to send initial msg", err1) + } + }) + } + }) + + err = signalPair(offerPC, answerPC) + require.NoError(t, err) + + select { + case <-recvMsg: + case <-time.After(5 * time.Second): + t.Fatal("timed out") + } + + // setup SCTP OnClose callback + ch := make(chan error, 1) + answerPC.SCTP().OnClose(func(err error) { + ch <- err + }) + + err = offerPC.Close() // This will trigger sctp onclose callback on remote + require.NoError(t, err) + + select { + case <-ch: + case <-time.After(5 * time.Second): + t.Fatal("timed out") + } +}