Skip to content

Commit

Permalink
updated unit test returns
Browse files Browse the repository at this point in the history
  • Loading branch information
rejain456 committed Jan 21, 2025
1 parent 7913564 commit 0c43db6
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 15 deletions.
34 changes: 23 additions & 11 deletions cns/middlewares/k8sSwiftV2_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ import (
"github.com/pkg/errors"
)

var defaultDenyEgressPolicy policy.Policy = getEndpointPolicy(cns.DirectionTypeOut)
var defaultDenyEgressPolicy policy.Policy = mustGetEndpointPolicy(cns.DirectionTypeOut)

var defaultDenyIngressPolicy policy.Policy = getEndpointPolicy(cns.DirectionTypeIn)
var defaultDenyIngressPolicy policy.Policy = mustGetEndpointPolicy(cns.DirectionTypeIn)

// for AKS L1VH, do not set default route on infraNIC to avoid customer pod reaching all infra vnet services
// default route is set for secondary interface NIC(i.e,delegatedNIC)
Expand Down Expand Up @@ -70,20 +70,32 @@ func (k *K8sSWIFTv2Middleware) addDefaultRoute(podIPInfo *cns.PodIpInfo, gwIP st
podIPInfo.Routes = append(podIPInfo.Routes, route)
}

func mustGetEndpointPolicy(direction string) policy.Policy {
endpointPolicy, err := getEndpointPolicy(direction)
if err != nil {
panic(err)
}
return endpointPolicy
}

// get policy of type endpoint policy given the params
func getEndpointPolicy(direction string) policy.Policy {
endpointPolicy := createEndpointPolicy(direction)
func getEndpointPolicy(direction string) (policy.Policy, error) {
endpointPolicy, err := createEndpointPolicy(direction)

if err != nil {
return policy.Policy{}, fmt.Errorf("error creating endpoint policy: %w", err)
}

additionalArgs := policy.Policy{
Type: policy.EndpointPolicy,
Data: endpointPolicy,
}

return additionalArgs
return additionalArgs, nil
}

// create policy given the params
func createEndpointPolicy(direction string) []byte {
func createEndpointPolicy(direction string) ([]byte, error) {
endpointPolicy := struct {
Type string `json:"Type"`
Action string `json:"Action"`
Expand All @@ -98,10 +110,10 @@ func createEndpointPolicy(direction string) []byte {

rawPolicy, err := json.Marshal(endpointPolicy)
if err != nil {
logger.Errorf("error marshalling policy to json, err is: %v", err)
return nil, fmt.Errorf("error marshalling policy to json, err is: %w", err)
}

return rawPolicy
return rawPolicy, nil
}

// IPConfigsRequestHandlerWrapper is the middleware function for handling SWIFT v2 IP configs requests for AKS-SWIFT. This function wrapped the default SWIFT request
Expand Down Expand Up @@ -136,7 +148,7 @@ func (k *K8sSWIFTv2Middleware) IPConfigsRequestHandlerWrapper(defaultHandler, fa
}

// GetDefaultDenyBool takes in mtpnc and returns the value of defaultDenyACLBool from it
defaultDenyACLBool, err := GetDefaultDenyBool(mtpnc)
defaultDenyACLBool := GetDefaultDenyBool(mtpnc)

// ipConfigsResp has infra IP configs -> if defaultDenyACLbool is enabled, add the default deny endpoint policies as a property in PodIpInfo
for i := range ipConfigsResp.PodIPInfo {
Expand Down Expand Up @@ -193,7 +205,7 @@ func (k *K8sSWIFTv2Middleware) IPConfigsRequestHandlerWrapper(defaultHandler, fa
}
}

func GetDefaultDenyBool(mtpnc v1alpha1.MultitenantPodNetworkConfig) (bool, error) {
func GetDefaultDenyBool(mtpnc v1alpha1.MultitenantPodNetworkConfig) bool {
// returns the value of DefaultDenyACL from mtpnc
return mtpnc.Status.DefaultDenyACL, nil
return mtpnc.Status.DefaultDenyACL
}
9 changes: 5 additions & 4 deletions cns/middlewares/k8sSwiftV2_windows_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"github.com/Azure/azure-container-networking/cns/middlewares/mock"
"github.com/Azure/azure-container-networking/crd/multitenancy/api/v1alpha1"
"github.com/Azure/azure-container-networking/network/policy"
"github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/require"
"gotest.tools/v3/assert"
)
Expand Down Expand Up @@ -150,16 +151,16 @@ func TestAddDefaultDenyACL(t *testing.T) {
var defaultDenyEgressPolicy, defaultDenyIngressPolicy policy.Policy
var err error

defaultDenyEgressPolicy = getEndpointPolicy("Out")
defaultDenyIngressPolicy = getEndpointPolicy("In")
defaultDenyEgressPolicy = mustGetEndpointPolicy("Out")
defaultDenyIngressPolicy = mustGetEndpointPolicy("In")

allEndpoints = append(allEndpoints, defaultDenyEgressPolicy, defaultDenyIngressPolicy)

// Normalize both slices so there is no extra spacing, new lines, etc
normalizedExpected := normalizeKVPairs(t, expectedDefaultDenyEndpoint)
normalizedActual := normalizeKVPairs(t, allEndpoints)
if !reflect.DeepEqual(normalizedExpected, normalizedActual) {
t.Errorf("got '%+v', expected '%+v'", normalizedActual, normalizedExpected)
if !cmp.Equal(normalizedExpected, normalizedActual) {
t.Error("received policy differs from expectation: diff", cmp.Diff(normalizedExpected, normalizedActual))
}
assert.Equal(t, err, nil)
}
Expand Down

0 comments on commit 0c43db6

Please sign in to comment.