diff --git a/cns/middlewares/k8sSwiftV2.go b/cns/middlewares/k8sSwiftV2.go index a11290c205..5c7134fe4b 100644 --- a/cns/middlewares/k8sSwiftV2.go +++ b/cns/middlewares/k8sSwiftV2.go @@ -10,6 +10,7 @@ import ( "github.com/Azure/azure-container-networking/cns/middlewares/utils" "github.com/Azure/azure-container-networking/cns/types" "github.com/Azure/azure-container-networking/crd/multitenancy/api/v1alpha1" + "github.com/Azure/azure-container-networking/network/policy" "github.com/pkg/errors" v1 "k8s.io/api/core/v1" k8stypes "k8s.io/apimachinery/pkg/types" @@ -40,7 +41,9 @@ var _ cns.IPConfigsHandlerMiddleware = (*K8sSWIFTv2Middleware)(nil) // and release IP configs handlers. func (k *K8sSWIFTv2Middleware) IPConfigsRequestHandlerWrapper(defaultHandler, failureHandler cns.IPConfigsHandlerFunc) cns.IPConfigsHandlerFunc { return func(ctx context.Context, req cns.IPConfigsRequest) (*cns.IPConfigsResponse, error) { - podInfo, respCode, message := k.validateIPConfigsRequest(ctx, &req) + podInfo, respCode, defaultDenyACLbool, message := k.GetPodInfoForIPConfigsRequest(ctx, &req) + + logger.Printf("defaultDenyACLbool value is: %v", defaultDenyACLbool) if respCode != types.Success { return &cns.IPConfigsResponse{ @@ -55,6 +58,31 @@ func (k *K8sSWIFTv2Middleware) IPConfigsRequestHandlerWrapper(defaultHandler, fa if !req.SecondaryInterfacesExist { return ipConfigsResp, err } + + // ipConfigsResp has infra IP configs -> if defaultDenyACLbool is enabled, add the default deny endpoint policies as a property in PodIpInfo + for i := range ipConfigsResp.PodIPInfo { + ipInfo := &ipConfigsResp.PodIPInfo[i] + // there will be no pod connectivity to and from those pods + var defaultDenyEgressPolicy, defaultDenyIngressPolicy policy.Policy + + if defaultDenyACLbool && ipInfo.NICType == cns.InfraNIC { + defaultDenyEgressPolicy, err = getEndpointPolicy(string(policy.ACLPolicy), cns.ActionTypeBlock, cns.DirectionTypeOut, 10_000) + if err != nil { + logger.Errorf("failed to add default deny acl's for pod %v with err %v", podInfo.Name(), err) + } + + defaultDenyIngressPolicy, err = getEndpointPolicy(string(policy.ACLPolicy), cns.ActionTypeBlock, cns.DirectionTypeIn, 10_000) + if err != nil { + logger.Errorf("failed to add default deny acl's for pod %v with err %v", podInfo.Name(), err) + } + + ipInfo.EndpointPolicies = append(ipInfo.EndpointPolicies, defaultDenyEgressPolicy, defaultDenyIngressPolicy) + logger.Printf("Created endpoint policies for defaultDenyEgressPolicy and defaultDenyIngressPolicy") + + break + } + } + // If the pod is v2, get the infra IP configs from the handler first and then add the SWIFTv2 IP config defer func() { // Release the default IP config if there is an error @@ -100,21 +128,23 @@ func (k *K8sSWIFTv2Middleware) IPConfigsRequestHandlerWrapper(defaultHandler, fa } } -// validateIPConfigsRequest validates if pod is multitenant by checking the pod labels, used in SWIFT V2 AKS scenario. +// GetPodInfoForIPConfigsRequest validates if pod is multitenant by checking the pod labels, used in SWIFT V2 AKS scenario. // nolint -func (k *K8sSWIFTv2Middleware) validateIPConfigsRequest(ctx context.Context, req *cns.IPConfigsRequest) (podInfo cns.PodInfo, respCode types.ResponseCode, message string) { +func (k *K8sSWIFTv2Middleware) GetPodInfoForIPConfigsRequest(ctx context.Context, req *cns.IPConfigsRequest) (podInfo cns.PodInfo, respCode types.ResponseCode, defaultDenyACL bool, message string) { + defaultDenyACLbool := false + // Retrieve the pod from the cluster podInfo, err := cns.UnmarshalPodInfo(req.OrchestratorContext) if err != nil { errBuf := errors.Wrapf(err, "failed to unmarshalling pod info from ipconfigs request %+v", req) - return nil, types.UnexpectedError, errBuf.Error() + return nil, types.UnexpectedError, defaultDenyACLbool, errBuf.Error() } logger.Printf("[SWIFTv2Middleware] validate ipconfigs request for pod %s", podInfo.Name()) podNamespacedName := k8stypes.NamespacedName{Namespace: podInfo.Namespace(), Name: podInfo.Name()} pod := v1.Pod{} if err := k.Cli.Get(ctx, podNamespacedName, &pod); err != nil { errBuf := errors.Wrapf(err, "failed to get pod %+v", podNamespacedName) - return nil, types.UnexpectedError, errBuf.Error() + return nil, types.UnexpectedError, defaultDenyACLbool, errBuf.Error() } // check the pod labels for Swift V2, set the request's SecondaryInterfaceSet flag to true and check if its MTPNC CRD is ready @@ -126,12 +156,16 @@ func (k *K8sSWIFTv2Middleware) validateIPConfigsRequest(ctx context.Context, req mtpnc := v1alpha1.MultitenantPodNetworkConfig{} mtpncNamespacedName := k8stypes.NamespacedName{Namespace: podInfo.Namespace(), Name: podInfo.Name()} if err := k.Cli.Get(ctx, mtpncNamespacedName, &mtpnc); err != nil { - return nil, types.UnexpectedError, fmt.Errorf("failed to get pod's mtpnc from cache : %w", err).Error() + return nil, types.UnexpectedError, defaultDenyACLbool, fmt.Errorf("failed to get pod's mtpnc from cache : %w", err).Error() } // Check if the MTPNC CRD is ready. If one of the fields is empty, return error if !mtpnc.IsReady() { - return nil, types.UnexpectedError, errMTPNCNotReady.Error() + return nil, types.UnexpectedError, defaultDenyACLbool, errMTPNCNotReady.Error() } + + // copying defaultDenyACL bool from mtpnc + defaultDenyACLbool = mtpnc.Status.DefaultDenyACL + // If primary Ip is set in status field, it indicates the presence of secondary interfaces if mtpnc.Status.PrimaryIP != "" { req.SecondaryInterfacesExist = true @@ -140,7 +174,7 @@ func (k *K8sSWIFTv2Middleware) validateIPConfigsRequest(ctx context.Context, req for _, interfaceInfo := range interfaceInfos { if interfaceInfo.DeviceType == v1alpha1.DeviceTypeInfiniBandNIC { if interfaceInfo.MacAddress == "" || interfaceInfo.NCID == "" { - return nil, types.UnexpectedError, errMTPNCNotReady.Error() + return nil, types.UnexpectedError, defaultDenyACLbool, errMTPNCNotReady.Error() } req.BackendInterfaceExist = true req.BackendInterfaceMacAddresses = append(req.BackendInterfaceMacAddresses, interfaceInfo.MacAddress) @@ -154,7 +188,7 @@ func (k *K8sSWIFTv2Middleware) validateIPConfigsRequest(ctx context.Context, req logger.Printf("[SWIFTv2Middleware] pod %s has secondary interface : %v", podInfo.Name(), req.SecondaryInterfacesExist) logger.Printf("[SWIFTv2Middleware] pod %s has backend interface : %v", podInfo.Name(), req.BackendInterfaceExist) // retrieve podinfo from orchestrator context - return podInfo, types.Success, "" + return podInfo, types.Success, defaultDenyACLbool, "" } // getIPConfig returns the pod's SWIFT V2 IP configuration. diff --git a/cns/middlewares/k8sSwiftV2_linux.go b/cns/middlewares/k8sSwiftV2_linux.go index e9a93de0e2..646c3bbd93 100644 --- a/cns/middlewares/k8sSwiftV2_linux.go +++ b/cns/middlewares/k8sSwiftV2_linux.go @@ -9,6 +9,7 @@ import ( "github.com/Azure/azure-container-networking/cns/logger" "github.com/Azure/azure-container-networking/cns/middlewares/utils" "github.com/Azure/azure-container-networking/crd/multitenancy/api/v1alpha1" + "github.com/Azure/azure-container-networking/network/policy" "github.com/pkg/errors" ) @@ -103,3 +104,7 @@ func (k *K8sSWIFTv2Middleware) assignSubnetPrefixLengthFields(_ *cns.PodIpInfo, } func (k *K8sSWIFTv2Middleware) addDefaultRoute(*cns.PodIpInfo, string) {} + +func getEndpointPolicy(_, _, _ string, _ int) (policy.Policy, error) { + return policy.Policy{}, nil +} diff --git a/cns/middlewares/k8sSwiftV2_linux_test.go b/cns/middlewares/k8sSwiftV2_linux_test.go index 76be6b2149..2c7f498ff2 100644 --- a/cns/middlewares/k8sSwiftV2_linux_test.go +++ b/cns/middlewares/k8sSwiftV2_linux_test.go @@ -144,7 +144,7 @@ func TestValidateMultitenantIPConfigsRequestSuccess(t *testing.T) { happyReq.OrchestratorContext = b happyReq.SecondaryInterfacesExist = false - _, respCode, err := middleware.validateIPConfigsRequest(context.TODO(), happyReq) + _, respCode, _, err := middleware.GetPodInfoForIPConfigsRequest(context.TODO(), happyReq) assert.Equal(t, err, "") assert.Equal(t, respCode, types.Success) assert.Equal(t, happyReq.SecondaryInterfacesExist, true) @@ -158,7 +158,7 @@ func TestValidateMultitenantIPConfigsRequestSuccess(t *testing.T) { happyReq2.OrchestratorContext = b happyReq2.SecondaryInterfacesExist = false - _, respCode, err = middleware.validateIPConfigsRequest(context.TODO(), happyReq2) + _, respCode, _, err = middleware.GetPodInfoForIPConfigsRequest(context.TODO(), happyReq2) assert.Equal(t, err, "") assert.Equal(t, respCode, types.Success) assert.Equal(t, happyReq.SecondaryInterfacesExist, true) @@ -172,7 +172,7 @@ func TestValidateMultitenantIPConfigsRequestSuccess(t *testing.T) { happyReq3.OrchestratorContext = b happyReq3.SecondaryInterfacesExist = false - _, respCode, err = middleware.validateIPConfigsRequest(context.TODO(), happyReq3) + _, respCode, _, err = middleware.GetPodInfoForIPConfigsRequest(context.TODO(), happyReq3) assert.Equal(t, err, "") assert.Equal(t, respCode, types.Success) assert.Equal(t, happyReq3.SecondaryInterfacesExist, false) @@ -188,7 +188,7 @@ func TestValidateMultitenantIPConfigsRequestFailure(t *testing.T) { InfraContainerID: testPod1Info.InfraContainerID(), } failReq.OrchestratorContext = []byte("invalid") - _, respCode, _ := middleware.validateIPConfigsRequest(context.TODO(), failReq) + _, respCode, _, _ := middleware.GetPodInfoForIPConfigsRequest(context.TODO(), failReq) assert.Equal(t, respCode, types.UnexpectedError) // Pod doesn't exist in cache test @@ -198,19 +198,19 @@ func TestValidateMultitenantIPConfigsRequestFailure(t *testing.T) { } b, _ := testPod2Info.OrchestratorContext() failReq.OrchestratorContext = b - _, respCode, _ = middleware.validateIPConfigsRequest(context.TODO(), failReq) + _, respCode, _, _ = middleware.GetPodInfoForIPConfigsRequest(context.TODO(), failReq) assert.Equal(t, respCode, types.UnexpectedError) // Failed to get MTPNC b, _ = testPod3Info.OrchestratorContext() failReq.OrchestratorContext = b - _, respCode, _ = middleware.validateIPConfigsRequest(context.TODO(), failReq) + _, respCode, _, _ = middleware.GetPodInfoForIPConfigsRequest(context.TODO(), failReq) assert.Equal(t, respCode, types.UnexpectedError) // MTPNC not ready b, _ = testPod4Info.OrchestratorContext() failReq.OrchestratorContext = b - _, respCode, _ = middleware.validateIPConfigsRequest(context.TODO(), failReq) + _, respCode, _, _ = middleware.GetPodInfoForIPConfigsRequest(context.TODO(), failReq) assert.Equal(t, respCode, types.UnexpectedError) } diff --git a/cns/middlewares/k8sSwiftV2_windows.go b/cns/middlewares/k8sSwiftV2_windows.go index 2be2fbd1df..3d56987083 100644 --- a/cns/middlewares/k8sSwiftV2_windows.go +++ b/cns/middlewares/k8sSwiftV2_windows.go @@ -1,9 +1,12 @@ package middlewares import ( + "encoding/json" + "github.com/Azure/azure-container-networking/cns" "github.com/Azure/azure-container-networking/cns/middlewares/utils" "github.com/Azure/azure-container-networking/crd/multitenancy/api/v1alpha1" + "github.com/Azure/azure-container-networking/network/policy" "github.com/pkg/errors" ) @@ -58,3 +61,42 @@ func (k *K8sSWIFTv2Middleware) addDefaultRoute(podIPInfo *cns.PodIpInfo, gwIP st } podIPInfo.Routes = append(podIPInfo.Routes, route) } + +// get policy of type endpoint policy given the params +func getEndpointPolicy(policyType, action, direction string, priority int) (policy.Policy, error) { + endpointPolicy, err := createEndpointPolicy(policyType, action, direction, priority) + if err != nil { + return policy.Policy{}, errors.Wrap(err, "failed to create endpoint policy") + } + + additionalArgs := policy.Policy{ + Type: policy.EndpointPolicy, + Data: endpointPolicy, + } + + return additionalArgs, nil +} + +// create policy given the params +func createEndpointPolicy(policyType, action, direction string, priority int) ([]byte, error) { + type EndpointPolicy struct { + Type string `json:"Type"` + Action string `json:"Action"` + Direction string `json:"Direction"` + Priority int `json:"Priority"` + } + + endpointPolicy := EndpointPolicy{ + Type: policyType, + Action: action, + Direction: direction, + Priority: priority, + } + + rawPolicy, err := json.Marshal(endpointPolicy) + if err != nil { + return nil, errors.Wrap(err, "error marshalling policy to json") + } + + return rawPolicy, nil +} diff --git a/cns/middlewares/k8sSwiftV2_windows_test.go b/cns/middlewares/k8sSwiftV2_windows_test.go index dab24685f9..5d03734007 100644 --- a/cns/middlewares/k8sSwiftV2_windows_test.go +++ b/cns/middlewares/k8sSwiftV2_windows_test.go @@ -1,12 +1,16 @@ package middlewares import ( + "encoding/json" + "fmt" "reflect" "testing" "github.com/Azure/azure-container-networking/cns" "github.com/Azure/azure-container-networking/cns/middlewares/mock" "github.com/Azure/azure-container-networking/crd/multitenancy/api/v1alpha1" + "github.com/Azure/azure-container-networking/network/policy" + "github.com/stretchr/testify/require" "gotest.tools/v3/assert" ) @@ -100,3 +104,92 @@ func TestAddDefaultRoute(t *testing.T) { t.Errorf("got '%+v', expected '%+v'", ipInfo.Routes, expectedRoutes) } } + +func TestAddDefaultDenyACL(t *testing.T) { + const policyType = "ACL" + const action = "Block" + const ingressDir = "In" + const egressDir = "Out" + const priority = 10000 + + valueIn := []byte(fmt.Sprintf(`{ + "Type": "%s", + "Action": "%s", + "Direction": "%s", + "Priority": %d + }`, + policyType, + action, + ingressDir, + priority, + )) + + valueOut := []byte(fmt.Sprintf(`{ + "Type": "%s", + "Action": "%s", + "Direction": "%s", + "Priority": %d + }`, + policyType, + action, + egressDir, + priority, + )) + + expectedDefaultDenyEndpoint := []policy.Policy{ + { + Type: policy.EndpointPolicy, + Data: valueOut, + }, + { + Type: policy.EndpointPolicy, + Data: valueIn, + }, + } + var allEndpoints []policy.Policy + var defaultDenyEgressPolicy, defaultDenyIngressPolicy policy.Policy + var err error + + defaultDenyEgressPolicy, err = getEndpointPolicy("ACL", "Block", "Out", 10_000) + if err != nil { + fmt.Printf("failed to create endpoint policy") + } + defaultDenyIngressPolicy, err = getEndpointPolicy("ACL", "Block", "In", 10_000) + if err != nil { + fmt.Printf("failed to create endpoint policy") + } + + allEndpoints = append(allEndpoints, defaultDenyEgressPolicy, defaultDenyIngressPolicy) + + // Normalize both slices so there is no extra spacing, new lines, etc + normalizedExpected := normalizeKVPairs(t, expectedDefaultDenyEndpoint) + normalizedActual := normalizeKVPairs(t, allEndpoints) + if !reflect.DeepEqual(normalizedExpected, normalizedActual) { + t.Errorf("got '%+v', expected '%+v'", normalizedActual, normalizedExpected) + } + assert.Equal(t, err, nil) +} + +// normalizeKVPairs normalizes the JSON values in the KV pairs by unmarshaling them into a map, then marshaling them back to compact JSON to remove any extra space, new lines, etc +func normalizeKVPairs(t *testing.T, policies []policy.Policy) []policy.Policy { + normalized := make([]policy.Policy, len(policies)) + + for i, kv := range policies { + var unmarshaledValue map[string]interface{} + // Unmarshal the Value into a map + err := json.Unmarshal(kv.Data, &unmarshaledValue) + require.NoError(t, err, "Failed to unmarshal JSON value") + + // Marshal it back to compact JSON + normalizedValue, err := json.Marshal(unmarshaledValue) + require.NoError(t, err, "Failed to re-marshal JSON value") + + // Replace Value with the normalized compact JSON + normalized[i] = policy.Policy{ + Type: policy.EndpointPolicy, + Data: normalizedValue, + } + } + + return normalized +}