From cd0195ca81de4e1c3dbe71c1785f70f9b90bb3e0 Mon Sep 17 00:00:00 2001 From: Joe Fioti Date: Fri, 26 Apr 2024 13:56:34 -0500 Subject: [PATCH] Combine dims for index and valid expressions --- src/op.rs | 29 ++++--- src/shape/tracker.rs | 181 +++++++++++++++++++++++-------------------- 2 files changed, 111 insertions(+), 99 deletions(-) diff --git a/src/op.rs b/src/op.rs index 4e5e34ec..01cf2fc8 100644 --- a/src/op.rs +++ b/src/op.rs @@ -242,13 +242,12 @@ pub struct Add; impl Operator for Add { fn process(&mut self, inp: Vec<(InputTensor, ShapeTracker)>) -> Vec { let (lhs, rhs) = (get_vec(&inp[0].0), get_vec(&inp[1].0)); - let lhs_expr = (inp[0].1.index_expression(), inp[0].1.valid_expression()); - let rhs_expr = (inp[1].1.index_expression(), inp[1].1.valid_expression()); + let lexpr = (inp[0].1.index_expression(), inp[0].1.valid_expression()); + let rexpr = (inp[1].1.index_expression(), inp[1].1.valid_expression()); let mut stack = vec![]; let mut out_data = vec![0.; inp[0].1.n_elements().to_usize().unwrap()]; for (i, out) in out_data.iter_mut().enumerate() { - *out = - get_index(lhs, &lhs_expr, &mut stack, i) + get_index(rhs, &rhs_expr, &mut stack, i); + *out = get_index(lhs, &lexpr, &mut stack, i) + get_index(rhs, &rexpr, &mut stack, i); } vec![Tensor::new(out_data)] } @@ -260,12 +259,11 @@ impl Operator for Mul { fn process(&mut self, inp: Vec<(InputTensor, ShapeTracker)>) -> Vec { let (lhs, rhs) = (get_vec(&inp[0].0), get_vec(&inp[1].0)); let mut out_data = vec![0.; inp[0].1.n_elements().to_usize().unwrap()]; - let lhs_expr = (inp[0].1.index_expression(), inp[0].1.valid_expression()); - let rhs_expr = (inp[1].1.index_expression(), inp[1].1.valid_expression()); + let lexpr = (inp[0].1.index_expression(), inp[0].1.valid_expression()); + let rexpr = (inp[1].1.index_expression(), inp[1].1.valid_expression()); let mut stack = vec![]; for (i, out) in out_data.iter_mut().enumerate() { - *out = - get_index(lhs, &lhs_expr, &mut stack, i) * get_index(rhs, &rhs_expr, &mut stack, i); + *out = get_index(lhs, &lexpr, &mut stack, i) * get_index(rhs, &rexpr, &mut stack, i); } vec![Tensor::new(out_data)] } @@ -277,12 +275,11 @@ impl Operator for Mod { fn process(&mut self, inp: Vec<(InputTensor, ShapeTracker)>) -> Vec { let (lhs, rhs) = (get_vec(&inp[0].0), get_vec(&inp[1].0)); let mut out_data = vec![0.; inp[0].1.n_elements().to_usize().unwrap()]; - let lhs_expr = (inp[0].1.index_expression(), inp[0].1.valid_expression()); - let rhs_expr = (inp[1].1.index_expression(), inp[1].1.valid_expression()); + let lexpr = (inp[0].1.index_expression(), inp[0].1.valid_expression()); + let rexpr = (inp[1].1.index_expression(), inp[1].1.valid_expression()); let mut stack = vec![]; for (i, out) in out_data.iter_mut().enumerate() { - *out = - get_index(lhs, &lhs_expr, &mut stack, i) % get_index(rhs, &rhs_expr, &mut stack, i); + *out = get_index(lhs, &lexpr, &mut stack, i) % get_index(rhs, &rexpr, &mut stack, i); } vec![Tensor::new(out_data)] } @@ -294,12 +291,12 @@ impl Operator for LessThan { fn process(&mut self, inp: Vec<(InputTensor, ShapeTracker)>) -> Vec { let (lhs, rhs) = (get_vec(&inp[0].0), get_vec(&inp[1].0)); let mut out_data = vec![0.; inp[0].1.n_elements().to_usize().unwrap()]; - let lhs_expr = (inp[0].1.index_expression(), inp[0].1.valid_expression()); - let rhs_expr = (inp[1].1.index_expression(), inp[1].1.valid_expression()); + let lexpr = (inp[0].1.index_expression(), inp[0].1.valid_expression()); + let rexpr = (inp[1].1.index_expression(), inp[1].1.valid_expression()); let mut stack = vec![]; for (i, out) in out_data.iter_mut().enumerate() { - *out = (get_index(lhs, &lhs_expr, &mut stack, i) - < get_index(rhs, &rhs_expr, &mut stack, i)) as i32 as f32; + *out = (get_index(lhs, &lexpr, &mut stack, i) < get_index(rhs, &rexpr, &mut stack, i)) + as i32 as f32; } vec![Tensor::new(out_data)] } diff --git a/src/shape/tracker.rs b/src/shape/tracker.rs index dff9c1e1..aa5ac0d8 100644 --- a/src/shape/tracker.rs +++ b/src/shape/tracker.rs @@ -39,9 +39,9 @@ impl ShapeTracker { } /// Add dim along a certian axis - pub fn add_dim(&mut self, axis: usize, dim: Expression) { + pub fn add_dim(&mut self, axis: usize, dim: impl Into) { self.indexes.insert(axis, self.dims.len()); - self.dims.push(dim); + self.dims.push(dim.into()); self.fake.push(false); self.mask.push((0.into(), i32::MAX.into())); self.padding.push((0.into(), 0.into())); @@ -49,7 +49,7 @@ impl ShapeTracker { /// Add fake dim along a certian axis pub fn expand(&mut self, axis: usize, dim: impl Into) { - self.add_dim(axis, dim.into()); + self.add_dim(axis, dim); self.fake[self.indexes[axis]] = true; } @@ -98,47 +98,61 @@ impl ShapeTracker { .collect() } + /// Create an expression to translate logical indexes into physical indexes pub fn index_expression(&self) -> BigExpression { - // Create strides in original order - let strides = self.unordered_strides(); - let mut ret = BigExpression::from(0); - let mut acc = BigExpression::from(1); - let index = BigExpression::from('z'); + let shape = combine_dims(*self); + let strides = shape.unordered_strides(); // Dimension strides in original order + let mut ind_expr = BigExpression::from(0); // The final index expression + let mut current_elem_size = BigExpression::from(1); // Keep track of the size of each element of the current dim (last dim elem size: 1) + + // For combined dims + // divide by last dims (smallest) element size + // mod by combined dim size + // multiply by last dims (smallest) stride + // Loop through all dims in reverse order - for i in self.indexes.into_iter().rev() { - let logical_sh = pad_mask_dim(self.dims[i].big(), self.padding[i], self.mask[i]); - if !self.fake[i] { - let dim_ind = (index.clone() / acc.clone()) % logical_sh.clone(); - ret += (dim_ind + self.mask[i].0 - self.padding[i].0) * strides[i].clone(); + for i in shape.indexes.into_iter().rev() { + // Get logical dimension size with padding and mask + let current_size = pad_mask_dim(shape.dims[i], shape.padding[i], shape.mask[i]); + // Don't include fake dimensions in the index expression + if !shape.fake[i] { + let mut dim_ind = BigExpression::from('z'); + // Remove other dim components + dim_ind /= current_elem_size.clone(); + // Get position in current dim + dim_ind %= current_size.clone(); + // Add offset + dim_ind += shape.mask[i].0 - shape.padding[i].0; + // Multiply by stride + dim_ind *= strides[i].clone(); + // Add to index expression + ind_expr += dim_ind; } - acc = acc.clone() * logical_sh.clone(); + // Keep track of element size for next dimension + current_elem_size *= current_size; } - ret.simplify() + ind_expr.simplify() } - /// If this BigExpression evaluates to 0, the logical index is invalid. Otherwise it is valid + /// If this expression evaluates to 0, the logical index is invalid. Otherwise it is valid pub fn valid_expression(&self) -> BigExpression { - if !self.is_reshaped() { - return 1.into(); - } + let shape = combine_dims(*self); let mut ret = BigExpression::from(1); let mut acc = BigExpression::from(1); let logical = BigExpression::from('z'); - for i in self.indexes.into_iter().rev() { - let (bottom_padding, top_padding) = self.padding[i]; - let (bottom_slice, top_slice) = self.mask[i]; - let logical_sh = - (self.dims[i].big() + bottom_padding + top_padding).min(top_slice) - bottom_slice; - if !self.fake[i] { + for i in shape.indexes.into_iter().rev() { + let (bottom_slice, top_slice) = shape.mask[i]; + let logical_sh = pad_mask_dim(shape.dims[i], shape.padding[i], shape.mask[i]); + if !shape.fake[i] { let dim_ind = (logical.clone() / acc.clone()) % logical_sh.clone(); - let greater_than = bottom_padding.big() - bottom_slice.big().min(bottom_padding); + let greater_than = shape.padding[i].0.big() - bottom_slice; if greater_than != 0 { ret &= dim_ind.clone().gte(greater_than); } - ret &= dim_ind.lt(self.dims[i].big() + bottom_padding); + ret &= dim_ind.lt(shape.dims[i].big() + shape.padding[i].0); if top_slice .to_usize() - .map(|s| self.dims[i].to_usize().map(|dim| s < dim).unwrap_or(true)) + .map(|s| shape.dims[i].to_usize().map(|dim| s < dim).unwrap_or(true)) .unwrap_or(true) { ret = ret.min(top_slice); @@ -149,39 +163,19 @@ impl ShapeTracker { ret.simplify() } - /// The number of elements in this tensor, including pads and mask + /// The number of elements in this tensor, including padding and mask pub fn n_elements(&self) -> BigExpression { - let r = self - .indexes - .into_iter() - .map(|i| (i, self.dims[i].big())) - // Add pads - .map(|(i, dim)| (i, dim + self.padding[i].0 + self.padding[i].1)) - // Slice - .map(|(i, dim)| dim.min(self.mask[i].1) - self.mask[i].0) - .product(); - if r == 0 { - 1.into() - } else { - r - } + self.shape().into_iter().product::().max(1) } /// The number of elements in this tensor, not including pads and mask pub fn n_physical_elements(&self) -> BigExpression { - let r = self - .dims + self.indexes .into_iter() - // Filter out fake dimensions - .enumerate() - .filter(|(i, _)| !self.fake[*i]) - .map(|(_, i)| i.into()) - .product(); - if r == 0 { - 1.into() - } else { - r - } + .filter(|i| !self.fake[*i]) + .map(|i| self.dims[i].big()) + .product::() + .max(1) } /// The number of dimensions @@ -202,19 +196,16 @@ impl ShapeTracker { /// Create a contiguous version pub fn contiguous(self) -> Self { - let new_dims = self - .indexes - .into_iter() - .map(|i| { - self.dims[i].min(self.mask[i].1 - self.mask[i].0) - + self.padding[i].0 - + self.padding[i].1 - }) - .collect::>(); - Self::new(&new_dims) + Self::new( + &self + .shape() + .into_iter() + .map(|i| i.small()) + .collect::>(), + ) } - /// Check if contiguous + /// Check if contiguous (no permutes or fake dimensions) pub fn is_contiguous(&self) -> bool { self.indexes.iter().enumerate().all(|(a, b)| a == *b) && self.fake.iter().all(|i| !*i) } @@ -228,10 +219,7 @@ impl ShapeTracker { pub fn shape(&self) -> Vec { self.indexes .into_iter() - .map(|i| { - (self.dims[i].big() + self.padding[i].0 - self.mask[i].0 + self.padding[i].1) - .min(self.mask[i].1) - }) + .map(|i| pad_mask_dim(self.dims[i], self.padding[i], self.mask[i])) .collect() } @@ -242,32 +230,37 @@ impl ShapeTracker { /// Take a slice pub fn slice(&mut self, mask: &[(Expression, Expression)]) { - for (i, (s, e)) in mask.iter().enumerate() { - self.mask[self.indexes[i]].0 = self.mask[self.indexes[i]].0.max(s.max(0)); - self.mask[self.indexes[i]].1 = self.mask[self.indexes[i]].1.min(e.max(0)); + for (ind, (b, t)) in mask.iter().enumerate().map(|(i, m)| (self.indexes[i], m)) { + self.mask[ind].0 = self.mask[ind].0.max(b.max(0)); + self.mask[ind].1 = self.mask[ind].1.min(t.max(0)); } } /// Add padding pub fn pad(&mut self, padding: &[(Expression, Expression)]) { - for (i, (s, e)) in padding.iter().enumerate() { + for (ind, (s, e)) in padding + .iter() + .enumerate() + .map(|(i, m)| (self.indexes[i], m)) + { + // Make sure we aren't padding a masked dimension if (e.to_usize().map(|n| n != 0).unwrap_or(true) - && self.mask[self.indexes[i]] + && self.mask[ind] .1 .to_usize() .map(|n| n as i32 != i32::MAX) .unwrap_or(true)) || (s.to_usize().map(|n| n != 0).unwrap_or(true) - && self.mask[self.indexes[i]] + && self.mask[ind] .0 .to_usize() .map(|n| n as i32 != 0) .unwrap_or(true)) { - panic!("Adding padding to a slice isn't supported") + panic!("Adding padding to a masked shape isn't supported") } - self.padding[self.indexes[i]].0 += s.max(0); - self.padding[self.indexes[i]].1 += e.max(0); + self.padding[ind].0 += s.max(0); + self.padding[ind].1 += e.max(0); } } @@ -311,11 +304,31 @@ impl ShapeTracker { } fn pad_mask_dim( - dim: BigExpression, + dim: impl Into, padding: (Expression, Expression), mask: (Expression, Expression), ) -> BigExpression { - (dim + padding.0 + padding.1).min(mask.1) - mask.0 + (dim.into() + padding.0 + padding.1).min(mask.1) - mask.0 +} + +// Combine non-permuted, non-padded, non-fake, non-masked dimensions together +fn combine_dims(mut shape: ShapeTracker) -> ShapeTracker { + for i in (1..shape.len()).rev() { + if (shape.indexes[i] != i || shape.indexes[i - 1] != i - 1) + || (shape.fake[i] || shape.fake[i - 1]) + || (shape.padding[i].0 != 0 || shape.padding[i].1 != 0) + || (shape.mask[i].0 != 0 || shape.mask[i].1 != i32::MAX) + || (shape.padding[i - 1].0 != 0 || shape.padding[i - 1].1 != 0) + || (shape.mask[i - 1].0 != 0 || shape.mask[i - 1].1 != i32::MAX) + { + continue; + } + // We can combine dimension i and i - 1 + let dim_i = shape.dims[i]; + shape.dims[i - 1] *= dim_i; + shape.remove_dim(i); + } + shape } /// Resolve shapes between the two trackers to the best of our ability @@ -346,14 +359,16 @@ mod tests { use crate::prelude::*; #[test] fn test_idx_expr() { - let tracker = ShapeTracker::new(&[ + let mut tracker = ShapeTracker::new(&[ Expression::from(10), Expression::from(5), Expression::from(3), ]); - // tracker.permute(&[0, 2, 1]); + tracker.permute(&[0, 2, 1]); + println!("Shape: [10, 5, 3]"); println!("Strides: {:?}", tracker.strides()); println!("Ind: {:?}", tracker.index_expression()); + println!("Val: {:?}", tracker.valid_expression()); } #[test]