diff --git a/cmd/checkpoint/checkpoint.go b/cmd/checkpoint/checkpoint.go index 22cc0fdd..3729aa17 100644 --- a/cmd/checkpoint/checkpoint.go +++ b/cmd/checkpoint/checkpoint.go @@ -23,6 +23,8 @@ import ( "strings" checkpoint "github.com/NVIDIA/mig-parted/api/checkpoint/v1" + "github.com/NVIDIA/mig-parted/cmd/util" + "github.com/NVIDIA/mig-parted/internal/nvml" "github.com/NVIDIA/mig-parted/pkg/mig/state" "github.com/sirupsen/logrus" cli "github.com/urfave/cli/v2" @@ -83,6 +85,13 @@ func checkpointWrapper(c *cli.Context, f *Flags) error { return err } + nvml := nvml.New() + err = util.NvmlInit(nvml) + if err != nil { + return fmt.Errorf("error initializing NVML: %v", err) + } + defer util.TryNvmlShutdown(nvml) + migState, err := state.NewMigStateManager().Fetch() if err != nil { return fmt.Errorf("error fetching MIG state: %v", err) diff --git a/cmd/export/config.go b/cmd/export/config.go index 377b3838..07ca5a06 100644 --- a/cmd/export/config.go +++ b/cmd/export/config.go @@ -29,6 +29,12 @@ import ( ) func ExportMigConfigs(c *Context) (*v1.Spec, error) { + err := util.NvmlInit(c.Nvml) + if err != nil { + return nil, fmt.Errorf("error initializing NVML: %v", err) + } + defer util.TryNvmlShutdown(c.Nvml) + nvpci := nvpci.New() gpus, err := nvpci.GetGPUs() if err != nil { diff --git a/cmd/export/export.go b/cmd/export/export.go index e4a1ec6d..aece81fb 100644 --- a/cmd/export/export.go +++ b/cmd/export/export.go @@ -23,6 +23,7 @@ import ( "os" "github.com/NVIDIA/mig-parted/api/spec/v1" + "github.com/NVIDIA/mig-parted/internal/nvml" "github.com/sirupsen/logrus" cli "github.com/urfave/cli/v2" @@ -49,6 +50,7 @@ type Flags struct { type Context struct { *cli.Context Flags *Flags + Nvml nvml.Interface } func BuildCommand() *cli.Command { @@ -96,6 +98,7 @@ func exportWrapper(c *cli.Context, f *Flags) error { context := Context{ Context: c, Flags: f, + Nvml: nvml.New(), } spec, err := ExportMigConfigs(&context)