Skip to content

Commit

Permalink
Changed mean reduce
Browse files Browse the repository at this point in the history
  • Loading branch information
jafioti committed Jul 30, 2024
1 parent 81db899 commit e1279d9
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 97 deletions.
30 changes: 17 additions & 13 deletions src/hl_ops/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ use std::ops::{Add, Div, Mul, Rem, Sub};
impl Add for GraphTensor {
type Output = GraphTensor;

fn add(mut self, mut rhs: GraphTensor) -> Self::Output {
resolve_local_dyn_dims(&mut self.shape, &mut rhs.shape, false);
fn add(self, rhs: GraphTensor) -> Self::Output {
assert_eq!(self.dims(), rhs.dims(), "Dims must match to add tensors.");
let new_id = self
.graph()
.add_op(op::Add)
Expand Down Expand Up @@ -61,8 +61,12 @@ impl SubAssign for GraphTensor {
impl Mul for GraphTensor {
type Output = GraphTensor;

fn mul(mut self, mut rhs: GraphTensor) -> Self::Output {
resolve_local_dyn_dims(&mut self.shape, &mut rhs.shape, false);
fn mul(self, rhs: GraphTensor) -> Self::Output {
assert_eq!(
self.dims(),
rhs.dims(),
"Dims must match to multiply tensors."
);
let new_id = self
.graph()
.add_op(op::Mul)
Expand Down Expand Up @@ -114,8 +118,8 @@ impl DivAssign for GraphTensor {
impl Rem<GraphTensor> for GraphTensor {
type Output = GraphTensor;

fn rem(mut self, mut rhs: GraphTensor) -> Self::Output {
resolve_local_dyn_dims(&mut self.shape, &mut rhs.shape, false);
fn rem(self, rhs: GraphTensor) -> Self::Output {
assert_eq!(self.dims(), rhs.dims(), "Dims must match to mod tensors.");
let new_id = self
.graph()
.add_op(op::Mod)
Expand Down Expand Up @@ -144,7 +148,7 @@ impl<S: Into<Expression>> Add<S> for GraphTensor {
type Output = GraphTensor;

fn add(self, rhs: S) -> Self::Output {
self + self.graph().constant_expr(rhs).expand_to(self.shape)
self + self.graph().constant(rhs).expand_to(self.shape)
}
}

Expand All @@ -160,7 +164,7 @@ impl<S: Into<Expression>> Sub<S> for GraphTensor {
type Output = GraphTensor;

fn sub(self, rhs: S) -> Self::Output {
self - self.graph().constant_expr(rhs).expand_to(self.shape)
self - self.graph().constant(rhs).expand_to(self.shape)
}
}

Expand All @@ -176,7 +180,7 @@ impl<S: Into<Expression>> Mul<S> for GraphTensor {
type Output = GraphTensor;

fn mul(self, rhs: S) -> Self::Output {
self * self.graph().constant_expr(rhs).expand_to(self.shape)
self * self.graph().constant(rhs).expand_to(self.shape)
}
}

Expand All @@ -193,7 +197,7 @@ impl<S: Into<Expression>> Div<S> for GraphTensor {
type Output = GraphTensor;

fn div(self, rhs: S) -> Self::Output {
self / self.graph().constant_expr(rhs).expand_to(self.shape)
self / self.graph().constant(rhs).expand_to(self.shape)
}
}

Expand All @@ -209,14 +213,14 @@ impl<S: Into<Expression>> Rem<S> for GraphTensor {
type Output = GraphTensor;

fn rem(self, rhs: S) -> Self::Output {
self % self.graph().constant_expr(rhs).expand_to(self.shape)
self % self.graph().constant(rhs).expand_to(self.shape)
}
}

// Comparisons (based on https://github.com/tinygrad/tinygrad/blob/3e0c2d256fe9f4f5f85cd3e4d8733a51d7b4a984/tinygrad/tensor.py#L653)
impl GraphTensor {
pub fn less_than(mut self, mut rhs: GraphTensor) -> GraphTensor {
resolve_local_dyn_dims(&mut self.shape, &mut rhs.shape, false);
pub fn less_than(self, rhs: GraphTensor) -> GraphTensor {
assert_eq!(self.dims(), rhs.dims(), "Dims must match to lt tensors.");
let new_id = self
.graph()
.add_op(op::LessThan)
Expand Down
24 changes: 3 additions & 21 deletions src/hl_ops/other.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,14 +71,9 @@ impl From<f64> for ConstantValue {
ConstantValue::Float(value as f32)
}
}
impl From<Expression> for ConstantValue {
fn from(value: Expression) -> Self {
ConstantValue::Expression(value)
}
}
impl From<&Expression> for ConstantValue {
fn from(value: &Expression) -> Self {
ConstantValue::Expression(*value)
impl<T: Into<Expression>> From<T> for ConstantValue {
fn from(value: T) -> Self {
ConstantValue::Expression(value.into())
}
}

Expand All @@ -92,19 +87,6 @@ impl Graph {
)
}

/// A scalar constant evaluated from an expression at runtime
pub fn constant_expr<E: Into<Expression>>(&mut self, expr: E) -> GraphTensor {
GraphTensor::from_id(
self.add_op(Constant(
ConstantValue::Expression(expr.into().simplify()),
&self.dyn_map,
))
.finish(),
ShapeTracker::new(()),
self,
)
}

/// ARange from 0 to N
pub fn arange(&mut self, to: impl Into<Expression>) -> GraphTensor {
let to = to.into();
Expand Down
50 changes: 12 additions & 38 deletions src/hl_ops/reduction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ impl GraphTensor {
/// Reduce a dimension of the tensor by summing all elements along that axis.
pub fn sum_reduce(self, axes: impl ToAxes) -> GraphTensor {
let mut shape = self.shape;

let mut new_id = self.id;
for dim in axes.to_axes().into_iter().rev() {
new_id = self
Expand Down Expand Up @@ -40,31 +39,12 @@ impl GraphTensor {

/// Reduce a dimension of the tensor by taking the mean of all elements along that axis.
pub fn mean_reduce(self, axes: impl ToAxes) -> GraphTensor {
let mut shape = self.shape;
let mut node_id = self.id;
for dim in axes.to_axes().into_iter().rev() {
// Sum reduce
node_id = self
.graph()
.add_op(op::SumReduce(dim))
.input(node_id, 0, shape)
.finish();

// Divide by size of dimension
let div_tensor = self.graph().constant_expr(shape.remove_dim(dim)).id;
let mul_tensor = self
.graph()
.add_op(op::Recip)
.input(div_tensor, 0, ShapeTracker::new(()))
.finish();
node_id = self
.graph()
.add_op(op::Mul)
.input(node_id, 0, shape)
.input(mul_tensor, 0, ShapeTracker::fake(shape))
.finish();
}
GraphTensor::from_id(node_id, shape, self.graph_ref)
let mul_factor = 1 / axes
.to_axes()
.into_iter()
.map(|i| self.dims()[i])
.product::<Expression>();
(self * mul_factor).sum_reduce(axes)
}

/// Reduce a dimension of the tensor by multiplying all elements along that axis.
Expand All @@ -81,10 +61,8 @@ mod tests {
fn test_sum_reduce() {
let mut cx = Graph::new();
let a_data = random_vec(6);
let a = cx.tensor((2, 3));
a.set(a_data.clone());
let b = a.sum_reduce(1);
b.retrieve();
let a = cx.tensor((2, 3)).set(a_data.clone());
let b = a.sum_reduce(1).retrieve();

cx.execute();

Expand All @@ -99,10 +77,8 @@ mod tests {
fn test_max_reduce() {
let mut cx = Graph::new();
let a_data = random_vec(6);
let a = cx.tensor((2, 3));
a.set(a_data.clone());
let b = a.max_reduce(1);
b.retrieve();
let a = cx.tensor((2, 3)).set(a_data.clone());
let b = a.max_reduce(1).retrieve();

cx.execute();

Expand All @@ -117,10 +93,8 @@ mod tests {
fn test_mean_reduce() {
let mut cx = Graph::new();
let a_data = random_vec(6);
let a = cx.tensor((2, 3));
a.set(a_data.clone());
let b = a.mean_reduce(1);
b.retrieve();
let a = cx.tensor((2, 3)).set(a_data.clone());
let b = a.mean_reduce(1).retrieve();

cx.execute();

Expand Down
3 changes: 2 additions & 1 deletion src/op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ impl Operator for Constant {
fn process(&mut self, _: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
vec![Tensor::new(vec![match &self.0 {
ConstantValue::Expression(e) => {
e.exec(unsafe { self.1.as_ref().unwrap() }).unwrap() as f32
e.exec_float(unsafe { self.1.as_ref().unwrap() }).unwrap() as f32
}
ConstantValue::Float(f) => *f,
}])]
Expand Down Expand Up @@ -280,6 +280,7 @@ pub struct Mul;
impl Operator for Mul {
fn process(&mut self, inp: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
let (lhs, rhs) = (get_vec(&inp[0].0), get_vec(&inp[1].0));
println!("EXPR: {:?}", inp[0].1.dims());
let mut out_data = vec![0.; inp[0].1.n_elements().to_usize().unwrap()];
let lexpr = (inp[0].1.index_expression(), inp[0].1.valid_expression());
let rexpr = (inp[1].1.index_expression(), inp[1].1.valid_expression());
Expand Down
53 changes: 52 additions & 1 deletion src/shape/symbolic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,22 @@ impl Term {
_ => None,
}
}
pub fn as_float_op(self) -> Option<fn(f64, f64) -> f64> {
match self {
Term::Add => Some(|a, b| a + b),
Term::Sub => Some(|a, b| a - b),
Term::Mul => Some(|a, b| a * b),
Term::Div => Some(|a, b| a / b),
Term::Mod => Some(|a, b| a % b),
Term::Max => Some(|a, b| a.max(b)),
Term::Min => Some(|a, b| a.min(b)),
Term::And => Some(|a, b| (a.abs() > 1e-4 && b.abs() > 1e-4) as i32 as f64),
Term::Or => Some(|a, b| (a.abs() > 1e-4 || b.abs() > 1e-4) as i32 as f64),
Term::Gte => Some(|a, b| (a >= b) as i32 as f64),
Term::Lt => Some(|a, b| (a < b) as i32 as f64),
_ => None,
}
}
}

impl<T> PartialEq<T> for Expression
Expand Down Expand Up @@ -345,6 +361,37 @@ impl Expression {
}
stack.pop().map(|i| i as usize)
}
/// Evaluate the expression given variables.
pub fn exec_float(&self, variables: &FxHashMap<char, usize>) -> Option<f64> {
self.exec_stack_float(variables, &mut Vec::new())
}
/// Evaluate the expression given variables. This function requires a stack to be given for use as storage
pub fn exec_stack_float(
&self,
variables: &FxHashMap<char, usize>,
stack: &mut Vec<f64>,
) -> Option<f64> {
for term in self.terms.read().iter() {
match term {
Term::Num(n) => stack.push(*n as f64),
Term::Var(c) =>
{
#[allow(clippy::needless_borrow)]
if let Some(n) = variables.get(&c) {
stack.push(*n as f64)
} else {
return None;
}
}
_ => {
let a = stack.pop().unwrap();
let b = stack.pop().unwrap();
stack.push(term.as_float_op().unwrap()(a, b));
}
}
}
stack.pop()
}
/// Retrieve all symbols in the expression.
pub fn to_symbols(&self) -> Vec<char> {
self.terms
Expand Down Expand Up @@ -523,7 +570,11 @@ impl<E: Into<Expression>> Div<E> for Expression {
return 0.into();
}
if let (Some(a), Some(b)) = (self.as_num(), rhs.as_num()) {
return (a / b).into();
if a % b == 0 {
if let Some(c) = a.checked_div(b) {
return c.into();
}
}
}
let mut terms = rhs.terms.read().clone();
terms.extend(self.terms.read().iter().copied());
Expand Down
23 changes: 0 additions & 23 deletions src/shape/tracker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -332,29 +332,6 @@ fn pad_mask_dim(
(padding.0 + padding.1 + dim).min(mask.1) - mask.0
}

/// Resolve shapes between the two trackers to the best of our ability
pub fn resolve_local_dyn_dims(a: &mut ShapeTracker, b: &mut ShapeTracker, default_to_one: bool) {
// B to A
for i in 0..a.dims.len() {
if a.dims[a.indexes[i]].is_unknown() {
a.dims[a.indexes[i]] = b.dims[b.indexes[i]];
if a.dims[a.indexes[i]].is_unknown() && default_to_one {
a.dims[a.indexes[i]] = 1.into();
}
}
}

// A to B
for i in 0..a.dims.len() {
if b.dims[b.indexes[i]].is_unknown() {
b.dims[b.indexes[i]] = a.dims[a.indexes[i]];
if b.dims[b.indexes[i]].is_unknown() && default_to_one {
b.dims[b.indexes[i]] = 1.into();
}
}
}
}

#[cfg(test)]
mod tests {
use crate::prelude::*;
Expand Down

0 comments on commit e1279d9

Please sign in to comment.