diff --git a/internal/guest/runtime/hcsv2/uvm.go b/internal/guest/runtime/hcsv2/uvm.go index 8902bead78..fdac01431a 100644 --- a/internal/guest/runtime/hcsv2/uvm.go +++ b/internal/guest/runtime/hcsv2/uvm.go @@ -56,6 +56,21 @@ import ( // for V2 where the specific message is targeted at the UVM itself. const UVMContainerID = "00000000-0000-0000-0000-000000000000" +var ( + // scsiActualControllerNumberFn is the function to retrieves the actual controller + // number assigned to a SCSI controller. + scsiActualControllerNumberFn = scsi.ActualControllerNumber + // scsiGetDevicePathFn is the function to retrieves the device path for a SCSI device. + scsiGetDevicePathFn = scsi.GetDevicePath + // scsiMountFn is the function to mount a SCSI device. + scsiMountFn = scsi.Mount + // scsiUnmountFn is the function to unmount a SCSI device. + scsiUnmountFn = scsi.Unmount + // readVeritySuperBlockFn is the function to read ext4 super block + // for a given VHD to then further read the dm-verity super block and root hash. + readVeritySuperBlockFn = verity.ReadVeritySuperBlock +) + // Host is the structure tracking all UVM host state including all containers // and processes. type Host struct { @@ -566,7 +581,7 @@ func (h *Host) modifyHostSettings(ctx context.Context, containerID string, req * // find the actual controller number on the bus and update the incoming request. var cNum uint8 - cNum, err := scsi.ActualControllerNumber(ctx, mvd.Controller) + cNum, err := scsiActualControllerNumberFn(ctx, mvd.Controller) if err != nil { return err } @@ -575,7 +590,7 @@ func (h *Host) modifyHostSettings(ctx context.Context, containerID string, req * if !mvd.ReadOnly { localCtx, cancel := context.WithTimeout(ctx, time.Second*5) defer cancel() - source, err := scsi.GetDevicePath(localCtx, mvd.Controller, mvd.Lun, mvd.Partition) + source, err := scsiGetDevicePathFn(localCtx, mvd.Controller, mvd.Lun, mvd.Partition) if err != nil { return err } @@ -1018,11 +1033,11 @@ func modifyMappedVirtualDisk( // it is a block device meant for a container mount. In the latter case, // we don't want to check the verity information. if len(securityPolicy.EncodedSecurityPolicy()) > 0 && enforcePolicy { - devPath, err := scsi.GetDevicePath(ctx, mvd.Controller, mvd.Lun, mvd.Partition) + devPath, err := scsiGetDevicePathFn(ctx, mvd.Controller, mvd.Lun, mvd.Partition) if err != nil { return err } - verityInfo, err = verity.ReadVeritySuperBlock(ctx, devPath) + verityInfo, err = readVeritySuperBlockFn(ctx, devPath) if err != nil { return err } @@ -1049,7 +1064,7 @@ func modifyMappedVirtualDisk( Filesystem: mvd.Filesystem, BlockDev: mvd.BlockDev, } - return scsi.Mount(mountCtx, mvd.Controller, mvd.Lun, mvd.Partition, mvd.MountPath, + return scsiMountFn(mountCtx, mvd.Controller, mvd.Lun, mvd.Partition, mvd.MountPath, mvd.ReadOnly, mvd.Options, config) } return nil @@ -1067,7 +1082,7 @@ func modifyMappedVirtualDisk( Filesystem: mvd.Filesystem, BlockDev: mvd.BlockDev, } - if err := scsi.Unmount(ctx, mvd.Controller, mvd.Lun, mvd.Partition, + if err := scsiUnmountFn(ctx, mvd.Controller, mvd.Lun, mvd.Partition, mvd.MountPath, config); err != nil { return err } @@ -1116,7 +1131,7 @@ func modifyMappedVPMemDevice(ctx context.Context, if vpd.MappingInfo != nil { return fmt.Errorf("multi mapping is not supported with verity") } - verityInfo, err = verity.ReadVeritySuperBlock(ctx, pmem.GetDevicePath(vpd.DeviceNumber)) + verityInfo, err = readVeritySuperBlockFn(ctx, pmem.GetDevicePath(vpd.DeviceNumber)) if err != nil { return err } diff --git a/internal/guest/runtime/hcsv2/uvm_test.go b/internal/guest/runtime/hcsv2/uvm_test.go new file mode 100644 index 0000000000..5a0d7861cd --- /dev/null +++ b/internal/guest/runtime/hcsv2/uvm_test.go @@ -0,0 +1,153 @@ +//go:build linux +// +build linux + +package hcsv2 + +import ( + "context" + "fmt" + "os" + "strings" + "testing" + + "github.com/Microsoft/hcsshim/internal/guest/storage/scsi" + "github.com/Microsoft/hcsshim/internal/guestpath" + "github.com/Microsoft/hcsshim/internal/protocol/guestrequest" + "github.com/Microsoft/hcsshim/internal/protocol/guestresource" + "github.com/Microsoft/hcsshim/pkg/securitypolicy" +) + +const ( + testContainerID = "test-container" + testDevicePath = "/dev/sde" + testMountPathOnGuest = "/mount/path" +) + +var testMountPathForContainerDeviceOnGuest = fmt.Sprintf(guestpath.LCOWSCSIMountPrefixFmt, 3) + +func Test_ModifyHostSettings_VirtualDisk(t *testing.T) { + tests := []struct { + name string + requestType guestrequest.RequestType + mountPath string + containerMount bool + readonly bool + expectError bool + errorMessage string + }{ + { + name: "ValidMountOnGuest_Add_RW", + requestType: guestrequest.RequestTypeAdd, + mountPath: testMountPathOnGuest, + containerMount: false, + readonly: false, + expectError: false, + errorMessage: "", + }, + { + name: "ValidMountOnGuest_Add_RO", + requestType: guestrequest.RequestTypeAdd, + mountPath: testMountPathOnGuest, + containerMount: false, + readonly: true, + expectError: false, + errorMessage: "", + }, + { + name: "ValidMountOnGuest_ContainerDevice_Add_RW", + requestType: guestrequest.RequestTypeAdd, + mountPath: testMountPathForContainerDeviceOnGuest, + containerMount: true, + readonly: false, + expectError: false, + errorMessage: "", + }, + { + name: "ValidMountOnGuest_ContainerDevice_Add_RO", + requestType: guestrequest.RequestTypeAdd, + mountPath: testMountPathForContainerDeviceOnGuest, + containerMount: true, + readonly: true, + expectError: false, + errorMessage: "", + }, + { + name: "ValidMountOnGuest_Remove", + requestType: guestrequest.RequestTypeRemove, + mountPath: testMountPathForContainerDeviceOnGuest, + expectError: false, + errorMessage: "", + }, + { + name: "InvalidMountOnGuest_ContainerDevice", + requestType: guestrequest.RequestTypeAdd, + mountPath: "/invalid/mount/path", + containerMount: true, + expectError: true, + errorMessage: "invalid mount path inside guest", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h := NewHost(nil, nil, &securitypolicy.OpenDoorSecurityPolicyEnforcer{}, os.Stdout) + ctx := context.Background() + + // Mock functions + scsiActualControllerNumberFn = func(ctx context.Context, controller uint8) (uint8, error) { + return controller, nil + } + scsiGetDevicePathFn = func(ctx context.Context, controller uint8, lun uint8, partition uint64) (string, error) { + return testDevicePath, nil + } + scsiMountFn = func(ctx context.Context, controller uint8, lun uint8, partition uint64, mountPath string, readOnly bool, options []string, config *scsi.Config) error { + return nil + } + scsiUnmountFn = func(ctx context.Context, controller uint8, lun uint8, partition uint64, mountPath string, config *scsi.Config) error { + return nil + } + // Restore the original functions after the test. + defer func() { + scsiActualControllerNumberFn = scsi.ActualControllerNumber + scsiGetDevicePathFn = scsi.GetDevicePath + scsiMountFn = scsi.Mount + scsiUnmountFn = scsi.Unmount + }() + + // Create the modification request. + req := &guestrequest.ModificationRequest{ + ResourceType: guestresource.ResourceTypeMappedVirtualDisk, + RequestType: guestrequest.RequestTypeAdd, + Settings: &guestresource.LCOWMappedVirtualDisk{ + ReadOnly: tt.readonly, + ContainerMount: tt.containerMount, + Controller: 0, + Lun: 0, + Partition: 1, + Encrypted: false, + MountPath: tt.mountPath, + }, + } + + // Run the test. + err := h.modifyHostSettings(ctx, testContainerID, req) + if err != nil { + // If an error was expected then validate the error message. + if tt.expectError && !strings.Contains(err.Error(), tt.errorMessage) { + t.Fatalf("expected error %s, got: %v", tt.errorMessage, err) + } + + // If the error was not expected, then fail the test. + if !tt.expectError { + t.Fatalf("expected no error, got: %v", err) + } + } + + if err == nil { + if tt.expectError { + t.Fatalf("expected error %s but got nil", tt.errorMessage) + } + } + }) + } +}