diff --git a/platform/osRegistryInterface.go b/platform/osRegistryInterface.go new file mode 100644 index 0000000000..9eb7ae2140 --- /dev/null +++ b/platform/osRegistryInterface.go @@ -0,0 +1,46 @@ +//go:build windows +// +build windows + +package platform + +import "golang.org/x/sys/windows/registry" + +// Registry interface for interacting with the Windows registry +type Registry interface { + OpenKey(k registry.Key, path string, access uint32) (RegistryKey, error) +} + +// RegistryKey interface to represent an open registry key +type RegistryKey interface { + GetStringValue(name string) (string, uint32, error) + SetStringValue(name, value string) error + Close() error +} + +type WindowsRegistry struct{} + +// WindowsRegistryKey implements the RegistryKey interface +type WindowsRegistryKey struct { + key registry.Key +} + +func (r *WindowsRegistry) OpenKey(k registry.Key, path string, access uint32) (RegistryKey, error) { + key, err := registry.OpenKey(registry.LOCAL_MACHINE, path, access) + if err != nil { + return nil, err + } + return &WindowsRegistryKey{key}, nil +} + +func (k *WindowsRegistryKey) GetStringValue(name string) (val string, valtype uint32, err error) { + value, valType, err := k.key.GetStringValue(name) + return value, valType, err +} + +func (k *WindowsRegistryKey) SetStringValue(name, value string) error { + return k.key.SetStringValue(name, value) +} + +func (k *WindowsRegistryKey) Close() error { + return k.key.Close() +} diff --git a/platform/os_windows.go b/platform/os_windows.go index 63900c6e5d..e89620941d 100644 --- a/platform/os_windows.go +++ b/platform/os_windows.go @@ -20,6 +20,9 @@ import ( "github.com/pkg/errors" "go.uber.org/zap" "golang.org/x/sys/windows" + "golang.org/x/sys/windows/registry" + "golang.org/x/sys/windows/svc" + "golang.org/x/sys/windows/svc/mgr" ) const ( @@ -61,24 +64,12 @@ const ( // for vlan tagged arp requests SDNRemoteArpMacAddress = "12-34-56-78-9a-bc" - // Command to get SDNRemoteArpMacAddress registry key - GetSdnRemoteArpMacAddressCommand = "(Get-ItemProperty " + - "-Path HKLM:\\SYSTEM\\CurrentControlSet\\Services\\hns\\State -Name SDNRemoteArpMacAddress).SDNRemoteArpMacAddress" - - // Command to set SDNRemoteArpMacAddress registry key - SetSdnRemoteArpMacAddressCommand = "Set-ItemProperty " + - "-Path HKLM:\\SYSTEM\\CurrentControlSet\\Services\\hns\\State -Name SDNRemoteArpMacAddress -Value \"12-34-56-78-9a-bc\"" - - // Command to check if system has hns state path or not - CheckIfHNSStatePathExistsCommand = "Test-Path " + - "-Path HKLM:\\SYSTEM\\CurrentControlSet\\Services\\hns\\State" - // Command to fetch netadapter and pnp id + //TODO can we replace this (and things in endpoint_windows) with "golang.org/x/sys/windows" + //var adapterInfo windows.IpAdapterInfo + //var bufferSize uint32 = uint32(unsafe.Sizeof(adapterInfo)) GetMacAddressVFPPnpIDMapping = "Get-NetAdapter | Select-Object MacAddress, PnpDeviceID| Format-Table -HideTableHeaders" - // Command to restart HNS service - RestartHnsServiceCommand = "Restart-Service -Name hns" - // Interval between successive checks for mellanox adapter's PriorityVLANTag value defaultMellanoxMonitorInterval = 30 * time.Second @@ -257,32 +248,39 @@ func (p *execClient) ExecutePowershellCommandWithContext(ctx context.Context, co } // SetSdnRemoteArpMacAddress sets the regkey for SDNRemoteArpMacAddress needed for multitenancy if hns is enabled -func SetSdnRemoteArpMacAddress(execClient ExecClient) error { - exists, err := execClient.ExecutePowershellCommand(CheckIfHNSStatePathExistsCommand) +func SetSdnRemoteArpMacAddress(reg Registry) error { + key, err := reg.OpenKey(registry.LOCAL_MACHINE, "SYSTEM\\CurrentControlSet\\Services\\hns\\State", registry.READ|registry.SET_VALUE) if err != nil { + if err == registry.ErrNotExist { + log.Printf("hns state path does not exist, skip setting SdnRemoteArpMacAddress") + return nil + } errMsg := fmt.Sprintf("Failed to check the existent of hns state path due to error %s", err.Error()) log.Printf(errMsg) return errors.Errorf(errMsg) } - if strings.EqualFold(exists, "false") { - log.Printf("hns state path does not exist, skip setting SdnRemoteArpMacAddress") - return nil - } + if sdnRemoteArpMacAddressSet == false { - result, err := execClient.ExecutePowershellCommand(GetSdnRemoteArpMacAddressCommand) + + //Was (Get-ItemProperty -Path HKLM:\\SYSTEM\\CurrentControlSet\\Services\\hns\\State -Name SDNRemoteArpMacAddress).SDNRemoteArpMacAddress" + result, _, err := key.GetStringValue("SDNRemoteArpMacAddress") if err != nil { return err } // Set the reg key if not already set or has incorrect value if result != SDNRemoteArpMacAddress { - if _, err = execClient.ExecutePowershellCommand(SetSdnRemoteArpMacAddressCommand); err != nil { + + //was "Set-ItemProperty -Path HKLM:\\SYSTEM\\CurrentControlSet\\Services\\hns\\State -Name SDNRemoteArpMacAddress -Value \"12-34-56-78-9a-bc\"" + + if err := key.SetStringValue("SDNRemoteArpMacAddress", SDNRemoteArpMacAddress); err != nil { log.Printf("Failed to set SDNRemoteArpMacAddress due to error %s", err.Error()) return err } - log.Printf("[Azure CNS] SDNRemoteArpMacAddress regKey set successfully. Restarting hns service.") - if _, err := execClient.ExecutePowershellCommand(RestartHnsServiceCommand); err != nil { + + // was "Restart-Service -Name hns" + if err := restartService("hns"); err != nil { log.Printf("Failed to Restart HNS Service due to error %s", err.Error()) return err } @@ -294,6 +292,50 @@ func SetSdnRemoteArpMacAddress(execClient ExecClient) error { return nil } +// straight out of chat gpt +func restartService(serviceName string) error { + // Connect to the service manager + m, err := mgr.Connect() + if err != nil { + return fmt.Errorf("could not connect to service manager: %v", err) + } + defer m.Disconnect() + + // Open the service by name + service, err := m.OpenService(serviceName) + if err != nil { + return fmt.Errorf("could not access service: %v", err) + } + defer service.Close() + + // Stop the service + _, err = service.Control(svc.Stop) + if err != nil { + return fmt.Errorf("could not stop service: %v", err) + } + + // Wait for the service to stop + status, err := service.Query() + if err != nil { + return fmt.Errorf("could not query service status: %v", err) + } + for status.State != svc.Stopped { + time.Sleep(500 * time.Millisecond) + status, err = service.Query() + if err != nil { + return fmt.Errorf("could not query service status: %v", err) + } + } + + // Start the service again + err = service.Start() + if err != nil { + return fmt.Errorf("could not start service: %v", err) + } + + return nil +} + func HasMellanoxAdapter() bool { m := &mellanox.Mellanox{} return hasNetworkAdapter(m) @@ -364,6 +406,7 @@ func GetProcessNameByID(pidstr string) (string, error) { pidstr = strings.Trim(pidstr, "\r\n") cmd := fmt.Sprintf("Get-Process -Id %s|Format-List", pidstr) p := NewExecClient(nil) + //TODO not riemovign this because it seems to only be called in test? out, err := p.ExecutePowershellCommand(cmd) if err != nil { log.Printf("Process is not running. Output:%v, Error %v", out, err) diff --git a/platform/os_windows_test.go b/platform/os_windows_test.go index 5cb5dacc12..d7ce12f11f 100644 --- a/platform/os_windows_test.go +++ b/platform/os_windows_test.go @@ -4,7 +4,6 @@ import ( "context" "errors" "os/exec" - "strings" "testing" "time" @@ -12,10 +11,59 @@ import ( "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "golang.org/x/sys/windows/registry" ) var errTestFailure = errors.New("test failure") +// MockRegistry is a mock implementation of the Registry interface +type MockRegistry struct { + Keys map[string]*MockRegistryKey +} + +// OpenKey opens a mock registry key. +func (r *MockRegistry) OpenKey(k registry.Key, path string, access uint32) (RegistryKey, error) { + // Directly check if the key exists in the mock registry by its path + if key, exists := r.Keys[path]; exists { + return key, nil + } + return nil, errors.New("key does not exist") +} + +// MockRegistryKey is a mock implementation of the RegistryKey interface +type MockRegistryKey struct { + Values map[string]string +} + +func (k *MockRegistryKey) GetStringValue(name string) (string, uint32, error) { + if value, exists := k.Values[name]; exists { + return value, registry.SZ, nil + } + return "", registry.SZ, registry.ErrNotExist +} + +func (k *MockRegistryKey) SetStringValue(name, value string) error { + k.Values[name] = value + return nil +} + +func (k *MockRegistryKey) Close() error { + return nil +} + +func initMockRegistry() *MockRegistry { + mockRegistry := &MockRegistry{ + Keys: map[string]*MockRegistryKey{ + `SOFTWARE\MockCompany\MockApp`: { + Values: map[string]string{ + "MockValue": "MockData", + }, + }, + }, + } + return mockRegistry +} + // Test if hasNetworkAdapter returns false on actual error or empty adapter name(an error) func TestHasNetworkAdapterReturnsError(t *testing.T) { ctrl := gomock.NewController(t) @@ -116,34 +164,36 @@ func TestExecuteCommandError(t *testing.T) { } func TestSetSdnRemoteArpMacAddress_hnsNotEnabled(t *testing.T) { - mockExecClient := NewMockExecClient(false) + //mockExecClient := NewMockExecClient(false) + mockRegistry := initMockRegistry() // testing skip setting SdnRemoteArpMacAddress when hns not enabled - mockExecClient.SetPowershellCommandResponder(func(_ string) (string, error) { - return "False", nil - }) - err := SetSdnRemoteArpMacAddress(mockExecClient) + // mockExecClient.SetPowershellCommandResponder(func(_ string) (string, error) { + // return "False", nil + // }) + err := SetSdnRemoteArpMacAddress(mockRegistry) assert.NoError(t, err) assert.Equal(t, false, sdnRemoteArpMacAddressSet) // testing the scenario when there is an error in checking if hns is enabled or not - mockExecClient.SetPowershellCommandResponder(func(_ string) (string, error) { - return "", errTestFailure - }) - err = SetSdnRemoteArpMacAddress(mockExecClient) + // mockExecClient.SetPowershellCommandResponder(func(_ string) (string, error) { + // return "", errTestFailure + // }) + err = SetSdnRemoteArpMacAddress(mockRegistry) assert.ErrorAs(t, err, &errTestFailure) assert.Equal(t, false, sdnRemoteArpMacAddressSet) } func TestSetSdnRemoteArpMacAddress_hnsEnabled(t *testing.T) { - mockExecClient := NewMockExecClient(false) + //mockExecClient := NewMockExecClient(false) + mockRegistry := initMockRegistry() // happy path - mockExecClient.SetPowershellCommandResponder(func(cmd string) (string, error) { - if strings.Contains(cmd, "Test-Path") { - return "True", nil - } - return "", nil - }) - err := SetSdnRemoteArpMacAddress(mockExecClient) + // mockExecClient.SetPowershellCommandResponder(func(cmd string) (string, error) { + // if strings.Contains(cmd, "Test-Path") { + // return "True", nil + // } + // return "", nil + // }) + err := SetSdnRemoteArpMacAddress(mockRegistry) assert.NoError(t, err) assert.Equal(t, true, sdnRemoteArpMacAddressSet) // reset sdnRemoteArpMacAddressSet