From e1279d9780fa50b5e1352cf5115e7d7a99abba10 Mon Sep 17 00:00:00 2001 From: Joe Fioti Date: Tue, 30 Jul 2024 10:35:46 -0500 Subject: [PATCH] Changed mean reduce --- src/hl_ops/binary.rs | 30 +++++++++++++---------- src/hl_ops/other.rs | 24 +++---------------- src/hl_ops/reduction.rs | 50 ++++++++++---------------------------- src/op.rs | 3 ++- src/shape/symbolic.rs | 53 ++++++++++++++++++++++++++++++++++++++++- src/shape/tracker.rs | 23 ------------------ 6 files changed, 86 insertions(+), 97 deletions(-) diff --git a/src/hl_ops/binary.rs b/src/hl_ops/binary.rs index eff44f13..bd719613 100644 --- a/src/hl_ops/binary.rs +++ b/src/hl_ops/binary.rs @@ -10,8 +10,8 @@ use std::ops::{Add, Div, Mul, Rem, Sub}; impl Add for GraphTensor { type Output = GraphTensor; - fn add(mut self, mut rhs: GraphTensor) -> Self::Output { - resolve_local_dyn_dims(&mut self.shape, &mut rhs.shape, false); + fn add(self, rhs: GraphTensor) -> Self::Output { + assert_eq!(self.dims(), rhs.dims(), "Dims must match to add tensors."); let new_id = self .graph() .add_op(op::Add) @@ -61,8 +61,12 @@ impl SubAssign for GraphTensor { impl Mul for GraphTensor { type Output = GraphTensor; - fn mul(mut self, mut rhs: GraphTensor) -> Self::Output { - resolve_local_dyn_dims(&mut self.shape, &mut rhs.shape, false); + fn mul(self, rhs: GraphTensor) -> Self::Output { + assert_eq!( + self.dims(), + rhs.dims(), + "Dims must match to multiply tensors." + ); let new_id = self .graph() .add_op(op::Mul) @@ -114,8 +118,8 @@ impl DivAssign for GraphTensor { impl Rem for GraphTensor { type Output = GraphTensor; - fn rem(mut self, mut rhs: GraphTensor) -> Self::Output { - resolve_local_dyn_dims(&mut self.shape, &mut rhs.shape, false); + fn rem(self, rhs: GraphTensor) -> Self::Output { + assert_eq!(self.dims(), rhs.dims(), "Dims must match to mod tensors."); let new_id = self .graph() .add_op(op::Mod) @@ -144,7 +148,7 @@ impl> Add for GraphTensor { type Output = GraphTensor; fn add(self, rhs: S) -> Self::Output { - self + self.graph().constant_expr(rhs).expand_to(self.shape) + self + self.graph().constant(rhs).expand_to(self.shape) } } @@ -160,7 +164,7 @@ impl> Sub for GraphTensor { type Output = GraphTensor; fn sub(self, rhs: S) -> Self::Output { - self - self.graph().constant_expr(rhs).expand_to(self.shape) + self - self.graph().constant(rhs).expand_to(self.shape) } } @@ -176,7 +180,7 @@ impl> Mul for GraphTensor { type Output = GraphTensor; fn mul(self, rhs: S) -> Self::Output { - self * self.graph().constant_expr(rhs).expand_to(self.shape) + self * self.graph().constant(rhs).expand_to(self.shape) } } @@ -193,7 +197,7 @@ impl> Div for GraphTensor { type Output = GraphTensor; fn div(self, rhs: S) -> Self::Output { - self / self.graph().constant_expr(rhs).expand_to(self.shape) + self / self.graph().constant(rhs).expand_to(self.shape) } } @@ -209,14 +213,14 @@ impl> Rem for GraphTensor { type Output = GraphTensor; fn rem(self, rhs: S) -> Self::Output { - self % self.graph().constant_expr(rhs).expand_to(self.shape) + self % self.graph().constant(rhs).expand_to(self.shape) } } // Comparisons (based on https://github.com/tinygrad/tinygrad/blob/3e0c2d256fe9f4f5f85cd3e4d8733a51d7b4a984/tinygrad/tensor.py#L653) impl GraphTensor { - pub fn less_than(mut self, mut rhs: GraphTensor) -> GraphTensor { - resolve_local_dyn_dims(&mut self.shape, &mut rhs.shape, false); + pub fn less_than(self, rhs: GraphTensor) -> GraphTensor { + assert_eq!(self.dims(), rhs.dims(), "Dims must match to lt tensors."); let new_id = self .graph() .add_op(op::LessThan) diff --git a/src/hl_ops/other.rs b/src/hl_ops/other.rs index 555f100f..5a1b0634 100644 --- a/src/hl_ops/other.rs +++ b/src/hl_ops/other.rs @@ -71,14 +71,9 @@ impl From for ConstantValue { ConstantValue::Float(value as f32) } } -impl From for ConstantValue { - fn from(value: Expression) -> Self { - ConstantValue::Expression(value) - } -} -impl From<&Expression> for ConstantValue { - fn from(value: &Expression) -> Self { - ConstantValue::Expression(*value) +impl> From for ConstantValue { + fn from(value: T) -> Self { + ConstantValue::Expression(value.into()) } } @@ -92,19 +87,6 @@ impl Graph { ) } - /// A scalar constant evaluated from an expression at runtime - pub fn constant_expr>(&mut self, expr: E) -> GraphTensor { - GraphTensor::from_id( - self.add_op(Constant( - ConstantValue::Expression(expr.into().simplify()), - &self.dyn_map, - )) - .finish(), - ShapeTracker::new(()), - self, - ) - } - /// ARange from 0 to N pub fn arange(&mut self, to: impl Into) -> GraphTensor { let to = to.into(); diff --git a/src/hl_ops/reduction.rs b/src/hl_ops/reduction.rs index 4738f8f5..afa82059 100644 --- a/src/hl_ops/reduction.rs +++ b/src/hl_ops/reduction.rs @@ -7,7 +7,6 @@ impl GraphTensor { /// Reduce a dimension of the tensor by summing all elements along that axis. pub fn sum_reduce(self, axes: impl ToAxes) -> GraphTensor { let mut shape = self.shape; - let mut new_id = self.id; for dim in axes.to_axes().into_iter().rev() { new_id = self @@ -40,31 +39,12 @@ impl GraphTensor { /// Reduce a dimension of the tensor by taking the mean of all elements along that axis. pub fn mean_reduce(self, axes: impl ToAxes) -> GraphTensor { - let mut shape = self.shape; - let mut node_id = self.id; - for dim in axes.to_axes().into_iter().rev() { - // Sum reduce - node_id = self - .graph() - .add_op(op::SumReduce(dim)) - .input(node_id, 0, shape) - .finish(); - - // Divide by size of dimension - let div_tensor = self.graph().constant_expr(shape.remove_dim(dim)).id; - let mul_tensor = self - .graph() - .add_op(op::Recip) - .input(div_tensor, 0, ShapeTracker::new(())) - .finish(); - node_id = self - .graph() - .add_op(op::Mul) - .input(node_id, 0, shape) - .input(mul_tensor, 0, ShapeTracker::fake(shape)) - .finish(); - } - GraphTensor::from_id(node_id, shape, self.graph_ref) + let mul_factor = 1 / axes + .to_axes() + .into_iter() + .map(|i| self.dims()[i]) + .product::(); + (self * mul_factor).sum_reduce(axes) } /// Reduce a dimension of the tensor by multiplying all elements along that axis. @@ -81,10 +61,8 @@ mod tests { fn test_sum_reduce() { let mut cx = Graph::new(); let a_data = random_vec(6); - let a = cx.tensor((2, 3)); - a.set(a_data.clone()); - let b = a.sum_reduce(1); - b.retrieve(); + let a = cx.tensor((2, 3)).set(a_data.clone()); + let b = a.sum_reduce(1).retrieve(); cx.execute(); @@ -99,10 +77,8 @@ mod tests { fn test_max_reduce() { let mut cx = Graph::new(); let a_data = random_vec(6); - let a = cx.tensor((2, 3)); - a.set(a_data.clone()); - let b = a.max_reduce(1); - b.retrieve(); + let a = cx.tensor((2, 3)).set(a_data.clone()); + let b = a.max_reduce(1).retrieve(); cx.execute(); @@ -117,10 +93,8 @@ mod tests { fn test_mean_reduce() { let mut cx = Graph::new(); let a_data = random_vec(6); - let a = cx.tensor((2, 3)); - a.set(a_data.clone()); - let b = a.mean_reduce(1); - b.retrieve(); + let a = cx.tensor((2, 3)).set(a_data.clone()); + let b = a.mean_reduce(1).retrieve(); cx.execute(); diff --git a/src/op.rs b/src/op.rs index 4eb48971..c4246e14 100644 --- a/src/op.rs +++ b/src/op.rs @@ -156,7 +156,7 @@ impl Operator for Constant { fn process(&mut self, _: Vec<(InputTensor, ShapeTracker)>) -> Vec { vec![Tensor::new(vec![match &self.0 { ConstantValue::Expression(e) => { - e.exec(unsafe { self.1.as_ref().unwrap() }).unwrap() as f32 + e.exec_float(unsafe { self.1.as_ref().unwrap() }).unwrap() as f32 } ConstantValue::Float(f) => *f, }])] @@ -280,6 +280,7 @@ pub struct Mul; impl Operator for Mul { fn process(&mut self, inp: Vec<(InputTensor, ShapeTracker)>) -> Vec { let (lhs, rhs) = (get_vec(&inp[0].0), get_vec(&inp[1].0)); + println!("EXPR: {:?}", inp[0].1.dims()); let mut out_data = vec![0.; inp[0].1.n_elements().to_usize().unwrap()]; let lexpr = (inp[0].1.index_expression(), inp[0].1.valid_expression()); let rexpr = (inp[1].1.index_expression(), inp[1].1.valid_expression()); diff --git a/src/shape/symbolic.rs b/src/shape/symbolic.rs index 505976ec..bf4a7b6d 100644 --- a/src/shape/symbolic.rs +++ b/src/shape/symbolic.rs @@ -112,6 +112,22 @@ impl Term { _ => None, } } + pub fn as_float_op(self) -> Option f64> { + match self { + Term::Add => Some(|a, b| a + b), + Term::Sub => Some(|a, b| a - b), + Term::Mul => Some(|a, b| a * b), + Term::Div => Some(|a, b| a / b), + Term::Mod => Some(|a, b| a % b), + Term::Max => Some(|a, b| a.max(b)), + Term::Min => Some(|a, b| a.min(b)), + Term::And => Some(|a, b| (a.abs() > 1e-4 && b.abs() > 1e-4) as i32 as f64), + Term::Or => Some(|a, b| (a.abs() > 1e-4 || b.abs() > 1e-4) as i32 as f64), + Term::Gte => Some(|a, b| (a >= b) as i32 as f64), + Term::Lt => Some(|a, b| (a < b) as i32 as f64), + _ => None, + } + } } impl PartialEq for Expression @@ -345,6 +361,37 @@ impl Expression { } stack.pop().map(|i| i as usize) } + /// Evaluate the expression given variables. + pub fn exec_float(&self, variables: &FxHashMap) -> Option { + self.exec_stack_float(variables, &mut Vec::new()) + } + /// Evaluate the expression given variables. This function requires a stack to be given for use as storage + pub fn exec_stack_float( + &self, + variables: &FxHashMap, + stack: &mut Vec, + ) -> Option { + for term in self.terms.read().iter() { + match term { + Term::Num(n) => stack.push(*n as f64), + Term::Var(c) => + { + #[allow(clippy::needless_borrow)] + if let Some(n) = variables.get(&c) { + stack.push(*n as f64) + } else { + return None; + } + } + _ => { + let a = stack.pop().unwrap(); + let b = stack.pop().unwrap(); + stack.push(term.as_float_op().unwrap()(a, b)); + } + } + } + stack.pop() + } /// Retrieve all symbols in the expression. pub fn to_symbols(&self) -> Vec { self.terms @@ -523,7 +570,11 @@ impl> Div for Expression { return 0.into(); } if let (Some(a), Some(b)) = (self.as_num(), rhs.as_num()) { - return (a / b).into(); + if a % b == 0 { + if let Some(c) = a.checked_div(b) { + return c.into(); + } + } } let mut terms = rhs.terms.read().clone(); terms.extend(self.terms.read().iter().copied()); diff --git a/src/shape/tracker.rs b/src/shape/tracker.rs index a45f7cfb..cebc3d71 100644 --- a/src/shape/tracker.rs +++ b/src/shape/tracker.rs @@ -332,29 +332,6 @@ fn pad_mask_dim( (padding.0 + padding.1 + dim).min(mask.1) - mask.0 } -/// Resolve shapes between the two trackers to the best of our ability -pub fn resolve_local_dyn_dims(a: &mut ShapeTracker, b: &mut ShapeTracker, default_to_one: bool) { - // B to A - for i in 0..a.dims.len() { - if a.dims[a.indexes[i]].is_unknown() { - a.dims[a.indexes[i]] = b.dims[b.indexes[i]]; - if a.dims[a.indexes[i]].is_unknown() && default_to_one { - a.dims[a.indexes[i]] = 1.into(); - } - } - } - - // A to B - for i in 0..a.dims.len() { - if b.dims[b.indexes[i]].is_unknown() { - b.dims[b.indexes[i]] = a.dims[a.indexes[i]]; - if b.dims[b.indexes[i]].is_unknown() && default_to_one { - b.dims[b.indexes[i]] = 1.into(); - } - } - } -} - #[cfg(test)] mod tests { use crate::prelude::*;