diff --git a/proxyclient/client.go b/proxyclient/client.go index a055a02..706fa55 100644 --- a/proxyclient/client.go +++ b/proxyclient/client.go @@ -15,7 +15,6 @@ import ( "github.com/sirupsen/logrus" corev1 "k8s.io/api/core/v1" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" ) const ( @@ -24,6 +23,7 @@ const ( defaultServerPath = "/connect" retryTimeout = 1 * time.Second certificateWatchInterval = 10 * time.Second + getSecretRetryTimeout = 5 * time.Second ) type PortForwarder interface { @@ -81,9 +81,7 @@ func New(ctx context.Context, serverSharedSecret, namespace, certSecretName, cer namespace: namespace, } - if err := client.buildDialer(ctx, secretController); err != nil { - return nil, fmt.Errorf("dialer build failed %w: ", err) - } + client.setUpBuildDialerCallback(ctx, certSecretName, secretController) for _, opt := range opts { opt(client) @@ -92,12 +90,16 @@ func New(ctx context.Context, serverSharedSecret, namespace, certSecretName, cer return client, nil } -func (c *ProxyClient) buildDialer(ctx context.Context, secretController v1.SecretController) error { - secretController.OnChange(ctx, "remotedialer-proxy", func(_ string, newSecret *corev1.Secret) (*corev1.Secret, error) { +func (c *ProxyClient) setUpBuildDialerCallback(ctx context.Context, certSecretName string, secretController v1.SecretController) { + secretController.OnChange(ctx, certSecretName, func(_ string, newSecret *corev1.Secret) (*corev1.Secret, error) { + if newSecret == nil { + return nil, nil + } + if newSecret.Name == c.certSecretName && newSecret.Namespace == c.namespace { rootCAs, err := buildCertFromSecret(c.namespace, c.certSecretName, newSecret) if err != nil { - logrus.Errorf("build certificate failed: %s", err.Error()) + logrus.Errorf("RDPClient: build certificate failed: %s", err.Error()) return nil, err } @@ -109,32 +111,11 @@ func (c *ProxyClient) buildDialer(ctx context.Context, secretController v1.Secre }, } c.dialerMtx.Unlock() - logrus.Infof("certificate updated successfully") + logrus.Infof("RDPClient: certificate updated successfully") } return newSecret, nil }) - - secret, err := secretController.Get(c.namespace, c.certSecretName, metav1.GetOptions{}) - if err != nil { - return err - } - - rootCAs, err := buildCertFromSecret(c.namespace, c.certSecretName, secret) - if err != nil { - return fmt.Errorf("build certificate failed: %w", err) - } - - c.dialerMtx.Lock() - c.dialer = &websocket.Dialer{ - TLSClientConfig: &tls.Config{ - RootCAs: rootCAs, - ServerName: c.certServerName, - }, - } - c.dialerMtx.Unlock() - - return nil } func buildCertFromSecret(namespace, certSecretName string, secret *corev1.Secret) (*x509.CertPool, error) { @@ -153,27 +134,46 @@ func buildCertFromSecret(namespace, certSecretName string, secret *corev1.Secret func (c *ProxyClient) Run(ctx context.Context) { go func() { + LookForDialer: + for { + select { + case <-ctx.Done(): + logrus.Infof("RDPClient: Received stop signal.") + return + + default: + logrus.Info("RDPClient: Checking if dialer is built...") + if c.dialer != nil { + logrus.Info("RDPClient: Dialer is built. Ready to start.") + break LookForDialer + } + + logrus.Infof("RDPClient: Dialer is not built yet, waiting %d secs to re-check.", getSecretRetryTimeout/time.Second) + time.Sleep(getSecretRetryTimeout) + } + } + for { select { case <-ctx.Done(): - logrus.Infof("ProxyClient: ClientConnect finished. If no error, the session closed cleanly.") + logrus.Infof("RDPClient: Received signal to stop.") return default: if err := c.forwarder.Start(); err != nil { - logrus.Errorf("remotedialer.ProxyClient error: %s ", err) + logrus.Errorf("RDPClient: %s ", err) time.Sleep(retryTimeout) continue } - logrus.Infof("ProxyClient connecting to %s", c.serverUrl) + logrus.Infof("RDPClient: connecting to %s", c.serverUrl) headers := http.Header{} headers.Set("X-API-Tunnel-Secret", c.serverConnectSecret) onConnectAuth := func(proto, address string) bool { return true } onConnect := func(sessionCtx context.Context, session *remotedialer.Session) error { - logrus.Infoln("ProxyClient: remotedialer session connected!") + logrus.Infoln("RDPClient: remotedialer session connected!") if c.onConnect != nil { return c.onConnect(sessionCtx, session) } @@ -185,21 +185,19 @@ func (c *ProxyClient) Run(ctx context.Context) { c.dialerMtx.Unlock() if err := remotedialer.ClientConnect(ctx, c.serverUrl, headers, dialer, onConnectAuth, onConnect); err != nil { - logrus.Errorf("remotedialer.ClientConnect error: %s", err.Error()) + logrus.Errorf("RDPClient: remotedialer.ClientConnect error: %s", err.Error()) c.forwarder.Stop() time.Sleep(retryTimeout) } } } }() - - <-ctx.Done() } func (c *ProxyClient) Stop() { if c.forwarder != nil { c.forwarder.Stop() - logrus.Infoln("ProxyClient: port-forward stopped.") + logrus.Infoln("RDPClient: port-forward stopped.") } }