diff --git a/CMakeLists.txt b/CMakeLists.txt index 91110864..a2098e34 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -122,9 +122,11 @@ else() set(DORADO_ENABLE_PCH TRUE) endif() -if(CMAKE_SYSTEM_NAME STREQUAL "Linux") - if((CMAKE_SYSTEM_PROCESSOR MATCHES "^aarch64*|^arm*") AND (${CUDAToolkit_VERSION} VERSION_LESS 11.0)) +if((CMAKE_SYSTEM_NAME STREQUAL "Linux") AND (CMAKE_SYSTEM_PROCESSOR MATCHES "^aarch64*|^arm*")) + if(${CUDAToolkit_VERSION} VERSION_LESS 11.0) add_compile_definitions(DORADO_TX2) + else() + add_compile_definitions(DORADO_ORIN) endif() endif() diff --git a/dorado/torch_utils/auto_detect_device.h b/dorado/torch_utils/auto_detect_device.h index 76f8851f..4a6c2135 100644 --- a/dorado/torch_utils/auto_detect_device.h +++ b/dorado/torch_utils/auto_detect_device.h @@ -1,6 +1,8 @@ #pragma once #if DORADO_CUDA_BUILD +#include "torch_utils/gpu_monitor.h" + #include #endif @@ -12,7 +14,8 @@ inline std::string get_auto_detected_device() { #if DORADO_METAL_BUILD return "metal"; #elif DORADO_CUDA_BUILD - return torch::cuda::is_available() ? "cuda:all" : "cpu"; + // Using get_device_count will force a wait for NVML to load, which will ensure the driver has started up. + return utils::gpu_monitor::get_device_count() > 0 ? "cuda:all" : "cpu"; #else return "cpu"; #endif diff --git a/dorado/torch_utils/cuda_utils.cpp b/dorado/torch_utils/cuda_utils.cpp index be406d58..58ce6506 100644 --- a/dorado/torch_utils/cuda_utils.cpp +++ b/dorado/torch_utils/cuda_utils.cpp @@ -1,5 +1,6 @@ #include "cuda_utils.h" +#include "torch_utils/gpu_monitor.h" #include "utils/PostCondition.h" #include "utils/math_utils.h" @@ -220,8 +221,9 @@ bool try_parse_cuda_device_string(const std::string &device_string, std::vector &devices, std::string &error_message) { std::vector device_ids{}; - if (!details::try_parse_device_ids(device_string, torch::cuda::device_count(), device_ids, - error_message)) { + + if (!details::try_parse_device_ids(device_string, utils::gpu_monitor::get_device_count(), + device_ids, error_message)) { return false; } @@ -243,7 +245,7 @@ std::vector parse_cuda_device_string(const std::string &device_stri std::vector get_cuda_device_info(const std::string &device_string, bool include_unused) { - const auto num_devices = torch::cuda::device_count(); + const auto num_devices = utils::gpu_monitor::get_device_count(); std::string error_message{}; std::vector requested_device_ids{}; if (!details::try_parse_device_ids(device_string, num_devices, requested_device_ids, diff --git a/dorado/torch_utils/gpu_monitor.cpp b/dorado/torch_utils/gpu_monitor.cpp index e82bc455..d58df36a 100644 --- a/dorado/torch_utils/gpu_monitor.cpp +++ b/dorado/torch_utils/gpu_monitor.cpp @@ -15,11 +15,15 @@ #else // _WIN32 #include #endif // _WIN32 +#if defined(DORADO_ORIN) || defined(DORADO_TX2) +#include +#endif // defined(DORADO_ORIN) || defined(DORADO_TX2) #endif // HAS_NVML #include #include +#include #include #include #include @@ -164,7 +168,7 @@ class NvmlApi final { void init() { if (!platform_open() || !load_symbols()) { - spdlog::warn("Failed to load NVML"); + spdlog::info("Failed to load NVML"); clear_symbols(); platform_close(); return; @@ -172,9 +176,23 @@ class NvmlApi final { // Fall back to the original nvmlInit() if _v2 doesn't exist. auto *do_init = m_Init_v2 ? m_Init_v2 : m_Init; - nvmlReturn_t result = do_init(); + + // We retry initialisation for a certain amount of time, to allow the driver to load on system startup + auto start = std::chrono::system_clock::now(); + auto wait_seconds = std::chrono::seconds(10); + nvmlReturn_t result; + do { + result = do_init(); + if (result == NVML_SUCCESS) { + break; + } + spdlog::warn("Failed to initialize NVML: {}, retrying in 1s...", m_ErrorString(result)); + std::this_thread::sleep_for(std::chrono::seconds(1)); + } while ((std::chrono::system_clock::now() - start) < wait_seconds); + if (result != NVML_SUCCESS) { - spdlog::warn("Failed to initialize NVML: {}", m_ErrorString(result)); + spdlog::warn("Failed to initialize NVML after {} seconds: {}", wait_seconds.count(), + m_ErrorString(result)); clear_symbols(); platform_close(); } @@ -433,10 +451,11 @@ class DeviceInfoCache final { spdlog::warn("Call to DeviceGetCount failed: {}", m_nvml.ErrorString(result)); } } -#if defined(DORADO_TX2) +#if defined(DORADO_ORIN) || defined(DORADO_TX2) if (m_device_count == 0) { - // TX2 may not have NVML, in which case just report that we have 1. - m_device_count = 1; + // TX2/Orin may not have NVML, in which case ask torch how many devices it thinks there are. + m_device_count = torch::cuda::device_count(); + spdlog::info("Setting device count to {} as reported from torch", m_device_count); } #endif }