Skip to content

Commit

Permalink
feat: add rand-0_9 crate feature (#702)
Browse files Browse the repository at this point in the history
* feat: add rand-0_9 crate feature
* Added rand 0.9 test into build system

---------

Co-authored-by: Paul Mason <paul@paulmason.me>
  • Loading branch information
robjtede and paupino authored Feb 23, 2025
1 parent 9689971 commit 41ce632
Show file tree
Hide file tree
Showing 4 changed files with 192 additions and 0 deletions.
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ num-traits = { default-features = false, features = ["i128"], version = "0.2" }
postgres-types = { default-features = false, optional = true, version = "0.2" }
proptest = { default-features = false, optional = true, features = ["std"], version = "1.0" }
rand = { default-features = false, optional = true, version = "0.8" }
rand-0_9 = { default-features = false, optional = true, package = "rand", version = "0.9" }
rkyv = { default-features = false, features = ["size_32", "std"], optional = true, version = "0.7.42" }
rocket = { default-features = false, optional = true, version = "0.5.0-rc.3" }
serde = { default-features = false, optional = true, version = "1.0" }
Expand All @@ -42,6 +43,7 @@ criterion = { default-features = false, version = "0.5" }
csv = "1"
futures = { default-features = false, version = "0.3" }
rand = { default-features = false, features = ["getrandom"], version = "0.8" }
rand-0_9 = { default-features = false, features = ["thread_rng"], package = "rand", version = "0.9" }
rkyv-0_8 = { version = "0.8", package = "rkyv" }
serde = { default-features = false, features = ["derive"], version = "1.0" }
serde_json = "1.0"
Expand Down
10 changes: 10 additions & 0 deletions make/tests/misc.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,15 @@ command = "cargo"
args = ["test", "--workspace", "--features=rkyv", "--features=rkyv-safe", "rkyv_tests", "--", "--skip", "generated"]

[tasks.test-rand]
dependencies = [
"test-rand-0_8",
"test-rand-0_9"
]

[tasks.test-rand-0_8]
command = "cargo"
args = ["test", "--workspace", "--features=rand", "rand_tests", "--", "--skip", "generated"]

[tasks.test-rand-0_9]
command = "cargo"
args = ["test", "--workspace", "--features=rand-0_9", "rand_tests", "--", "--skip", "generated"]
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ mod postgres;
mod proptest;
#[cfg(feature = "rand")]
mod rand;
#[cfg(feature = "rand-0_9")]
mod rand_0_9;
#[cfg(feature = "rocket-traits")]
mod rocket;
#[cfg(all(
Expand Down
178 changes: 178 additions & 0 deletions src/rand_0_9.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
use crate::Decimal;
use rand_0_9::{
distr::{
uniform::{SampleBorrow, SampleUniform, UniformInt, UniformSampler},
Distribution, StandardUniform,
},
Rng,
};

impl Distribution<Decimal> for StandardUniform {
fn sample<R>(&self, rng: &mut R) -> Decimal
where
R: Rng + ?Sized,
{
Decimal::from_parts(
rng.next_u32(),
rng.next_u32(),
rng.next_u32(),
rng.random(),
rng.random_range(0..=Decimal::MAX_SCALE),
)
}
}

impl SampleUniform for Decimal {
type Sampler = DecimalSampler;
}

#[derive(Clone, Copy, Debug, PartialEq)]
pub struct DecimalSampler {
mantissa_sampler: UniformInt<i128>,
scale: u32,
}

impl UniformSampler for DecimalSampler {
type X = Decimal;

/// Creates a new sampler that will yield random decimal objects between `low` and `high`.
///
/// The sampler will always provide decimals at the same scale as the inputs; if the inputs
/// have different scales, the higher scale is used.
///
/// # Example
///
/// ```
/// # use rand_0_9 as rand;
/// # use rand::Rng;
/// # use rust_decimal_macros::dec;
/// let mut rng = rand::rng();
/// let random = rng.random_range(dec!(1.00)..dec!(2.00));
/// assert!(random >= dec!(1.00));
/// assert!(random < dec!(2.00));
/// assert_eq!(random.scale(), 2);
/// ```
#[inline]
fn new<B1, B2>(low: B1, high: B2) -> Result<Self, rand_0_9::distr::uniform::Error>
where
B1: SampleBorrow<Self::X> + Sized,
B2: SampleBorrow<Self::X> + Sized,
{
let (low, high) = sync_scales(*low.borrow(), *high.borrow());
let high = Decimal::from_i128_with_scale(high.mantissa() - 1, high.scale());
UniformSampler::new_inclusive(low, high)
}

/// Creates a new sampler that will yield random decimal objects between `low` and `high`.
///
/// The sampler will always provide decimals at the same scale as the inputs; if the inputs
/// have different scales, the higher scale is used.
///
/// # Example
///
/// ```
/// # use rand_0_9 as rand;
/// # use rand::Rng;
/// # use rust_decimal_macros::dec;
/// let mut rng = rand::rng();
/// let random = rng.random_range(dec!(1.00)..=dec!(2.00));
/// assert!(random >= dec!(1.00));
/// assert!(random <= dec!(2.00));
/// assert_eq!(random.scale(), 2);
/// ```
#[inline]
fn new_inclusive<B1, B2>(low: B1, high: B2) -> Result<Self, rand_0_9::distr::uniform::Error>
where
B1: SampleBorrow<Self::X> + Sized,
B2: SampleBorrow<Self::X> + Sized,
{
let (low, high) = sync_scales(*low.borrow(), *high.borrow());

// Return our sampler, which contains an underlying i128 sampler so we
// outsource the actual randomness implementation.
Ok(Self {
mantissa_sampler: UniformInt::new_inclusive(low.mantissa(), high.mantissa())?,
scale: low.scale(),
})
}

#[inline]
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Self::X {
let mantissa = self.mantissa_sampler.sample(rng);
Decimal::from_i128_with_scale(mantissa, self.scale)
}
}

/// Return equivalent Decimal objects with the same scale as one another.
#[inline]
fn sync_scales(mut a: Decimal, mut b: Decimal) -> (Decimal, Decimal) {
if a.scale() == b.scale() {
return (a, b);
}

// Set scales to match one another, because we are relying on mantissas'
// being comparable in order outsource the actual sampling implementation.
a.rescale(a.scale().max(b.scale()));
b.rescale(a.scale().max(b.scale()));

// Edge case: If the values have _wildly_ different scales, the values may not have rescaled far enough to match one another.
//
// In this case, we accept some precision loss because the randomization approach we are using assumes that the scales will necessarily match.
if a.scale() != b.scale() {
a.rescale(a.scale().min(b.scale()));
b.rescale(a.scale().min(b.scale()));
}

(a, b)
}

#[cfg(test)]
mod rand_tests {
use std::collections::HashSet;

use super::*;

macro_rules! dec {
($e:expr) => {
Decimal::from_str_exact(stringify!($e)).unwrap()
};
}

#[test]
fn has_random_decimal_instances() {
let mut rng = rand_0_9::rng();
let random: [Decimal; 32] = rng.random();
assert!(random.windows(2).any(|slice| { slice[0] != slice[1] }));
}

#[test]
fn generates_within_range() {
let mut rng = rand_0_9::rng();
for _ in 0..128 {
let random = rng.random_range(dec!(1.00)..dec!(1.05));
assert!(random < dec!(1.05));
assert!(random >= dec!(1.00));
}
}

#[test]
fn generates_within_inclusive_range() {
let mut rng = rand_0_9::rng();
let mut values: HashSet<Decimal> = HashSet::new();
for _ in 0..256 {
let random = rng.random_range(dec!(1.00)..=dec!(1.01));
// The scale is 2, so 1.00 and 1.01 are the only two valid choices.
assert!(random == dec!(1.00) || random == dec!(1.01));
values.insert(random);
}
// Somewhat flaky, will fail 1 out of every 2^255 times this is run.
// Probably acceptable in the real world.
assert_eq!(values.len(), 2);
}

#[test]
fn test_edge_case_scales_match() {
let (low, high) = sync_scales(dec!(1.000_000_000_000_000_000_01), dec!(100_000_000_000_000_000_001));
assert_eq!(low.scale(), high.scale());
}
}

0 comments on commit 41ce632

Please sign in to comment.