Skip to content

Commit

Permalink
fix: remove PowerShell from Windows registry interactions (#2993)
Browse files Browse the repository at this point in the history
remove powershell from windows registry txs

Signed-off-by: Evan Baker <rbtr@users.noreply.github.com>
  • Loading branch information
rbtr authored Nov 22, 2024
1 parent 87c3307 commit 887c445
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 71 deletions.
3 changes: 1 addition & 2 deletions cns/service/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -798,8 +798,7 @@ func main() {
}

// Setting the remote ARP MAC address to 12-34-56-78-9a-bc on windows for external traffic if HNS is enabled
execClient := platform.NewExecClient(nil)
err = platform.SetSdnRemoteArpMacAddress(execClient)
err = platform.SetSdnRemoteArpMacAddress(rootCtx)
if err != nil {
logger.Errorf("Failed to set remote ARP MAC address: %v", err)
return
Expand Down
2 changes: 1 addition & 1 deletion platform/os_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ func (p *execClient) KillProcessByName(processName string) error {

// SetSdnRemoteArpMacAddress sets the regkey for SDNRemoteArpMacAddress needed for multitenancy
// This operation is specific to windows OS
func SetSdnRemoteArpMacAddress(_ ExecClient) error {
func SetSdnRemoteArpMacAddress(context.Context) error {
return nil
}

Expand Down
107 changes: 66 additions & 41 deletions platform/os_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ import (
"github.com/pkg/errors"
"go.uber.org/zap"
"golang.org/x/sys/windows"
"golang.org/x/sys/windows/registry"
"golang.org/x/sys/windows/svc"
"golang.org/x/sys/windows/svc/mgr"
)

const (
Expand Down Expand Up @@ -60,20 +63,9 @@ const (
// for vlan tagged arp requests
SDNRemoteArpMacAddress = "12-34-56-78-9a-bc"

// Command to get SDNRemoteArpMacAddress registry key
GetSdnRemoteArpMacAddressCommand = "(Get-ItemProperty " +
"-Path HKLM:\\SYSTEM\\CurrentControlSet\\Services\\hns\\State -Name SDNRemoteArpMacAddress).SDNRemoteArpMacAddress"

// Command to set SDNRemoteArpMacAddress registry key
SetSdnRemoteArpMacAddressCommand = "Set-ItemProperty " +
"-Path HKLM:\\SYSTEM\\CurrentControlSet\\Services\\hns\\State -Name SDNRemoteArpMacAddress -Value \"12-34-56-78-9a-bc\""

// Command to check if system has hns state path or not
CheckIfHNSStatePathExistsCommand = "Test-Path " +
"-Path HKLM:\\SYSTEM\\CurrentControlSet\\Services\\hns\\State"

// Command to restart HNS service
RestartHnsServiceCommand = "Restart-Service -Name hns"
// Command to fetch netadapter and pnp id
// TODO: can we replace this (and things in endpoint_windows) with other utils from "golang.org/x/sys/windows"?
GetMacAddressVFPPnpIDMapping = "Get-NetAdapter | Select-Object MacAddress, PnpDeviceID| Format-Table -HideTableHeaders"

// Interval between successive checks for mellanox adapter's PriorityVLANTag value
defaultMellanoxMonitorInterval = 30 * time.Second
Expand Down Expand Up @@ -195,40 +187,73 @@ func (p *execClient) ExecutePowershellCommand(command string) (string, error) {
}

// SetSdnRemoteArpMacAddress sets the regkey for SDNRemoteArpMacAddress needed for multitenancy if hns is enabled
func SetSdnRemoteArpMacAddress(execClient ExecClient) error {
exists, err := execClient.ExecutePowershellCommand(CheckIfHNSStatePathExistsCommand)
func SetSdnRemoteArpMacAddress(ctx context.Context) error {
log.Printf("Setting SDNRemoteArpMacAddress regKey")
// open the registry key
k, err := registry.OpenKey(registry.LOCAL_MACHINE, `SYSTEM\CurrentControlSet\Services\hns\State`, registry.READ|registry.SET_VALUE)
if err != nil {
errMsg := fmt.Sprintf("Failed to check the existent of hns state path due to error %s", err.Error())
log.Printf(errMsg)
return errors.Errorf(errMsg)
if errors.Is(err, registry.ErrNotExist) {
return nil
}
return errors.Wrap(err, "could not open registry key")
}
if strings.EqualFold(exists, "false") {
log.Printf("hns state path does not exist, skip setting SdnRemoteArpMacAddress")
return nil
defer k.Close()
// check the key value
if v, _, _ := k.GetStringValue("SDNRemoteArpMacAddress"); v == SDNRemoteArpMacAddress {
log.Printf("SDNRemoteArpMacAddress regKey already set")
return nil // already set
}
if err = k.SetStringValue("SDNRemoteArpMacAddress", SDNRemoteArpMacAddress); err != nil {
return errors.Wrap(err, "could not set registry key")
}
log.Printf("SDNRemoteArpMacAddress regKey set successfully")
log.Printf("Restarting HNS service")
// connect to the service manager
m, err := mgr.Connect()
if err != nil {
return errors.Wrap(err, "could not connect to service manager")
}
defer m.Disconnect() //nolint:errcheck // ignore error
// open the HNS service
service, err := m.OpenService("hns")
if err != nil {
return errors.Wrap(err, "could not access service")
}
defer service.Close()
if err := restartService(ctx, service); err != nil {
return errors.Wrap(err, "could not restart service")
}
log.Printf("HNS service restarted successfully")
return nil
}

func restartService(ctx context.Context, s *mgr.Service) error {
// Stop the service
_, err := s.Control(svc.Stop)
if err != nil {
return errors.Wrap(err, "could not stop service")
}
if sdnRemoteArpMacAddressSet == false {
result, err := execClient.ExecutePowershellCommand(GetSdnRemoteArpMacAddressCommand)
// Wait for the service to stop
ticker := time.NewTicker(500 * time.Millisecond) //nolint:gomnd // 500ms
defer ticker.Stop()
for { // hacky cancellable do-while
status, err := s.Query()
if err != nil {
return err
return errors.Wrap(err, "could not query service status")
}

// Set the reg key if not already set or has incorrect value
if result != SDNRemoteArpMacAddress {
if _, err = execClient.ExecutePowershellCommand(SetSdnRemoteArpMacAddressCommand); err != nil {
log.Printf("Failed to set SDNRemoteArpMacAddress due to error %s", err.Error())
return err
}

log.Printf("[Azure CNS] SDNRemoteArpMacAddress regKey set successfully. Restarting hns service.")
if _, err := execClient.ExecutePowershellCommand(RestartHnsServiceCommand); err != nil {
log.Printf("Failed to Restart HNS Service due to error %s", err.Error())
return err
}
if status.State == svc.Stopped {
break
}
select {
case <-ctx.Done():
return errors.New("context cancelled")
case <-ticker.C:
}

sdnRemoteArpMacAddressSet = true
}

// Start the service again
if err := s.Start(); err != nil {
return errors.Wrap(err, "could not start service")
}
return nil
}

Expand Down
62 changes: 36 additions & 26 deletions platform/os_windows_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package platform
import (
"errors"
"os/exec"
"strings"
"testing"

"github.com/Azure/azure-container-networking/platform/windows/adapter/mocks"
Expand Down Expand Up @@ -100,37 +99,48 @@ func TestExecuteCommandError(t *testing.T) {
assert.Equal(t, 1, xErr.ExitCode())
}

func TestSetSdnRemoteArpMacAddress_hnsNotEnabled(t *testing.T) {
mockExecClient := NewMockExecClient(false)
// testing skip setting SdnRemoteArpMacAddress when hns not enabled
mockExecClient.SetPowershellCommandResponder(func(_ string) (string, error) {
return "False", nil
})
err := SetSdnRemoteArpMacAddress(mockExecClient)
assert.NoError(t, err)
assert.Equal(t, false, sdnRemoteArpMacAddressSet)
func TestExecuteCommand(t *testing.T) {
_, err := NewExecClient(nil).ExecuteCommand(context.Background(), "ping", "localhost")
if err != nil {
t.Errorf("TestExecuteCommand failed with error %v", err)
}
}

// testing the scenario when there is an error in checking if hns is enabled or not
mockExecClient.SetPowershellCommandResponder(func(_ string) (string, error) {
return "", errTestFailure
})
err = SetSdnRemoteArpMacAddress(mockExecClient)
assert.ErrorAs(t, err, &errTestFailure)
assert.Equal(t, false, sdnRemoteArpMacAddressSet)
func TestExecuteCommandError(t *testing.T) {
_, err := NewExecClient(nil).ExecuteCommand(context.Background(), "dontaddtopath")
require.Error(t, err)
require.ErrorIs(t, err, exec.ErrNotFound)
}

func TestSetSdnRemoteArpMacAddress_hnsEnabled(t *testing.T) {
func TestFetchPnpIDMapping(t *testing.T) {
mockExecClient := NewMockExecClient(false)
// happy path
mockExecClient.SetPowershellCommandResponder(func(cmd string) (string, error) {
if strings.Contains(cmd, "Test-Path") {
return "True", nil
}
return "6C-A1-00-50-E4-2D PCI\\VEN_8086&DEV_2723&SUBSYS_00808086&REV_1A\\4&328243d9&0&00E0\n80-6D-97-1E-CF-4E USB\\VID_17EF&PID_A359\\3010019E3", nil
})
vfmapping, _ := FetchMacAddressPnpIDMapping(context.Background(), mockExecClient)
require.Len(t, vfmapping, 2)

// Test when no adapters are found
mockExecClient.SetPowershellCommandResponder(func(cmd string) (string, error) {
return "", nil
})
err := SetSdnRemoteArpMacAddress(mockExecClient)
assert.NoError(t, err)
assert.Equal(t, true, sdnRemoteArpMacAddressSet)
// reset sdnRemoteArpMacAddressSet
sdnRemoteArpMacAddressSet = false
vfmapping, _ = FetchMacAddressPnpIDMapping(context.Background(), mockExecClient)
require.Empty(t, vfmapping)
// Adding carriage returns
mockExecClient.SetPowershellCommandResponder(func(cmd string) (string, error) {
return "6C-A1-00-50-E4-2D PCI\\VEN_8086&DEV_2723&SUBSYS_00808086&REV_1A\\4&328243d9&0&00E0\r\n\r80-6D-97-1E-CF-4E USB\\VID_17EF&PID_A359\\3010019E3", nil
})

vfmapping, _ = FetchMacAddressPnpIDMapping(context.Background(), mockExecClient)
require.Len(t, vfmapping, 2)
}

// ping -t localhost will ping indefinitely and should exceed the 5 second timeout
func TestExecuteCommandTimeout(t *testing.T) {
const timeout = 5 * time.Second
client := NewExecClientTimeout(timeout)

_, err := client.ExecuteCommand(context.Background(), "ping", "-t", "localhost")
require.Error(t, err)
}
2 changes: 1 addition & 1 deletion test/validate/validate.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ func (v *Validator) ValidateStateFile(ctx context.Context) error {
}

func (v *Validator) validateIPs(ctx context.Context, stateFileIps stateFileIpsFunc, cmd []string, checkType, namespace, labelSelector string) error {
log.Printf("Validating %s state file", checkType)
log.Printf("Validating %s state file for %s on %s", checkType, v.cni, v.os)
nodes, err := acnk8s.GetNodeListByLabelSelector(ctx, v.clientset, nodeSelectorMap[v.os])
if err != nil {
return errors.Wrapf(err, "failed to get node list")
Expand Down

0 comments on commit 887c445

Please sign in to comment.