diff --git a/conn.go b/conn.go index 06d97f7e..b55bc99c 100644 --- a/conn.go +++ b/conn.go @@ -33,6 +33,9 @@ type Conn struct { ctx context.Context cancelCtx context.CancelFunc + closeOnce sync.Once + closeErr error + busObj BusObject unixFD bool uuid string @@ -246,6 +249,11 @@ func newConn(tr transport, opts ...ConnOption) (*Conn, error) { conn.ctx = context.Background() } conn.ctx, conn.cancelCtx = context.WithCancel(conn.ctx) + go func() { + <-conn.ctx.Done() + conn.Close() + }() + conn.calls = newCallTracker() if conn.handler == nil { conn.handler = NewDefaultHandler() @@ -272,24 +280,27 @@ func (conn *Conn) BusObject() BusObject { // and the channels passed to Eavesdrop and Signal are closed. This method must // not be called on shared connections. func (conn *Conn) Close() error { - conn.outHandler.close() - if term, ok := conn.signalHandler.(Terminator); ok { - term.Terminate() - } + conn.closeOnce.Do(func() { + conn.outHandler.close() + if term, ok := conn.signalHandler.(Terminator); ok { + term.Terminate() + } - if term, ok := conn.handler.(Terminator); ok { - term.Terminate() - } + if term, ok := conn.handler.(Terminator); ok { + term.Terminate() + } - conn.eavesdroppedLck.Lock() - if conn.eavesdropped != nil { - close(conn.eavesdropped) - } - conn.eavesdroppedLck.Unlock() + conn.eavesdroppedLck.Lock() + if conn.eavesdropped != nil { + close(conn.eavesdropped) + } + conn.eavesdroppedLck.Unlock() - conn.cancelCtx() + conn.cancelCtx() - return conn.transport.Close() + conn.closeErr = conn.transport.Close() + }) + return conn.closeErr } // Context returns the context associated with the connection. The diff --git a/conn_test.go b/conn_test.go index 13c15e51..ef3b8d93 100644 --- a/conn_test.go +++ b/conn_test.go @@ -1,6 +1,7 @@ package dbus import ( + "context" "encoding/binary" "io" "io/ioutil" @@ -453,3 +454,107 @@ func TestInterceptors(t *testing.T) { t.Fatal(err) } } + +func TestCloseCancelsConnectionContext(t *testing.T) { + bus, err := SessionBusPrivate() + if err != nil { + t.Fatal(err) + } + defer bus.Close() + + if err = bus.Auth(nil); err != nil { + t.Fatal(err) + } + if err = bus.Hello(); err != nil { + t.Fatal(err) + } + if err != nil { + t.Fatal(err) + } + + // The context is not done at this point + ctx := bus.Context() + select { + case <-ctx.Done(): + t.Fatal("context should not be done") + default: + } + + err = bus.Close() + if err != nil { + t.Fatal(err) + } + select { + case <-ctx.Done(): + // expected + case <-time.After(5 * time.Second): + t.Fatal("context is not done after connection closed") + } +} + +func TestDisconnectCancelsConnectionContext(t *testing.T) { + reader, pipewriter := io.Pipe() + defer pipewriter.Close() + defer reader.Close() + + bus, err := NewConn(rwc{Reader: reader, Writer: ioutil.Discard}) + if err != nil { + t.Fatal(err) + } + + go func() { + _, err := pipewriter.Write([]byte("REJECTED name\r\nOK myuuid\r\n")) + if err != nil { + t.Errorf("error writing to pipe: %v", err) + } + }() + err = bus.Auth([]Auth{fakeAuth{}}) + if err != nil { + t.Fatal(err) + } + + ctx := bus.Context() + + pipewriter.Close() + select { + case <-ctx.Done(): + // expected + case <-time.After(5 * time.Second): + t.Fatal("context is not done after connection closed") + } +} + +func TestCancellingContextClosesConnection(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + reader, pipewriter := io.Pipe() + defer pipewriter.Close() + defer reader.Close() + + bus, err := NewConn(rwc{Reader: reader, Writer: ioutil.Discard}, WithContext(ctx)) + if err != nil { + t.Fatal(err) + } + + go func() { + _, err := pipewriter.Write([]byte("REJECTED name\r\nOK myuuid\r\n")) + if err != nil { + t.Errorf("error writing to pipe: %v", err) + } + }() + err = bus.Auth([]Auth{fakeAuth{}}) + if err != nil { + t.Fatal(err) + } + + // Cancel the connection's parent context and give time for + // other goroutines to schedule. + cancel() + time.Sleep(50 * time.Millisecond) + + err = bus.BusObject().Call("org.freedesktop.DBus.Peer.Ping", 0).Store() + if err != ErrClosed { + t.Errorf("expected connection to be closed, but got: %v", err) + } +}