Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Provide SCTP Association OnClose callback #2858

Merged
merged 1 commit into from
Aug 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 29 additions & 3 deletions sctptransport.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
// OnStateChange func()

onErrorHandler func(error)
onCloseHandler func(error)

sctpAssociation *sctp.Association
onDataChannelHandler func(*DataChannel)
Expand Down Expand Up @@ -176,6 +177,7 @@
dataChannels = append(dataChannels, dc.dataChannel)
}
r.lock.RUnlock()

ACCEPT:
for {
dc, err := datachannel.Accept(a, &datachannel.Config{
Expand All @@ -185,6 +187,9 @@
if !errors.Is(err, io.EOF) {
r.log.Errorf("Failed to accept data channel: %v", err)
r.onError(err)
r.onClose(err)

Check warning on line 190 in sctptransport.go

View check run for this annotation

Codecov / codecov/patch

sctptransport.go#L190

Added line #L190 was not covered by tests
} else {
r.onClose(nil)
}
return
}
Expand Down Expand Up @@ -232,9 +237,14 @@
MaxRetransmits: maxRetransmits,
}, r, r.api.settingEngine.LoggerFactory.NewLogger("ortc"))
if err != nil {
// This data channel is invalid. Close it and log an error.
if err1 := dc.Close(); err1 != nil {
r.log.Errorf("Failed to close invalid data channel: %v", err1)

Check warning on line 242 in sctptransport.go

View check run for this annotation

Codecov / codecov/patch

sctptransport.go#L241-L242

Added lines #L241 - L242 were not covered by tests
}
r.log.Errorf("Failed to accept data channel: %v", err)
r.onError(err)
return
// We've received a datachannel with invalid configuration. We can still receive other datachannels.
continue ACCEPT

Check warning on line 247 in sctptransport.go

View check run for this annotation

Codecov / codecov/patch

sctptransport.go#L247

Added line #L247 was not covered by tests
}

<-r.onDataChannel(rtcDC)
Expand All @@ -251,8 +261,7 @@
}
}

// OnError sets an event handler which is invoked when
// the SCTP connection error occurs.
// OnError sets an event handler which is invoked when the SCTP Association errors.
func (r *SCTPTransport) OnError(f func(err error)) {
r.lock.Lock()
defer r.lock.Unlock()
Expand All @@ -269,6 +278,23 @@
}
}

// OnClose sets an event handler which is invoked when the SCTP Association closes.
func (r *SCTPTransport) OnClose(f func(err error)) {
r.lock.Lock()
defer r.lock.Unlock()
r.onCloseHandler = f
}

func (r *SCTPTransport) onClose(err error) {
r.lock.RLock()
handler := r.onCloseHandler
r.lock.RUnlock()

if handler != nil {
go handler(err)
}
}

// OnDataChannel sets an event handler which is invoked when a data
// channel message arrives from a remote peer.
func (r *SCTPTransport) OnDataChannel(f func(*DataChannel)) {
Expand Down
71 changes: 70 additions & 1 deletion sctptransport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,13 @@

package webrtc

import "testing"
import (
"bytes"
"testing"
"time"

"github.com/stretchr/testify/require"
)

func TestGenerateDataChannelID(t *testing.T) {
sctpTransportWithChannels := func(ids []uint16) *SCTPTransport {
Expand Down Expand Up @@ -55,3 +61,66 @@ func TestGenerateDataChannelID(t *testing.T) {
}
}
}

func TestSCTPTransportOnClose(t *testing.T) {
offerPC, answerPC, err := newPair()
require.NoError(t, err)

answerPC.OnDataChannel(func(dc *DataChannel) {
dc.OnMessage(func(_ DataChannelMessage) {
if err1 := dc.Send([]byte("hello")); err1 != nil {
t.Error("failed to send message")
}
})
})

recvMsg := make(chan struct{}, 1)
offerPC.OnConnectionStateChange(func(state PeerConnectionState) {
if state == PeerConnectionStateConnected {
defer func() {
offerPC.OnConnectionStateChange(nil)
}()

dc, createErr := offerPC.CreateDataChannel(expectedLabel, nil)
if createErr != nil {
t.Errorf("Failed to create a PC pair for testing")
return
}
dc.OnMessage(func(msg DataChannelMessage) {
if !bytes.Equal(msg.Data, []byte("hello")) {
t.Error("invalid msg received")
}
recvMsg <- struct{}{}
})
dc.OnOpen(func() {
if err1 := dc.Send([]byte("hello")); err1 != nil {
t.Error("failed to send initial msg", err1)
}
})
}
})

err = signalPair(offerPC, answerPC)
require.NoError(t, err)

select {
case <-recvMsg:
case <-time.After(5 * time.Second):
t.Fatal("timed out")
}

// setup SCTP OnClose callback
ch := make(chan error, 1)
answerPC.SCTP().OnClose(func(err error) {
ch <- err
})

err = offerPC.Close() // This will trigger sctp onclose callback on remote
require.NoError(t, err)

select {
case <-ch:
case <-time.After(5 * time.Second):
t.Fatal("timed out")
}
}
Loading