diff --git a/include/matx/generators/random.h b/include/matx/generators/random.h index 9cd89854..ef391b5d 100644 --- a/include/matx/generators/random.h +++ b/include/matx/generators/random.h @@ -36,6 +36,7 @@ #include #include #include +#include namespace matx { @@ -162,55 +163,7 @@ __MATX_INLINE__ __MATX_DEVICE__ void get_random(cuda::std::complex &val, } }; -/** - * Generates random numbers - * - * @tparam - * Type of random number - * - * Generate random numbers based on a size and seed. Uses the Philox 4x32 - * generator with 10 rounds. - */ -template class [[deprecated("Use random() operator instead of randomGenerator_t")]] randomGenerator_t { -private: - index_t total_threads_; - bool init_; - curandStatePhilox4_32_10_t *states_; - uint64_t seed_; - -public: - randomGenerator_t() = delete; - - /** - * Constructs a random number generator - * - * This call will allocate memory sufficiently large enough to store state of - * the RNG - * - * @param total_threads - * Number of random values to generate - * @param seed - * Seed for the RNG - */ - __MATX_INLINE__ randomGenerator_t(index_t total_threads, uint64_t seed) - : total_threads_(total_threads) - { -#ifdef __CUDACC__ - matxAlloc((void **)&states_, - total_threads_ * sizeof(curandStatePhilox4_32_10_t), - MATX_DEVICE_MEMORY); - - int threads = 128; - int blocks = static_cast((total_threads_ + threads - 1) / threads); - curand_setup_kernel<<>>(states_, seed, total_threads); -#endif - }; - /** - * Destroy the RNG and free all memory - */ - __MATX_INLINE__ ~randomGenerator_t() { matxFree(states_); } -}; namespace detail { @@ -237,6 +190,7 @@ namespace detail { index_t total_size_; mutable curandStatePhilox4_32_10_t *states_; uint64_t seed_; + mutable std::mt19937 rng_; mutable bool init_ = false; mutable bool device_; @@ -321,13 +275,7 @@ namespace detail { } else if constexpr (is_host_executor_v) { if (!init_) { - [[maybe_unused]] curandStatus_t ret; - - ret = curandCreateGeneratorHost(&gen_, CURAND_RNG_PSEUDO_MT19937); - MATX_ASSERT_STR_EXP(ret, CURAND_STATUS_SUCCESS, matxCudaError, "Failed to create random number generator"); - - ret = curandSetPseudoRandomGeneratorSeed(gen_, seed_); - MATX_ASSERT_STR_EXP(ret, CURAND_STATUS_SUCCESS, matxCudaError, "Error setting random seed"); + rng_.seed(seed_); // In the future we may allocate a buffer, but for now we generate a single number at a time // matxAlloc((void **)&val, total_size_ * sizeof(T), MATX_HOST_MEMORY, stream); @@ -410,44 +358,32 @@ namespace detail { std::is_same_v> ) { - + using inner_type = typename inner_op_type_t::type; + if (fParams_.dist_ == UNIFORM) { - if constexpr (std::is_same_v) { - curandGenerateUniform(gen_, &val, 1); + std::uniform_real_distribution uniform_real_dist{0.0, 1.0}; + + if constexpr (std::is_same_v || std::is_same_v) { + val = uniform_real_dist(rng_); } - else if constexpr (std::is_same_v) { - curandGenerateUniformDouble(gen_, &val, 1); - } - else if constexpr (std::is_same_v>) { - float *tmp = reinterpret_cast(&val); - curandGenerateUniform(gen_, &tmp[0], 1); - curandGenerateUniform(gen_, &tmp[1], 1); - } - else if constexpr (std::is_same_v>) { - double *tmp = reinterpret_cast(&val); - curandGenerateUniformDouble(gen_, &tmp[0], 1); - curandGenerateUniformDouble(gen_, &tmp[1], 1); + else if constexpr ( std::is_same_v> || + std::is_same_v>) { + val = {uniform_real_dist(rng_), uniform_real_dist(rng_)}; } val = fParams_.alpha_ * val + fParams_.beta_; } else if (fParams_.dist_ == NORMAL) { - if constexpr (std::is_same_v) { - curandGenerateNormal(gen_, &val, 1, fParams_.beta_, fParams_.alpha_); - } - else if constexpr (std::is_same_v) { - curandGenerateNormalDouble(gen_, &val, 1, fParams_.beta_, fParams_.alpha_); - } - else if constexpr (std::is_same_v>) { - float *tmp = reinterpret_cast(&val); - curandGenerateNormal(gen_, &tmp[0], 1, fParams_.beta_, fParams_.alpha_); - curandGenerateNormal(gen_, &tmp[1], 1, fParams_.beta_, fParams_.alpha_); + std::normal_distribution normal_dist(0.0, 1.0); + + if constexpr (std::is_same_v || std::is_same_v) { + val = normal_dist(rng_); } - else if constexpr (std::is_same_v>) { - double *tmp = reinterpret_cast(&val); - curandGenerateNormalDouble(gen_, &tmp[0], 1, fParams_.beta_, fParams_.alpha_); - curandGenerateNormalDouble(gen_, &tmp[1], 1, fParams_.beta_, fParams_.alpha_); + else if constexpr (std::is_same_v> || std::is_same_v>) { + val = {normal_dist(rng_), normal_dist(rng_)}; } + + val = fParams_.alpha_ * val + fParams_.beta_; } else { val = 0; @@ -460,13 +396,8 @@ namespace detail { std::is_same_v ) { - float fScale; - curandGenerateUniform(gen_, &fScale, 1); - - // Scale to the provided min and max range - double fMax = static_cast(iParams_.max_); - double fMin = static_cast(iParams_.min_); - val = static_cast(fScale * (fMax - fMin) + fMin); + std::uniform_int_distribution uniform_int_dist(iParams_.min_, iParams_.max_); + val = uniform_int_dist(rng_); } #endif