Skip to content

Commit

Permalink
tests: add tests for handling of connection context
Browse files Browse the repository at this point in the history
  • Loading branch information
jhenstridge authored and jsouthworth committed Sep 25, 2019
1 parent 0953362 commit 5ae69d1
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 14 deletions.
39 changes: 25 additions & 14 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ type Conn struct {
ctx context.Context
cancelCtx context.CancelFunc

closeOnce sync.Once
closeErr error

busObj BusObject
unixFD bool
uuid string
Expand Down Expand Up @@ -246,6 +249,11 @@ func newConn(tr transport, opts ...ConnOption) (*Conn, error) {
conn.ctx = context.Background()
}
conn.ctx, conn.cancelCtx = context.WithCancel(conn.ctx)
go func() {
<-conn.ctx.Done()
conn.Close()
}()

conn.calls = newCallTracker()
if conn.handler == nil {
conn.handler = NewDefaultHandler()
Expand All @@ -272,24 +280,27 @@ func (conn *Conn) BusObject() BusObject {
// and the channels passed to Eavesdrop and Signal are closed. This method must
// not be called on shared connections.
func (conn *Conn) Close() error {
conn.outHandler.close()
if term, ok := conn.signalHandler.(Terminator); ok {
term.Terminate()
}
conn.closeOnce.Do(func() {
conn.outHandler.close()
if term, ok := conn.signalHandler.(Terminator); ok {
term.Terminate()
}

if term, ok := conn.handler.(Terminator); ok {
term.Terminate()
}
if term, ok := conn.handler.(Terminator); ok {
term.Terminate()
}

conn.eavesdroppedLck.Lock()
if conn.eavesdropped != nil {
close(conn.eavesdropped)
}
conn.eavesdroppedLck.Unlock()
conn.eavesdroppedLck.Lock()
if conn.eavesdropped != nil {
close(conn.eavesdropped)
}
conn.eavesdroppedLck.Unlock()

conn.cancelCtx()
conn.cancelCtx()

return conn.transport.Close()
conn.closeErr = conn.transport.Close()
})
return conn.closeErr
}

// Context returns the context associated with the connection. The
Expand Down
105 changes: 105 additions & 0 deletions conn_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package dbus

import (
"context"
"encoding/binary"
"io"
"io/ioutil"
Expand Down Expand Up @@ -453,3 +454,107 @@ func TestInterceptors(t *testing.T) {
t.Fatal(err)
}
}

func TestCloseCancelsConnectionContext(t *testing.T) {
bus, err := SessionBusPrivate()
if err != nil {
t.Fatal(err)
}
defer bus.Close()

if err = bus.Auth(nil); err != nil {
t.Fatal(err)
}
if err = bus.Hello(); err != nil {
t.Fatal(err)
}
if err != nil {
t.Fatal(err)
}

// The context is not done at this point
ctx := bus.Context()
select {
case <-ctx.Done():
t.Fatal("context should not be done")
default:
}

err = bus.Close()
if err != nil {
t.Fatal(err)
}
select {
case <-ctx.Done():
// expected
case <-time.After(5 * time.Second):
t.Fatal("context is not done after connection closed")
}
}

func TestDisconnectCancelsConnectionContext(t *testing.T) {
reader, pipewriter := io.Pipe()
defer pipewriter.Close()
defer reader.Close()

bus, err := NewConn(rwc{Reader: reader, Writer: ioutil.Discard})
if err != nil {
t.Fatal(err)
}

go func() {
_, err := pipewriter.Write([]byte("REJECTED name\r\nOK myuuid\r\n"))
if err != nil {
t.Errorf("error writing to pipe: %v", err)
}
}()
err = bus.Auth([]Auth{fakeAuth{}})
if err != nil {
t.Fatal(err)
}

ctx := bus.Context()

pipewriter.Close()
select {
case <-ctx.Done():
// expected
case <-time.After(5 * time.Second):
t.Fatal("context is not done after connection closed")
}
}

func TestCancellingContextClosesConnection(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

reader, pipewriter := io.Pipe()
defer pipewriter.Close()
defer reader.Close()

bus, err := NewConn(rwc{Reader: reader, Writer: ioutil.Discard}, WithContext(ctx))
if err != nil {
t.Fatal(err)
}

go func() {
_, err := pipewriter.Write([]byte("REJECTED name\r\nOK myuuid\r\n"))
if err != nil {
t.Errorf("error writing to pipe: %v", err)
}
}()
err = bus.Auth([]Auth{fakeAuth{}})
if err != nil {
t.Fatal(err)
}

// Cancel the connection's parent context and give time for
// other goroutines to schedule.
cancel()
time.Sleep(50 * time.Millisecond)

err = bus.BusObject().Call("org.freedesktop.DBus.Peer.Ping", 0).Store()
if err != ErrClosed {
t.Errorf("expected connection to be closed, but got: %v", err)
}
}

0 comments on commit 5ae69d1

Please sign in to comment.