diff --git a/test/e2e/storage/drivers/proxy/portproxy.go b/test/e2e/storage/drivers/proxy/portproxy.go index 7f1d604882c56..aef56974ad884 100644 --- a/test/e2e/storage/drivers/proxy/portproxy.go +++ b/test/e2e/storage/drivers/proxy/portproxy.go @@ -24,13 +24,11 @@ import ( "io/ioutil" "net" "net/http" - "strconv" "sync" "sync/atomic" "time" v1 "k8s.io/api/core/v1" - apierrors "k8s.io/apimachinery/pkg/api/errors" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime/schema" "k8s.io/apimachinery/pkg/util/httpstream" @@ -94,35 +92,11 @@ func Listen(ctx context.Context, clientset kubernetes.Interface, restConfig *res addr: addr, } - // Port forwarding is allowed to fail and will be restarted when it does. - prepareForwarding := func() (*remotePort, error) { - pod, err := clientset.CoreV1().Pods(addr.Namespace).Get(ctx, addr.PodName, metav1.GetOptions{}) - if err != nil { - return nil, err - } - for i, status := range pod.Status.ContainerStatuses { - if pod.Spec.Containers[i].Name == addr.ContainerName && - status.State.Running == nil { - return nil, fmt.Errorf("container %q is not running", addr.ContainerName) - } - } - - streamConn, _, err := dialer.Dial(portforward.PortForwardProtocolV1Name) - if err != nil { - return nil, fmt.Errorf("dialer failed: %v", err) - } - rp := &remotePort{ - streamConn: streamConn, - } - return rp, nil - } - var connectionsCreated, connectionsClosed int32 - runForwarding := func(rp *remotePort) { - defer rp.Close() - klog.V(5).Infof("%s: starting connection polling", prefix) - defer klog.V(5).Infof("%s: connection polling ended", prefix) + runForwarding := func() { + klog.V(2).Infof("%s: starting connection polling", prefix) + defer klog.V(2).Infof("%s: connection polling ended", prefix) // This delay determines how quickly we notice when someone has // connected inside the cluster. With socat, we cannot make this too small @@ -145,9 +119,9 @@ func Listen(ctx context.Context, clientset kubernetes.Interface, restConfig *res } klog.V(5).Infof("%s: trying to create a new connection #%d, %d open", prefix, connectionsCreated, openConnections) - stream, err := rp.dial(ctx, prefix, addr.Port) + stream, err := dial(ctx, fmt.Sprintf("%s #%d", prefix, connectionsCreated), dialer, addr.Port) if err != nil { - klog.V(5).Infof("%s: no connection: %v", prefix, err) + klog.Errorf("%s: no connection: %v", prefix, err) break } // Make the connection available to Accept below. @@ -166,18 +140,24 @@ func Listen(ctx context.Context, clientset kubernetes.Interface, restConfig *res // Portforwarding and polling for connections run in the background. go func() { for { - fw, err := prepareForwarding() - if err == nil { - runForwarding(fw) - } else { - if apierrors.IsNotFound(err) { - // This is normal, the pod isn't running yet. Log with lower severity. - klog.V(5).Infof("prepare forwarding %s: %v", addr, err) - } else { - klog.Errorf("prepare forwarding %s: %v", addr, err) + running := false + pod, err := clientset.CoreV1().Pods(addr.Namespace).Get(ctx, addr.PodName, metav1.GetOptions{}) + if err != nil { + klog.V(5).Infof("checking for container %q in pod %s/%s: %v", addr.ContainerName, addr.Namespace, addr.PodName, err) + } + for i, status := range pod.Status.ContainerStatuses { + if pod.Spec.Containers[i].Name == addr.ContainerName && + status.State.Running != nil { + running = true + break } } + if running { + klog.V(2).Infof("container %q in pod %s/%s is running", addr.ContainerName, addr.Namespace, addr.PodName) + runForwarding() + } + select { case <-ctx.Done(): return @@ -209,27 +189,32 @@ func (a Addr) String() string { return fmt.Sprintf("%s/%s:%d", a.Namespace, a.PodName, a.Port) } -// remotePort is a stripped down version of client-go/tools/portforward minus -// the local listeners. -type remotePort struct { +type stream struct { + httpstream.Stream streamConn httpstream.Connection - - requestIDLock sync.Mutex - requestID int } -func (rp *remotePort) dial(ctx context.Context, prefix string, port int) (httpstream.Stream, error) { - requestID := rp.nextRequestID() +func dial(ctx context.Context, prefix string, dialer httpstream.Dialer, port int) (s *stream, finalErr error) { + streamConn, _, err := dialer.Dial(portforward.PortForwardProtocolV1Name) + if err != nil { + return nil, fmt.Errorf("dialer failed: %v", err) + } + requestID := "1" + defer func() { + if finalErr != nil { + streamConn.Close() + } + }() // create error stream headers := http.Header{} headers.Set(v1.StreamType, v1.StreamTypeError) headers.Set(v1.PortHeader, fmt.Sprintf("%d", port)) - headers.Set(v1.PortForwardRequestIDHeader, strconv.Itoa(requestID)) + headers.Set(v1.PortForwardRequestIDHeader, requestID) // We're not writing to this stream, just reading an error message from it. // This happens asynchronously. - errorStream, err := rp.streamConn.CreateStream(headers) + errorStream, err := streamConn.CreateStream(headers) if err != nil { return nil, fmt.Errorf("error creating error stream: %v", err) } @@ -246,24 +231,20 @@ func (rp *remotePort) dial(ctx context.Context, prefix string, port int) (httpst // create data stream headers.Set(v1.StreamType, v1.StreamTypeData) - dataStream, err := rp.streamConn.CreateStream(headers) + dataStream, err := streamConn.CreateStream(headers) if err != nil { return nil, fmt.Errorf("error creating data stream: %v", err) } - return dataStream, nil + return &stream{ + Stream: dataStream, + streamConn: streamConn, + }, nil } -func (rp *remotePort) Close() { - rp.streamConn.Close() -} - -func (rp *remotePort) nextRequestID() int { - rp.requestIDLock.Lock() - defer rp.requestIDLock.Unlock() - id := rp.requestID - rp.requestID++ - return id +func (s *stream) Close() { + s.Stream.Close() + s.streamConn.Close() } type listener struct { @@ -292,7 +273,7 @@ func (l *listener) Accept() (net.Conn, error) { } type connection struct { - stream httpstream.Stream + stream *stream addr Addr counter int32 closed *int32 @@ -346,7 +327,8 @@ func (c *connection) Close() error { atomic.AddInt32(c.closed, 1) c.closed = nil } - return c.stream.Close() + c.stream.Close() + return nil } func (l *listener) Addr() net.Addr {