From 194f22137264ff738b08189cb7c71e3a8d1d35f5 Mon Sep 17 00:00:00 2001 From: Joe Fioti Date: Fri, 26 Apr 2024 14:16:27 -0500 Subject: [PATCH] Combined dims for expressions --- src/shape/tracker.rs | 27 ++++++++++++++++++--------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/src/shape/tracker.rs b/src/shape/tracker.rs index aa5ac0d8..8a1f9985 100644 --- a/src/shape/tracker.rs +++ b/src/shape/tracker.rs @@ -100,7 +100,9 @@ impl ShapeTracker { /// Create an expression to translate logical indexes into physical indexes pub fn index_expression(&self) -> BigExpression { + println!("ORIG: {:?}", self); let shape = combine_dims(*self); + println!("Combined: {:?}", shape); let strides = shape.unordered_strides(); // Dimension strides in original order let mut ind_expr = BigExpression::from(0); // The final index expression let mut current_elem_size = BigExpression::from(1); // Keep track of the size of each element of the current dim (last dim elem size: 1) @@ -314,18 +316,25 @@ fn pad_mask_dim( // Combine non-permuted, non-padded, non-fake, non-masked dimensions together fn combine_dims(mut shape: ShapeTracker) -> ShapeTracker { for i in (1..shape.len()).rev() { - if (shape.indexes[i] != i || shape.indexes[i - 1] != i - 1) - || (shape.fake[i] || shape.fake[i - 1]) - || (shape.padding[i].0 != 0 || shape.padding[i].1 != 0) - || (shape.mask[i].0 != 0 || shape.mask[i].1 != i32::MAX) - || (shape.padding[i - 1].0 != 0 || shape.padding[i - 1].1 != 0) - || (shape.mask[i - 1].0 != 0 || shape.mask[i - 1].1 != i32::MAX) + let (ind_i, ind_i_minus_1) = (shape.indexes[i], shape.indexes[i - 1]); + // Test permute + if (ind_i != ind_i_minus_1 + 1) + // Fakes + || (shape.fake[ind_i] || shape.fake[ind_i_minus_1]) + // Dim i padding + || (shape.padding[ind_i].0 != 0 || shape.padding[ind_i].1 != 0) + // Dim i mask + || (shape.mask[ind_i].0 != 0 || shape.mask[ind_i].1 != i32::MAX) + // Dim i - 1 padding + || (shape.padding[ind_i_minus_1].0 != 0 || shape.padding[ind_i_minus_1].1 != 0) + // Dim i - 1 mask + || (shape.mask[ind_i_minus_1].0 != 0 || shape.mask[ind_i_minus_1].1 != i32::MAX) { continue; } // We can combine dimension i and i - 1 - let dim_i = shape.dims[i]; - shape.dims[i - 1] *= dim_i; + let dim_i = shape.dims[ind_i]; + shape.dims[ind_i_minus_1] *= dim_i; shape.remove_dim(i); } shape @@ -364,7 +373,7 @@ mod tests { Expression::from(5), Expression::from(3), ]); - tracker.permute(&[0, 2, 1]); + tracker.permute(&[2, 0, 1]); println!("Shape: [10, 5, 3]"); println!("Strides: {:?}", tracker.strides()); println!("Ind: {:?}", tracker.index_expression());