diff --git a/src/include/kernel_utils.h b/src/include/kernel_utils.h index dfc29d2..5da5e21 100644 --- a/src/include/kernel_utils.h +++ b/src/include/kernel_utils.h @@ -46,9 +46,8 @@ inline int GET_BLOCKS(const int64_t N, const int64_t max_threads_per_block = CUD // Dispatches for float and double #define AT_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ [&] { \ - const auto& the_type = TYPE; \ + const at::ScalarType _st = TYPE; \ /* don't use TYPE again in case it is an expensive or side-effect op */ \ - at::ScalarType _st = ::detail::scalar_type(the_type); \ switch (_st) { \ AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \ @@ -60,9 +59,8 @@ inline int GET_BLOCKS(const int64_t N, const int64_t max_threads_per_block = CUD // Dispatches for float, double, and half #define AT_DISPATCH_FLOATING_TYPES_AND_HALF(TYPE, NAME, ...) \ [&] { \ - const auto& the_type = TYPE; \ + const at::ScalarType _st = TYPE; \ /* don't use TYPE again in case it is an expensive or side-effect op */ \ - at::ScalarType _st = ::detail::scalar_type(the_type); \ switch (_st) { \ AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \