diff --git a/crates/luminal_metal/src/tests/fp16.rs b/crates/luminal_metal/src/tests/fp16.rs index 8a8b8794..7b95eb69 100644 --- a/crates/luminal_metal/src/tests/fp16.rs +++ b/crates/luminal_metal/src/tests/fp16.rs @@ -792,6 +792,20 @@ fn test_movement() { assert_exact(&c.data(), &d_c.as_vec()); } +#[test] +fn test_slice_add() { + let mut cx = Graph::new(); + let a = cx.tensor().set(random_array::<256>()); + let mut b = (a.slice(0..64) + a.slice(64..128) + a.slice(128..192) + a.slice(192..256)) + .realize::>() + .expand::, _>() + .retrieve(); + + cx.compile(MetalCompiler::::default(), &mut b); + cx.execute(); + cx.display(); +} + #[test] fn test_conv2d() { let mut cx = Graph::new(); diff --git a/crates/luminal_metal/src/tests/fp32.rs b/crates/luminal_metal/src/tests/fp32.rs index 46e04e7d..4ce6365d 100644 --- a/crates/luminal_metal/src/tests/fp32.rs +++ b/crates/luminal_metal/src/tests/fp32.rs @@ -461,7 +461,7 @@ fn test_conv2d() { ); cx.execute(); - assert_close_precision( + assert_close( &out1.data(), &[ 3.9600, -0.3300, -1.7800, 4.0400, 1.5300, 0.2900, 2.8700, 3.0000, 0.9600, -1.8700, diff --git a/crates/luminal_nn/src/convolution.rs b/crates/luminal_nn/src/convolution.rs index 7087e4cf..d76f2add 100644 --- a/crates/luminal_nn/src/convolution.rs +++ b/crates/luminal_nn/src/convolution.rs @@ -318,7 +318,7 @@ mod tests { exp_out1.retrieve(); - let model: Conv2D = Conv2D::initialize(&mut cx); + let model = Conv2D::::initialize(&mut cx); model.weight.set(vec![ 0.1600, 0.2000, 0.1900, -0.1100, 0.0100, -0.0300, -0.1200, -0.0800, -0.1300, -0.0300, 0.1600, -0.1700, -0.0000, 0.1900, 0.1300, 0.0300, -0.1500, 0.0900, 0.0100, 0.0200, diff --git a/src/shape/slice.rs b/src/shape/slice.rs index 1e30b9d4..3aabe4b2 100644 --- a/src/shape/slice.rs +++ b/src/shape/slice.rs @@ -20,40 +20,94 @@ fn get_end_bound + Copy, S: Into>( } } -fn dim_to_size(r: Expression) -> usize { - r.to_usize().unwrap_or(i32::MAX as usize) -} - pub trait RangeToDim { type Dimension: Dimension; + fn bounds(&self, size: impl Into) -> (Expression, Expression); } impl RangeToDim for RangeFrom { type Dimension = Dyn<'-'>; + fn bounds(&self, size: impl Into) -> (Expression, Expression) { + ( + get_start_bound(self.start_bound()), + get_end_bound(self.end_bound(), size), + ) + } } impl RangeToDim for RangeTo { type Dimension = Dyn<'-'>; + fn bounds(&self, size: impl Into) -> (Expression, Expression) { + ( + get_start_bound(self.start_bound()), + get_end_bound(self.end_bound(), size), + ) + } } impl RangeToDim for RangeToInclusive { type Dimension = Dyn<'-'>; + fn bounds(&self, size: impl Into) -> (Expression, Expression) { + ( + get_start_bound(self.start_bound()), + get_end_bound(self.end_bound(), size), + ) + } } impl RangeToDim for Range { type Dimension = Dyn<'-'>; + fn bounds(&self, size: impl Into) -> (Expression, Expression) { + ( + get_start_bound(self.start_bound()), + get_end_bound(self.end_bound(), size), + ) + } } impl RangeToDim for RangeFrom { type Dimension = Dyn<'-'>; + fn bounds(&self, size: impl Into) -> (Expression, Expression) { + ( + get_start_bound(self.start_bound()), + get_end_bound(self.end_bound(), size), + ) + } } impl RangeToDim for RangeTo { type Dimension = Dyn<'-'>; + fn bounds(&self, size: impl Into) -> (Expression, Expression) { + ( + get_start_bound(self.start_bound()), + get_end_bound(self.end_bound(), size), + ) + } } impl RangeToDim for RangeToInclusive { type Dimension = Dyn<'-'>; + fn bounds(&self, size: impl Into) -> (Expression, Expression) { + ( + get_start_bound(self.start_bound()), + get_end_bound(self.end_bound(), size), + ) + } } impl RangeToDim for Range { type Dimension = Dyn<'-'>; + fn bounds(&self, size: impl Into) -> (Expression, Expression) { + ( + get_start_bound(self.start_bound()), + get_end_bound(self.end_bound(), size), + ) + } } impl RangeToDim for RangeFull { type Dimension = D; + fn bounds(&self, size: impl Into) -> (Expression, Expression) { + (0.into(), size.into()) + } +} +impl> RangeToDim for (R,) { + type Dimension = R::Dimension; + fn bounds(&self, size: impl Into) -> (Expression, Expression) { + self.0.bounds(size) + } } pub trait SliceOfShape { @@ -68,34 +122,21 @@ impl SliceOfShape for () { } } -impl + RangeToDim> SliceOfShape<(A,)> for (R,) { +impl> SliceOfShape<(A,)> for R { type OutputShape = (R::Dimension,); fn to_range_vec(&self) -> Vec<(Expression, Expression)> { - vec![( - get_start_bound(self.0.start_bound()), - get_end_bound(self.0.end_bound(), dim_to_size(A::const_size())), - )] + vec![self.bounds(A::const_size())] } } -impl< - A: Dimension, - B: Dimension, - R1: RangeBounds + RangeToDim, - R2: RangeBounds + RangeToDim, - > SliceOfShape<(A, B)> for (R1, R2) +impl, R2: RangeToDim> SliceOfShape<(A, B)> + for (R1, R2) { type OutputShape = (R1::Dimension, R2::Dimension); fn to_range_vec(&self) -> Vec<(Expression, Expression)> { vec![ - ( - get_start_bound(self.0.start_bound()), - get_end_bound(self.0.end_bound(), dim_to_size(A::const_size())), - ), - ( - get_start_bound(self.1.start_bound()), - get_end_bound(self.1.end_bound(), dim_to_size(B::const_size())), - ), + self.0.bounds(A::const_size()), + self.1.bounds(B::const_size()), ] } } @@ -104,26 +145,17 @@ impl< A: Dimension, B: Dimension, C: Dimension, - R1: RangeBounds + RangeToDim, - R2: RangeBounds + RangeToDim, - R3: RangeBounds + RangeToDim, + R1: RangeToDim, + R2: RangeToDim, + R3: RangeToDim, > SliceOfShape<(A, B, C)> for (R1, R2, R3) { type OutputShape = (R1::Dimension, R2::Dimension, R3::Dimension); fn to_range_vec(&self) -> Vec<(Expression, Expression)> { vec![ - ( - get_start_bound(self.0.start_bound()), - get_end_bound(self.0.end_bound(), dim_to_size(A::const_size())), - ), - ( - get_start_bound(self.1.start_bound()), - get_end_bound(self.1.end_bound(), dim_to_size(B::const_size())), - ), - ( - get_start_bound(self.2.start_bound()), - get_end_bound(self.2.end_bound(), dim_to_size(C::const_size())), - ), + self.0.bounds(A::const_size()), + self.1.bounds(B::const_size()), + self.2.bounds(C::const_size()), ] } } @@ -133,31 +165,19 @@ impl< B: Dimension, C: Dimension, D: Dimension, - R1: RangeBounds + RangeToDim, - R2: RangeBounds + RangeToDim, - R3: RangeBounds + RangeToDim, - R4: RangeBounds + RangeToDim, + R1: RangeToDim, + R2: RangeToDim, + R3: RangeToDim, + R4: RangeToDim, > SliceOfShape<(A, B, C, D)> for (R1, R2, R3, R4) { type OutputShape = (R1::Dimension, R2::Dimension, R3::Dimension, R4::Dimension); fn to_range_vec(&self) -> Vec<(Expression, Expression)> { vec![ - ( - get_start_bound(self.0.start_bound()), - get_end_bound(self.0.end_bound(), dim_to_size(A::const_size())), - ), - ( - get_start_bound(self.1.start_bound()), - get_end_bound(self.1.end_bound(), dim_to_size(B::const_size())), - ), - ( - get_start_bound(self.2.start_bound()), - get_end_bound(self.2.end_bound(), dim_to_size(C::const_size())), - ), - ( - get_start_bound(self.3.start_bound()), - get_end_bound(self.3.end_bound(), dim_to_size(D::const_size())), - ), + self.0.bounds(A::const_size()), + self.1.bounds(B::const_size()), + self.2.bounds(C::const_size()), + self.3.bounds(D::const_size()), ] } } @@ -168,11 +188,11 @@ impl< C: Dimension, D: Dimension, E: Dimension, - R1: RangeBounds + RangeToDim, - R2: RangeBounds + RangeToDim, - R3: RangeBounds + RangeToDim, - R4: RangeBounds + RangeToDim, - R5: RangeBounds + RangeToDim, + R1: RangeToDim, + R2: RangeToDim, + R3: RangeToDim, + R4: RangeToDim, + R5: RangeToDim, > SliceOfShape<(A, B, C, D, E)> for (R1, R2, R3, R4, R5) { type OutputShape = ( @@ -184,26 +204,11 @@ impl< ); fn to_range_vec(&self) -> Vec<(Expression, Expression)> { vec![ - ( - get_start_bound(self.0.start_bound()), - get_end_bound(self.0.end_bound(), dim_to_size(A::const_size())), - ), - ( - get_start_bound(self.1.start_bound()), - get_end_bound(self.1.end_bound(), dim_to_size(B::const_size())), - ), - ( - get_start_bound(self.2.start_bound()), - get_end_bound(self.2.end_bound(), dim_to_size(C::const_size())), - ), - ( - get_start_bound(self.3.start_bound()), - get_end_bound(self.3.end_bound(), dim_to_size(D::const_size())), - ), - ( - get_start_bound(self.4.start_bound()), - get_end_bound(self.4.end_bound(), dim_to_size(E::const_size())), - ), + self.0.bounds(A::const_size()), + self.1.bounds(B::const_size()), + self.2.bounds(C::const_size()), + self.3.bounds(D::const_size()), + self.4.bounds(E::const_size()), ] } } diff --git a/src/tests/mod.rs b/src/tests/mod.rs index f8b8d151..daccfe70 100644 --- a/src/tests/mod.rs +++ b/src/tests/mod.rs @@ -103,9 +103,22 @@ pub fn assert_exact(a_vec: &[T], b_vec: &[T]) { } } +pub fn random_array() -> [f32; N] { + let mut rng = thread_rng(); + random_array_rng(&mut rng) +} + +pub fn random_array_rng(rng: &mut R) -> [f32; N] { + let mut arr = [0.; N]; + for i in &mut arr { + *i = rng.gen_range(-0.5..0.5); + } + arr +} + pub fn random_vec(n: usize) -> Vec { let mut rng = thread_rng(); - (0..n).map(|_| rng.gen_range(-0.5..0.5)).collect() + random_vec_rng(n, &mut rng) } pub fn random_vec_rng(n: usize, rng: &mut R) -> Vec { @@ -127,13 +140,8 @@ macro_rules! test_imports { Axis as LAxis, Const as LConst, *, }, tests::{ - assert_close, - assert_close_precision, - assert_exact, - // harness::{test_compilers_close, test_compilers_exact}, - random_vec, - random_vec_rng, - test_graphs, + assert_close, assert_close_precision, assert_exact, random_array, random_array_rng, + random_vec, random_vec_rng, test_graphs, }, }; };