Skip to content

Commit

Permalink
Combined dims for expressions
Browse files Browse the repository at this point in the history
  • Loading branch information
jafioti committed Apr 26, 2024
1 parent cd0195c commit 194f221
Showing 1 changed file with 18 additions and 9 deletions.
27 changes: 18 additions & 9 deletions src/shape/tracker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,9 @@ impl ShapeTracker {

/// Create an expression to translate logical indexes into physical indexes
pub fn index_expression(&self) -> BigExpression {
println!("ORIG: {:?}", self);
let shape = combine_dims(*self);
println!("Combined: {:?}", shape);
let strides = shape.unordered_strides(); // Dimension strides in original order
let mut ind_expr = BigExpression::from(0); // The final index expression
let mut current_elem_size = BigExpression::from(1); // Keep track of the size of each element of the current dim (last dim elem size: 1)
Expand Down Expand Up @@ -314,18 +316,25 @@ fn pad_mask_dim(
// Combine non-permuted, non-padded, non-fake, non-masked dimensions together
fn combine_dims(mut shape: ShapeTracker) -> ShapeTracker {
for i in (1..shape.len()).rev() {
if (shape.indexes[i] != i || shape.indexes[i - 1] != i - 1)
|| (shape.fake[i] || shape.fake[i - 1])
|| (shape.padding[i].0 != 0 || shape.padding[i].1 != 0)
|| (shape.mask[i].0 != 0 || shape.mask[i].1 != i32::MAX)
|| (shape.padding[i - 1].0 != 0 || shape.padding[i - 1].1 != 0)
|| (shape.mask[i - 1].0 != 0 || shape.mask[i - 1].1 != i32::MAX)
let (ind_i, ind_i_minus_1) = (shape.indexes[i], shape.indexes[i - 1]);
// Test permute
if (ind_i != ind_i_minus_1 + 1)
// Fakes
|| (shape.fake[ind_i] || shape.fake[ind_i_minus_1])
// Dim i padding
|| (shape.padding[ind_i].0 != 0 || shape.padding[ind_i].1 != 0)
// Dim i mask
|| (shape.mask[ind_i].0 != 0 || shape.mask[ind_i].1 != i32::MAX)
// Dim i - 1 padding
|| (shape.padding[ind_i_minus_1].0 != 0 || shape.padding[ind_i_minus_1].1 != 0)
// Dim i - 1 mask
|| (shape.mask[ind_i_minus_1].0 != 0 || shape.mask[ind_i_minus_1].1 != i32::MAX)
{
continue;
}
// We can combine dimension i and i - 1
let dim_i = shape.dims[i];
shape.dims[i - 1] *= dim_i;
let dim_i = shape.dims[ind_i];
shape.dims[ind_i_minus_1] *= dim_i;
shape.remove_dim(i);
}
shape
Expand Down Expand Up @@ -364,7 +373,7 @@ mod tests {
Expression::from(5),
Expression::from(3),
]);
tracker.permute(&[0, 2, 1]);
tracker.permute(&[2, 0, 1]);
println!("Shape: [10, 5, 3]");
println!("Strides: {:?}", tracker.strides());
println!("Ind: {:?}", tracker.index_expression());
Expand Down

0 comments on commit 194f221

Please sign in to comment.