diff --git a/pkg/net/socketpair.go b/pkg/net/socketpair.go new file mode 100644 index 00000000..a0395381 --- /dev/null +++ b/pkg/net/socketpair.go @@ -0,0 +1,76 @@ +/* + Copyright The containerd Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package net + +import ( + "fmt" + "net" + "os" +) + +// SocketPair contains the os.Files of a connected pair of sockets. +type SocketPair struct { + local, peer *os.File +} + +// LocalFile returns the socketpair fd for local usage as an *os.File. +func (sp SocketPair) LocalFile() *os.File { + return sp.local +} + +// PeerFile returns the socketpair fd for peer usage as an *os.File. +func (sp SocketPair) PeerFile() *os.File { + return sp.peer +} + +// LocalConn returns a net.Conn for the local end of the socketpair. +func (sp SocketPair) LocalConn() (net.Conn, error) { + file := sp.LocalFile() + defer file.Close() + conn, err := net.FileConn(file) + if err != nil { + return nil, fmt.Errorf("failed to create net.Conn for %s: %w", file.Name(), err) + } + return conn, nil +} + +// PeerConn returns a net.Conn for the peer end of the socketpair. +func (sp SocketPair) PeerConn() (net.Conn, error) { + file := sp.PeerFile() + defer file.Close() + conn, err := net.FileConn(file) + if err != nil { + return nil, fmt.Errorf("failed to create net.Conn for %s: %w", file.Name(), err) + } + return conn, nil +} + +// Close closes both ends of the socketpair. +func (sp SocketPair) Close() { + sp.LocalClose() + sp.PeerClose() +} + +// LocalClose closes the local end of the socketpair. +func (sp SocketPair) LocalClose() { + sp.local.Close() +} + +// PeerClose closes the peer end of the socketpair. +func (sp SocketPair) PeerClose() { + sp.peer.Close() +} diff --git a/pkg/net/socketpair_unix.go b/pkg/net/socketpair_unix.go index f0b5cdb0..63b420c4 100644 --- a/pkg/net/socketpair_unix.go +++ b/pkg/net/socketpair_unix.go @@ -20,78 +20,27 @@ package net import ( "fmt" - "net" "os" + "syscall" - syscall "golang.org/x/sys/unix" + "golang.org/x/sys/unix" ) -const ( - local = 0 - peer = 1 -) - -// SocketPair contains the file descriptors of a connected pair of sockets. -type SocketPair [2]int - // NewSocketPair returns a connected pair of sockets. func NewSocketPair() (SocketPair, error) { - fds, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_STREAM, 0) - if err != nil { - return [2]int{-1, -1}, fmt.Errorf("failed to create socketpair: %w", err) - } - - return fds, nil -} - -// LocalFile returns the socketpair fd for local usage as an *os.File. -func (fds SocketPair) LocalFile() *os.File { - return os.NewFile(uintptr(fds[local]), fds.fileName()+"[0]") -} - -// PeerFile returns the socketpair fd for peer usage as an *os.File. -func (fds SocketPair) PeerFile() *os.File { - return os.NewFile(uintptr(fds[peer]), fds.fileName()+"[1]") -} - -// LocalConn returns a net.Conn for the local end of the socketpair. -func (fds SocketPair) LocalConn() (net.Conn, error) { - file := fds.LocalFile() - defer file.Close() - conn, err := net.FileConn(file) - if err != nil { - return nil, fmt.Errorf("failed to create net.Conn for %s[0]: %w", fds.fileName(), err) - } - return conn, nil -} - -// PeerConn returns a net.Conn for the peer end of the socketpair. -func (fds SocketPair) PeerConn() (net.Conn, error) { - file := fds.PeerFile() - defer file.Close() - conn, err := net.FileConn(file) + syscall.ForkLock.RLock() + defer syscall.ForkLock.RUnlock() + fds, err := unix.Socketpair(unix.AF_UNIX, unix.SOCK_STREAM, 0) if err != nil { - return nil, fmt.Errorf("failed to create net.Conn for %s[1]: %w", fds.fileName(), err) + return SocketPair{nil, nil}, fmt.Errorf("failed to create socketpair: %w", err) } - return conn, nil -} + unix.CloseOnExec(fds[0]) + unix.CloseOnExec(fds[1]) -// Close closes both ends of the socketpair. -func (fds SocketPair) Close() { - fds.LocalClose() - fds.PeerClose() -} - -// LocalClose closes the local end of the socketpair. -func (fds SocketPair) LocalClose() { - syscall.Close(fds[local]) -} - -// PeerClose closes the peer end of the socketpair. -func (fds SocketPair) PeerClose() { - syscall.Close(fds[peer]) -} + filename := fmt.Sprintf("socketpair-#%d:%d", fds[0], fds[1]) -func (fds SocketPair) fileName() string { - return fmt.Sprintf("socketpair-#%d:%d[0]", fds[local], fds[peer]) + return SocketPair{ + os.NewFile(uintptr(fds[0]), filename+"[0]"), + os.NewFile(uintptr(fds[1]), filename+"[1]"), + }, nil } diff --git a/pkg/net/socketpair_windows.go b/pkg/net/socketpair_windows.go index d88ea57d..4666d20e 100644 --- a/pkg/net/socketpair_windows.go +++ b/pkg/net/socketpair_windows.go @@ -20,21 +20,13 @@ package net import ( "fmt" - "net" "os" + "syscall" "unsafe" sys "golang.org/x/sys/windows" ) -// SocketPair contains a connected pair of sockets. -type SocketPair [2]sys.Handle - -const ( - local = 0 - peer = 1 -) - // NewSocketPair returns a connected pair of sockets. func NewSocketPair() (SocketPair, error) { /* return [2]sys.Handle{sys.InvalidHandle, sys.InvalidHandle}, @@ -46,7 +38,7 @@ func NewSocketPair() (SocketPair, error) { func emulateWithPreConnect() (SocketPair, error) { var ( - invalid = SocketPair{sys.InvalidHandle, sys.InvalidHandle} + invalid = SocketPair{nil, nil} sa sys.SockaddrInet4 //sn sys.Sockaddr l sys.Handle @@ -55,15 +47,14 @@ func emulateWithPreConnect() (SocketPair, error) { err error ) + syscall.ForkLock.RLock() + defer syscall.ForkLock.RUnlock() + l, err = socket(sys.AF_INET, sys.SOCK_STREAM, 0) if err != nil { return invalid, fmt.Errorf("failed to emulate socketpair (local Socket()): %w", err) } - defer func() { - if err != nil { - sys.CloseHandle(l) - } - }() + defer sys.CloseHandle(l) sa.Addr[0] = 127 sa.Addr[3] = 1 @@ -109,60 +100,15 @@ func emulateWithPreConnect() (SocketPair, error) { } }() - sys.CloseHandle(l) - return SocketPair{a, p}, nil -} - -// Close closes both ends of the socketpair. -func (sp SocketPair) Close() { - sp.LocalClose() - sp.PeerClose() -} - -// LocalFile returns the socketpair fd for local usage as an *os.File. -func (sp SocketPair) LocalFile() *os.File { - return os.NewFile(uintptr(sp[local]), sp.fileName()+"[0]") -} - -// PeerFile returns the socketpair fd for peer usage as an *os.File. -func (sp SocketPair) PeerFile() *os.File { - return os.NewFile(uintptr(sp[peer]), sp.fileName()+"[1]") -} + sys.CloseOnExec(a) + sys.CloseOnExec(p) -// LocalConn returns a net.Conn for the local end of the socketpair. -func (sp SocketPair) LocalConn() (net.Conn, error) { - file := sp.LocalFile() - defer file.Close() - conn, err := net.FileConn(file) - if err != nil { - return nil, fmt.Errorf("failed to create net.Conn for %s[0]: %w", sp.fileName(), err) - } - return conn, nil -} - -// PeerConn returns a net.Conn for the peer end of the socketpair. -func (sp SocketPair) PeerConn() (net.Conn, error) { - file := sp.PeerFile() - defer file.Close() - conn, err := net.FileConn(file) - if err != nil { - return nil, fmt.Errorf("failed to create net.Conn for %s[1]: %w", sp.fileName(), err) - } - return conn, nil -} - -// LocalClose closes the local end of the socketpair. -func (sp SocketPair) LocalClose() { - sys.CloseHandle(sp[local]) -} - -// PeerClose closes the peer end of the socketpair. -func (sp SocketPair) PeerClose() { - sys.CloseHandle(sp[peer]) -} + filename := fmt.Sprintf("socketpair-#%d:%d", a, p) -func (sp SocketPair) fileName() string { - return fmt.Sprintf("socketpair-#%d:%d[0]", sp[local], sp[peer]) + return SocketPair{ + os.NewFile(uintptr(a), filename+"[0]"), + os.NewFile(uintptr(p), filename+"[1]"), + }, nil } func socket(domain, typ, proto int) (sys.Handle, error) {