Skip to content

Commit

Permalink
Combine dims for index and valid expressions
Browse files Browse the repository at this point in the history
  • Loading branch information
jafioti committed Apr 26, 2024
1 parent 5bc7417 commit cd0195c
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 99 deletions.
29 changes: 13 additions & 16 deletions src/op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -242,13 +242,12 @@ pub struct Add;
impl Operator for Add {
fn process(&mut self, inp: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
let (lhs, rhs) = (get_vec(&inp[0].0), get_vec(&inp[1].0));
let lhs_expr = (inp[0].1.index_expression(), inp[0].1.valid_expression());
let rhs_expr = (inp[1].1.index_expression(), inp[1].1.valid_expression());
let lexpr = (inp[0].1.index_expression(), inp[0].1.valid_expression());
let rexpr = (inp[1].1.index_expression(), inp[1].1.valid_expression());
let mut stack = vec![];
let mut out_data = vec![0.; inp[0].1.n_elements().to_usize().unwrap()];
for (i, out) in out_data.iter_mut().enumerate() {
*out =
get_index(lhs, &lhs_expr, &mut stack, i) + get_index(rhs, &rhs_expr, &mut stack, i);
*out = get_index(lhs, &lexpr, &mut stack, i) + get_index(rhs, &rexpr, &mut stack, i);
}
vec![Tensor::new(out_data)]
}
Expand All @@ -260,12 +259,11 @@ impl Operator for Mul {
fn process(&mut self, inp: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
let (lhs, rhs) = (get_vec(&inp[0].0), get_vec(&inp[1].0));
let mut out_data = vec![0.; inp[0].1.n_elements().to_usize().unwrap()];
let lhs_expr = (inp[0].1.index_expression(), inp[0].1.valid_expression());
let rhs_expr = (inp[1].1.index_expression(), inp[1].1.valid_expression());
let lexpr = (inp[0].1.index_expression(), inp[0].1.valid_expression());
let rexpr = (inp[1].1.index_expression(), inp[1].1.valid_expression());
let mut stack = vec![];
for (i, out) in out_data.iter_mut().enumerate() {
*out =
get_index(lhs, &lhs_expr, &mut stack, i) * get_index(rhs, &rhs_expr, &mut stack, i);
*out = get_index(lhs, &lexpr, &mut stack, i) * get_index(rhs, &rexpr, &mut stack, i);
}
vec![Tensor::new(out_data)]
}
Expand All @@ -277,12 +275,11 @@ impl Operator for Mod {
fn process(&mut self, inp: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
let (lhs, rhs) = (get_vec(&inp[0].0), get_vec(&inp[1].0));
let mut out_data = vec![0.; inp[0].1.n_elements().to_usize().unwrap()];
let lhs_expr = (inp[0].1.index_expression(), inp[0].1.valid_expression());
let rhs_expr = (inp[1].1.index_expression(), inp[1].1.valid_expression());
let lexpr = (inp[0].1.index_expression(), inp[0].1.valid_expression());
let rexpr = (inp[1].1.index_expression(), inp[1].1.valid_expression());
let mut stack = vec![];
for (i, out) in out_data.iter_mut().enumerate() {
*out =
get_index(lhs, &lhs_expr, &mut stack, i) % get_index(rhs, &rhs_expr, &mut stack, i);
*out = get_index(lhs, &lexpr, &mut stack, i) % get_index(rhs, &rexpr, &mut stack, i);
}
vec![Tensor::new(out_data)]
}
Expand All @@ -294,12 +291,12 @@ impl Operator for LessThan {
fn process(&mut self, inp: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
let (lhs, rhs) = (get_vec(&inp[0].0), get_vec(&inp[1].0));
let mut out_data = vec![0.; inp[0].1.n_elements().to_usize().unwrap()];
let lhs_expr = (inp[0].1.index_expression(), inp[0].1.valid_expression());
let rhs_expr = (inp[1].1.index_expression(), inp[1].1.valid_expression());
let lexpr = (inp[0].1.index_expression(), inp[0].1.valid_expression());
let rexpr = (inp[1].1.index_expression(), inp[1].1.valid_expression());
let mut stack = vec![];
for (i, out) in out_data.iter_mut().enumerate() {
*out = (get_index(lhs, &lhs_expr, &mut stack, i)
< get_index(rhs, &rhs_expr, &mut stack, i)) as i32 as f32;
*out = (get_index(lhs, &lexpr, &mut stack, i) < get_index(rhs, &rexpr, &mut stack, i))
as i32 as f32;
}
vec![Tensor::new(out_data)]
}
Expand Down
181 changes: 98 additions & 83 deletions src/shape/tracker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,17 +39,17 @@ impl ShapeTracker {
}

/// Add dim along a certian axis
pub fn add_dim(&mut self, axis: usize, dim: Expression) {
pub fn add_dim(&mut self, axis: usize, dim: impl Into<Expression>) {
self.indexes.insert(axis, self.dims.len());
self.dims.push(dim);
self.dims.push(dim.into());
self.fake.push(false);
self.mask.push((0.into(), i32::MAX.into()));
self.padding.push((0.into(), 0.into()));
}

/// Add fake dim along a certian axis
pub fn expand(&mut self, axis: usize, dim: impl Into<Expression>) {
self.add_dim(axis, dim.into());
self.add_dim(axis, dim);
self.fake[self.indexes[axis]] = true;
}

Expand Down Expand Up @@ -98,47 +98,61 @@ impl ShapeTracker {
.collect()
}

/// Create an expression to translate logical indexes into physical indexes
pub fn index_expression(&self) -> BigExpression {
// Create strides in original order
let strides = self.unordered_strides();
let mut ret = BigExpression::from(0);
let mut acc = BigExpression::from(1);
let index = BigExpression::from('z');
let shape = combine_dims(*self);
let strides = shape.unordered_strides(); // Dimension strides in original order
let mut ind_expr = BigExpression::from(0); // The final index expression
let mut current_elem_size = BigExpression::from(1); // Keep track of the size of each element of the current dim (last dim elem size: 1)

// For combined dims
// divide by last dims (smallest) element size
// mod by combined dim size
// multiply by last dims (smallest) stride

// Loop through all dims in reverse order
for i in self.indexes.into_iter().rev() {
let logical_sh = pad_mask_dim(self.dims[i].big(), self.padding[i], self.mask[i]);
if !self.fake[i] {
let dim_ind = (index.clone() / acc.clone()) % logical_sh.clone();
ret += (dim_ind + self.mask[i].0 - self.padding[i].0) * strides[i].clone();
for i in shape.indexes.into_iter().rev() {
// Get logical dimension size with padding and mask
let current_size = pad_mask_dim(shape.dims[i], shape.padding[i], shape.mask[i]);
// Don't include fake dimensions in the index expression
if !shape.fake[i] {
let mut dim_ind = BigExpression::from('z');
// Remove other dim components
dim_ind /= current_elem_size.clone();
// Get position in current dim
dim_ind %= current_size.clone();
// Add offset
dim_ind += shape.mask[i].0 - shape.padding[i].0;
// Multiply by stride
dim_ind *= strides[i].clone();
// Add to index expression
ind_expr += dim_ind;
}
acc = acc.clone() * logical_sh.clone();
// Keep track of element size for next dimension
current_elem_size *= current_size;
}
ret.simplify()
ind_expr.simplify()
}

/// If this BigExpression evaluates to 0, the logical index is invalid. Otherwise it is valid
/// If this expression evaluates to 0, the logical index is invalid. Otherwise it is valid
pub fn valid_expression(&self) -> BigExpression {
if !self.is_reshaped() {
return 1.into();
}
let shape = combine_dims(*self);
let mut ret = BigExpression::from(1);
let mut acc = BigExpression::from(1);
let logical = BigExpression::from('z');
for i in self.indexes.into_iter().rev() {
let (bottom_padding, top_padding) = self.padding[i];
let (bottom_slice, top_slice) = self.mask[i];
let logical_sh =
(self.dims[i].big() + bottom_padding + top_padding).min(top_slice) - bottom_slice;
if !self.fake[i] {
for i in shape.indexes.into_iter().rev() {
let (bottom_slice, top_slice) = shape.mask[i];
let logical_sh = pad_mask_dim(shape.dims[i], shape.padding[i], shape.mask[i]);
if !shape.fake[i] {
let dim_ind = (logical.clone() / acc.clone()) % logical_sh.clone();
let greater_than = bottom_padding.big() - bottom_slice.big().min(bottom_padding);
let greater_than = shape.padding[i].0.big() - bottom_slice;
if greater_than != 0 {
ret &= dim_ind.clone().gte(greater_than);
}
ret &= dim_ind.lt(self.dims[i].big() + bottom_padding);
ret &= dim_ind.lt(shape.dims[i].big() + shape.padding[i].0);
if top_slice
.to_usize()
.map(|s| self.dims[i].to_usize().map(|dim| s < dim).unwrap_or(true))
.map(|s| shape.dims[i].to_usize().map(|dim| s < dim).unwrap_or(true))
.unwrap_or(true)
{
ret = ret.min(top_slice);
Expand All @@ -149,39 +163,19 @@ impl ShapeTracker {
ret.simplify()
}

/// The number of elements in this tensor, including pads and mask
/// The number of elements in this tensor, including padding and mask
pub fn n_elements(&self) -> BigExpression {
let r = self
.indexes
.into_iter()
.map(|i| (i, self.dims[i].big()))
// Add pads
.map(|(i, dim)| (i, dim + self.padding[i].0 + self.padding[i].1))
// Slice
.map(|(i, dim)| dim.min(self.mask[i].1) - self.mask[i].0)
.product();
if r == 0 {
1.into()
} else {
r
}
self.shape().into_iter().product::<BigExpression>().max(1)
}

/// The number of elements in this tensor, not including pads and mask
pub fn n_physical_elements(&self) -> BigExpression {
let r = self
.dims
self.indexes
.into_iter()
// Filter out fake dimensions
.enumerate()
.filter(|(i, _)| !self.fake[*i])
.map(|(_, i)| i.into())
.product();
if r == 0 {
1.into()
} else {
r
}
.filter(|i| !self.fake[*i])
.map(|i| self.dims[i].big())
.product::<BigExpression>()
.max(1)
}

/// The number of dimensions
Expand All @@ -202,19 +196,16 @@ impl ShapeTracker {

/// Create a contiguous version
pub fn contiguous(self) -> Self {
let new_dims = self
.indexes
.into_iter()
.map(|i| {
self.dims[i].min(self.mask[i].1 - self.mask[i].0)
+ self.padding[i].0
+ self.padding[i].1
})
.collect::<Vec<_>>();
Self::new(&new_dims)
Self::new(
&self
.shape()
.into_iter()
.map(|i| i.small())
.collect::<Vec<_>>(),
)
}

/// Check if contiguous
/// Check if contiguous (no permutes or fake dimensions)
pub fn is_contiguous(&self) -> bool {
self.indexes.iter().enumerate().all(|(a, b)| a == *b) && self.fake.iter().all(|i| !*i)
}
Expand All @@ -228,10 +219,7 @@ impl ShapeTracker {
pub fn shape(&self) -> Vec<BigExpression> {
self.indexes
.into_iter()
.map(|i| {
(self.dims[i].big() + self.padding[i].0 - self.mask[i].0 + self.padding[i].1)
.min(self.mask[i].1)
})
.map(|i| pad_mask_dim(self.dims[i], self.padding[i], self.mask[i]))
.collect()
}

Expand All @@ -242,32 +230,37 @@ impl ShapeTracker {

/// Take a slice
pub fn slice(&mut self, mask: &[(Expression, Expression)]) {
for (i, (s, e)) in mask.iter().enumerate() {
self.mask[self.indexes[i]].0 = self.mask[self.indexes[i]].0.max(s.max(0));
self.mask[self.indexes[i]].1 = self.mask[self.indexes[i]].1.min(e.max(0));
for (ind, (b, t)) in mask.iter().enumerate().map(|(i, m)| (self.indexes[i], m)) {
self.mask[ind].0 = self.mask[ind].0.max(b.max(0));
self.mask[ind].1 = self.mask[ind].1.min(t.max(0));
}
}

/// Add padding
pub fn pad(&mut self, padding: &[(Expression, Expression)]) {
for (i, (s, e)) in padding.iter().enumerate() {
for (ind, (s, e)) in padding
.iter()
.enumerate()
.map(|(i, m)| (self.indexes[i], m))
{
// Make sure we aren't padding a masked dimension
if (e.to_usize().map(|n| n != 0).unwrap_or(true)
&& self.mask[self.indexes[i]]
&& self.mask[ind]
.1
.to_usize()
.map(|n| n as i32 != i32::MAX)
.unwrap_or(true))
|| (s.to_usize().map(|n| n != 0).unwrap_or(true)
&& self.mask[self.indexes[i]]
&& self.mask[ind]
.0
.to_usize()
.map(|n| n as i32 != 0)
.unwrap_or(true))
{
panic!("Adding padding to a slice isn't supported")
panic!("Adding padding to a masked shape isn't supported")
}
self.padding[self.indexes[i]].0 += s.max(0);
self.padding[self.indexes[i]].1 += e.max(0);
self.padding[ind].0 += s.max(0);
self.padding[ind].1 += e.max(0);
}
}

Expand Down Expand Up @@ -311,11 +304,31 @@ impl ShapeTracker {
}

fn pad_mask_dim(
dim: BigExpression,
dim: impl Into<BigExpression>,
padding: (Expression, Expression),
mask: (Expression, Expression),
) -> BigExpression {
(dim + padding.0 + padding.1).min(mask.1) - mask.0
(dim.into() + padding.0 + padding.1).min(mask.1) - mask.0
}

// Combine non-permuted, non-padded, non-fake, non-masked dimensions together
fn combine_dims(mut shape: ShapeTracker) -> ShapeTracker {
for i in (1..shape.len()).rev() {
if (shape.indexes[i] != i || shape.indexes[i - 1] != i - 1)
|| (shape.fake[i] || shape.fake[i - 1])
|| (shape.padding[i].0 != 0 || shape.padding[i].1 != 0)
|| (shape.mask[i].0 != 0 || shape.mask[i].1 != i32::MAX)
|| (shape.padding[i - 1].0 != 0 || shape.padding[i - 1].1 != 0)
|| (shape.mask[i - 1].0 != 0 || shape.mask[i - 1].1 != i32::MAX)
{
continue;
}
// We can combine dimension i and i - 1
let dim_i = shape.dims[i];
shape.dims[i - 1] *= dim_i;
shape.remove_dim(i);
}
shape
}

/// Resolve shapes between the two trackers to the best of our ability
Expand Down Expand Up @@ -346,14 +359,16 @@ mod tests {
use crate::prelude::*;
#[test]
fn test_idx_expr() {
let tracker = ShapeTracker::new(&[
let mut tracker = ShapeTracker::new(&[
Expression::from(10),
Expression::from(5),
Expression::from(3),
]);
// tracker.permute(&[0, 2, 1]);
tracker.permute(&[0, 2, 1]);
println!("Shape: [10, 5, 3]");
println!("Strides: {:?}", tracker.strides());
println!("Ind: {:?}", tracker.index_expression());
println!("Val: {:?}", tracker.valid_expression());
}

#[test]
Expand Down

0 comments on commit cd0195c

Please sign in to comment.