diff --git a/pipe.go b/pipe.go index 25cc8110..72daac0f 100644 --- a/pipe.go +++ b/pipe.go @@ -22,6 +22,7 @@ import ( //sys connectNamedPipe(pipe syscall.Handle, o *syscall.Overlapped) (err error) = ConnectNamedPipe //sys createNamedPipe(name string, flags uint32, pipeMode uint32, maxInstances uint32, outSize uint32, inSize uint32, defaultTimeout uint32, sa *syscall.SecurityAttributes) (handle syscall.Handle, err error) [failretval==syscall.InvalidHandle] = CreateNamedPipeW +//sys disconnectNamedPipe(pipe syscall.Handle) (err error) = DisconnectNamedPipe //sys getNamedPipeInfo(pipe syscall.Handle, flags *uint32, outSize *uint32, inSize *uint32, maxInstances *uint32) (err error) = GetNamedPipeInfo //sys getNamedPipeHandleState(pipe syscall.Handle, state *uint32, curInstances *uint32, maxCollectionCount *uint32, collectDataTimeout *uint32, userName *uint16, maxUserNameSize uint32) (err error) = GetNamedPipeHandleStateW //sys localAlloc(uFlags uint32, length uint32) (ptr uintptr) = LocalAlloc @@ -30,6 +31,12 @@ import ( //sys rtlDosPathNameToNtPathName(name *uint16, ntName *unicodeString, filePart uintptr, reserved uintptr) (status ntStatus) = ntdll.RtlDosPathNameToNtPathName_U //sys rtlDefaultNpAcl(dacl *uintptr) (status ntStatus) = ntdll.RtlDefaultNpAcl +type PipeConn interface { + net.Conn + Disconnect() error + Flush() error +} + type ioStatusBlock struct { Status, Information uintptr } @@ -80,6 +87,8 @@ type win32Pipe struct { path string } +var _ PipeConn = (*win32Pipe)(nil) + type win32MessageBytePipe struct { win32Pipe writeClosed bool @@ -103,6 +112,10 @@ func (f *win32Pipe) SetDeadline(t time.Time) error { return f.SetWriteDeadline(t) } +func (f *win32Pipe) Disconnect() error { + return disconnectNamedPipe(f.win32File.handle) +} + // CloseWrite closes the write side of a message pipe in byte mode. func (f *win32MessageBytePipe) CloseWrite() error { if f.writeClosed { diff --git a/pipe_test.go b/pipe_test.go index cb6632a4..90342af4 100644 --- a/pipe_test.go +++ b/pipe_test.go @@ -402,6 +402,66 @@ func TestTimeoutPendingWrite(t *testing.T) { <-serverDone } +func TestDisconnectPipe(t *testing.T) { + l, err := ListenPipe(testPipeName, nil) + if err != nil { + t.Fatal(err) + } + defer l.Close() + + const testData = "foo" + serverDone := make(chan struct{}) + + go func() { + s, err := l.Accept() + if err != nil { + t.Error(err) + return + } + defer func() { + s.Close() + close(serverDone) + }() + + if _, err := s.Write([]byte(testData)); err != nil { + t.Error(err) + return + } + + if err := s.(PipeConn).Flush(); err != nil { + t.Error(err) + return + } + + if err := s.(PipeConn).Disconnect(); err != nil { + t.Error(err) + return + } + }() + + client, err := DialPipe(testPipeName, nil) + if err != nil { + t.Fatal(err) + } + defer client.Close() + + buf := make([]byte, len(testData)) + if _, err = client.Read(buf); err != nil { + t.Fatal(err) + } + + dataRead := string(buf) + if dataRead != testData { + t.Fatalf("incorrect data read %q", dataRead) + } + + if _, err = client.Read(buf); err == nil { + t.Fatal("read should fail") + } + + <-serverDone +} + type CloseWriter interface { CloseWrite() error } diff --git a/zsyscall_windows.go b/zsyscall_windows.go index 469b16f6..0583702c 100644 --- a/zsyscall_windows.go +++ b/zsyscall_windows.go @@ -65,6 +65,7 @@ var ( procConnectNamedPipe = modkernel32.NewProc("ConnectNamedPipe") procCreateIoCompletionPort = modkernel32.NewProc("CreateIoCompletionPort") procCreateNamedPipeW = modkernel32.NewProc("CreateNamedPipeW") + procDisconnectNamedPipe = modkernel32.NewProc("DisconnectNamedPipe") procGetCurrentThread = modkernel32.NewProc("GetCurrentThread") procGetNamedPipeHandleStateW = modkernel32.NewProc("GetNamedPipeHandleStateW") procGetNamedPipeInfo = modkernel32.NewProc("GetNamedPipeInfo") @@ -331,6 +332,14 @@ func _createNamedPipe(name *uint16, flags uint32, pipeMode uint32, maxInstances return } +func disconnectNamedPipe(pipe syscall.Handle) (err error) { + r1, _, e1 := syscall.Syscall(procDisconnectNamedPipe.Addr(), 1, uintptr(pipe), 0, 0) + if r1 == 0 { + err = errnoErr(e1) + } + return +} + func getCurrentThread() (h syscall.Handle) { r0, _, _ := syscall.Syscall(procGetCurrentThread.Addr(), 0, 0, 0, 0) h = syscall.Handle(r0)