diff --git a/cns/middlewares/k8sSwiftV2_windows.go b/cns/middlewares/k8sSwiftV2_windows.go index 6c3840d61e..8dfaba035c 100644 --- a/cns/middlewares/k8sSwiftV2_windows.go +++ b/cns/middlewares/k8sSwiftV2_windows.go @@ -14,9 +14,9 @@ import ( "github.com/pkg/errors" ) -var defaultDenyEgressPolicy policy.Policy = getEndpointPolicy(cns.DirectionTypeOut) +var defaultDenyEgressPolicy policy.Policy = mustGetEndpointPolicy(cns.DirectionTypeOut) -var defaultDenyIngressPolicy policy.Policy = getEndpointPolicy(cns.DirectionTypeIn) +var defaultDenyIngressPolicy policy.Policy = mustGetEndpointPolicy(cns.DirectionTypeIn) // for AKS L1VH, do not set default route on infraNIC to avoid customer pod reaching all infra vnet services // default route is set for secondary interface NIC(i.e,delegatedNIC) @@ -70,20 +70,32 @@ func (k *K8sSWIFTv2Middleware) addDefaultRoute(podIPInfo *cns.PodIpInfo, gwIP st podIPInfo.Routes = append(podIPInfo.Routes, route) } +func mustGetEndpointPolicy(direction string) policy.Policy { + endpointPolicy, err := getEndpointPolicy(direction) + if err != nil { + panic(err) + } + return endpointPolicy +} + // get policy of type endpoint policy given the params -func getEndpointPolicy(direction string) policy.Policy { - endpointPolicy := createEndpointPolicy(direction) +func getEndpointPolicy(direction string) (policy.Policy, error) { + endpointPolicy, err := createEndpointPolicy(direction) + + if err != nil { + return policy.Policy{}, fmt.Errorf("error creating endpoint policy: %w", err) + } additionalArgs := policy.Policy{ Type: policy.EndpointPolicy, Data: endpointPolicy, } - return additionalArgs + return additionalArgs, nil } // create policy given the params -func createEndpointPolicy(direction string) []byte { +func createEndpointPolicy(direction string) ([]byte, error) { endpointPolicy := struct { Type string `json:"Type"` Action string `json:"Action"` @@ -98,10 +110,10 @@ func createEndpointPolicy(direction string) []byte { rawPolicy, err := json.Marshal(endpointPolicy) if err != nil { - logger.Errorf("error marshalling policy to json, err is: %v", err) + return nil, fmt.Errorf("error marshalling policy to json, err is: %w", err) } - return rawPolicy + return rawPolicy, nil } // IPConfigsRequestHandlerWrapper is the middleware function for handling SWIFT v2 IP configs requests for AKS-SWIFT. This function wrapped the default SWIFT request @@ -136,7 +148,7 @@ func (k *K8sSWIFTv2Middleware) IPConfigsRequestHandlerWrapper(defaultHandler, fa } // GetDefaultDenyBool takes in mtpnc and returns the value of defaultDenyACLBool from it - defaultDenyACLBool, err := GetDefaultDenyBool(mtpnc) + defaultDenyACLBool := GetDefaultDenyBool(mtpnc) // ipConfigsResp has infra IP configs -> if defaultDenyACLbool is enabled, add the default deny endpoint policies as a property in PodIpInfo for i := range ipConfigsResp.PodIPInfo { @@ -193,7 +205,7 @@ func (k *K8sSWIFTv2Middleware) IPConfigsRequestHandlerWrapper(defaultHandler, fa } } -func GetDefaultDenyBool(mtpnc v1alpha1.MultitenantPodNetworkConfig) (bool, error) { +func GetDefaultDenyBool(mtpnc v1alpha1.MultitenantPodNetworkConfig) bool { // returns the value of DefaultDenyACL from mtpnc - return mtpnc.Status.DefaultDenyACL, nil + return mtpnc.Status.DefaultDenyACL } diff --git a/cns/middlewares/k8sSwiftV2_windows_test.go b/cns/middlewares/k8sSwiftV2_windows_test.go index 36a02c1d7b..a7e277cfc6 100644 --- a/cns/middlewares/k8sSwiftV2_windows_test.go +++ b/cns/middlewares/k8sSwiftV2_windows_test.go @@ -10,6 +10,7 @@ import ( "github.com/Azure/azure-container-networking/cns/middlewares/mock" "github.com/Azure/azure-container-networking/crd/multitenancy/api/v1alpha1" "github.com/Azure/azure-container-networking/network/policy" + "github.com/google/go-cmp/cmp" "github.com/stretchr/testify/require" "gotest.tools/v3/assert" ) @@ -150,16 +151,16 @@ func TestAddDefaultDenyACL(t *testing.T) { var defaultDenyEgressPolicy, defaultDenyIngressPolicy policy.Policy var err error - defaultDenyEgressPolicy = getEndpointPolicy("Out") - defaultDenyIngressPolicy = getEndpointPolicy("In") + defaultDenyEgressPolicy = mustGetEndpointPolicy("Out") + defaultDenyIngressPolicy = mustGetEndpointPolicy("In") allEndpoints = append(allEndpoints, defaultDenyEgressPolicy, defaultDenyIngressPolicy) // Normalize both slices so there is no extra spacing, new lines, etc normalizedExpected := normalizeKVPairs(t, expectedDefaultDenyEndpoint) normalizedActual := normalizeKVPairs(t, allEndpoints) - if !reflect.DeepEqual(normalizedExpected, normalizedActual) { - t.Errorf("got '%+v', expected '%+v'", normalizedActual, normalizedExpected) + if !cmp.Equal(normalizedExpected, normalizedActual) { + t.Error("received policy differs from expectation: diff", cmp.Diff(normalizedExpected, normalizedActual)) } assert.Equal(t, err, nil) }