Skip to content

Commit

Permalink
Cleaned up symbolic more
Browse files Browse the repository at this point in the history
  • Loading branch information
jafioti committed Apr 26, 2024
1 parent b359ecf commit a39176d
Show file tree
Hide file tree
Showing 7 changed files with 99 additions and 86 deletions.
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ Cargo.lock
*.npx
*.npz
/**/llama-7b-hf
/**/mistral-7b-hf
/**/setup_weights/target
*.model
*.gguf
76 changes: 54 additions & 22 deletions crates/luminal_symbolic/src/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ pub trait ExpressionStorage:
fn push(&mut self, term: Term);
fn pop(&mut self) -> Option<Term>;
fn remove(&mut self, index: usize) -> Term;
fn into_vec(self) -> Vec<Term>;
}

// Implement the main storage types
Expand All @@ -44,6 +45,9 @@ impl ExpressionStorage for Vec<Term> {
fn remove(&mut self, index: usize) -> Term {
Vec::remove(self, index)
}
fn into_vec(self) -> Vec<Term> {
self
}
}

impl<const C: usize> ExpressionStorage for ArrayVec<[Term; C]>
Expand All @@ -62,14 +66,26 @@ where
fn remove(&mut self, index: usize) -> Term {
ArrayVec::remove(self, index)
}
fn into_vec(self) -> Vec<Term> {
self.to_vec()
}
}

/// A symbolic expression
#[derive(Clone, Copy, Hash, PartialEq, Eq)]
#[derive(Clone, Copy, Hash, Eq)]
pub struct GenericExpression<S: ExpressionStorage> {
pub terms: S,
}

impl<S: ExpressionStorage, T> PartialEq<T> for GenericExpression<S>
where
for<'a> &'a T: Into<Self>,
{
fn eq(&self, other: &T) -> bool {
self.terms == other.into().terms
}
}

impl<S: ExpressionStorage> Default for GenericExpression<S> {
fn default() -> Self {
let mut s = S::default();
Expand Down Expand Up @@ -261,6 +277,18 @@ where
}
}

impl Expression {
pub fn big(&self) -> BigExpression {
BigExpression::from(*self)
}
}

impl BigExpression {
pub fn small(&self) -> Expression {
Expression::from(self)
}
}

impl<S: ExpressionStorage> From<Term> for GenericExpression<S> {
fn from(value: Term) -> Self {
let mut terms = S::default();
Expand Down Expand Up @@ -317,6 +345,16 @@ impl<S: ExpressionStorage> From<&bool> for GenericExpression<S> {
}
}

impl<S: ExpressionStorage, T: ExpressionStorage> From<&GenericExpression<T>>
for GenericExpression<S>
{
fn from(value: &GenericExpression<T>) -> Self {
let mut s = S::default();
s.extend(value.terms.clone().into_vec());
Self { terms: s }
}
}

impl From<Expression> for BigExpression {
fn from(value: Expression) -> Self {
Self {
Expand All @@ -325,12 +363,6 @@ impl From<Expression> for BigExpression {
}
}

impl From<&Expression> for Expression {
fn from(value: &Expression) -> Self {
*value
}
}

impl From<BigExpression> for Expression {
fn from(value: BigExpression) -> Self {
let mut terms = ArrayVec::new();
Expand All @@ -343,10 +375,10 @@ impl<S: ExpressionStorage, E: Into<Self>> Add<E> for GenericExpression<S> {
type Output = Self;
fn add(self, rhs: E) -> Self::Output {
let mut rhs = rhs.into();
if rhs == Self::from(0) {
if rhs == 0 {
return self;
}
if self == Self::from(0) {
if self == 0 {
return rhs;
}
if self == rhs {
Expand All @@ -362,7 +394,7 @@ impl<S: ExpressionStorage, E: Into<Self>> Sub<E> for GenericExpression<S> {
type Output = Self;
fn sub(self, rhs: E) -> Self::Output {
let mut rhs = rhs.into();
if rhs == Self::from(0) {
if rhs == 0 {
return self;
}
if self == rhs {
Expand All @@ -378,13 +410,13 @@ impl<S: ExpressionStorage, E: Into<Self>> Mul<E> for GenericExpression<S> {
type Output = Self;
fn mul(self, rhs: E) -> Self::Output {
let mut rhs = rhs.into();
if rhs == Self::from(1) {
if rhs == 1 {
return self;
}
if self == Self::from(1) {
if self == 1 {
return rhs;
}
if rhs == Self::from(0) || self == Self::from(0) {
if rhs == 0 || self == 0 {
return 0.into();
}
rhs.terms.extend(self.terms);
Expand All @@ -397,13 +429,13 @@ impl<S: ExpressionStorage, E: Into<Self>> Div<E> for GenericExpression<S> {
type Output = Self;
fn div(self, rhs: E) -> Self::Output {
let mut rhs = rhs.into();
if rhs == Self::from(1) {
if rhs == 1 {
return self;
}
if self == rhs {
return 1.into();
}
if self == Self::from(0) {
if self == 0 {
return 0.into();
}
rhs.terms.extend(self.terms);
Expand All @@ -416,8 +448,8 @@ impl<S: ExpressionStorage, E: Into<Self>> Rem<E> for GenericExpression<S> {
type Output = Self;
fn rem(self, rhs: E) -> Self::Output {
let mut rhs = rhs.into();
if rhs == Self::from(1) || rhs == self {
return Self::from(0);
if rhs == 1 || rhs == self {
return 0.into();
}
rhs.terms.extend(self.terms);
rhs.terms.push(Term::Mod);
Expand All @@ -429,13 +461,13 @@ impl<S: ExpressionStorage, E: Into<Self>> BitAnd<E> for GenericExpression<S> {
type Output = Self;
fn bitand(self, rhs: E) -> Self::Output {
let mut rhs = rhs.into();
if rhs == Self::from(0) || self == Self::from(0) {
return Self::from(0);
if rhs == 0 || self == 0 {
return 0.into();
}
if rhs == Self::from(1) {
if rhs == 1 {
return self;
}
if self == Self::from(1) {
if self == 1 {
return rhs;
}
rhs.terms.extend(self.terms);
Expand All @@ -448,7 +480,7 @@ impl<S: ExpressionStorage, E: Into<Self>> BitOr<E> for GenericExpression<S> {
type Output = Self;
fn bitor(self, rhs: E) -> Self::Output {
let mut rhs = rhs.into();
if rhs == Self::from(1) || self == Self::from(1) {
if rhs == 1 || self == 1 {
return 1.into();
}
rhs.terms.extend(self.terms);
Expand Down
2 changes: 1 addition & 1 deletion crates/luminal_symbolic/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ fn test_expressions() {
fn test_minimizations() {
let expr = ((BigExpression::from('a') * 1) + 0) / 1 + (1 - 1);
let reduced_expr = expr.simplify();
assert_eq!(reduced_expr, 'a'.into());
assert_eq!(reduced_expr, 'a');
}

#[test]
Expand Down
34 changes: 16 additions & 18 deletions src/hl_ops/movement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,9 @@ impl<S: Shape> GraphTensor<S> {
let ranges = slice.to_range_vec();
// This exists because currently padding and slicing on the same dimension (even on opposite sides) is unsupported
if ranges.iter().zip(self.shape.indexes).any(|(range, ind)| {
(range.0 != 0.into() || range.1 != i32::MAX.into())
&& (self.shape.padding[self.shape.indexes[ind]].0 != 0.into()
|| self.shape.padding[self.shape.indexes[ind]].1 != 0.into())
(range.0 != 0 || range.1 != i32::MAX)
&& (self.shape.padding[self.shape.indexes[ind]].0 != 0
|| self.shape.padding[self.shape.indexes[ind]].1 != 0)
}) {
self = self.contiguous();
}
Expand All @@ -114,7 +114,7 @@ impl<S: Shape> GraphTensor<S> {
self = self.contiguous();
// Expand a new dimension to do the slicing on
let n_rows = total_size / (spacing + size);
self.shape.expand(n_dims, (spacing + size).into());
self.shape.expand(n_dims, spacing + size);
// self = self.contiguous();
self.shape.dims[self.shape.indexes[n_dims - 1]] = n_rows;
self.shape.fake[self.shape.indexes[n_dims]] = false;
Expand All @@ -141,23 +141,21 @@ impl<S: Shape> GraphTensor<S> {
let number_of_windows = ((dim_size - full_kernel) / stride) + 1;
// Expand new dimension
self.shape.expand(n_dims - 1, number_of_windows);

let orig_width = BigExpression::from(dim_size);

self = self.contiguous();
if n_dims > 1 {
// View as single dimension of matrix with wider width
let mat_size = (orig_width.clone() + stride) * number_of_windows;
let actual_size = orig_width.clone() * self.shape.dims[self.shape.indexes[n_dims - 1]];
let mat_size = (dim_size.big() + stride.big()) * number_of_windows.big();
let actual_size =
dim_size.big() * self.shape.dims[self.shape.indexes[n_dims - 1]].big();
// Reshape into single dimension to pad
self.shape.remove_dim(n_dims);
self.shape.dims[self.shape.indexes[n_dims - 1]] = actual_size.clone().into();
self.shape.padding[self.shape.indexes[n_dims - 1]].1 = (mat_size - actual_size).into();
self.shape.dims[self.shape.indexes[n_dims - 1]] = actual_size.small();
self.shape.padding[self.shape.indexes[n_dims - 1]].1 = (mat_size - actual_size).small();
self = self.contiguous();
// Reshape back (mats should be full now)
self.shape.add_dim(n_dims, (orig_width + stride).into());
self.shape.add_dim(n_dims, dim_size + stride);
} else {
self.shape.dims[self.shape.indexes[n_dims]] = (orig_width + stride).into();
self.shape.dims[self.shape.indexes[n_dims]] = dim_size + stride;
}
self.shape.dims[self.shape.indexes[n_dims - 1]] = number_of_windows;
// Slice down to kernel size
Expand All @@ -184,9 +182,9 @@ impl<S: Shape> GraphTensor<S> {
.collect::<Vec<_>>();
// This exists because currently padding and slicing on the same dimension (even on opposite sides) is unsupported
if ranges.iter().zip(self.shape.indexes).any(|(range, ind)| {
(range.0 != 0.into() || range.1 != 0.into())
&& (self.shape.slices[self.shape.indexes[ind]].0 != 0.into()
|| self.shape.slices[self.shape.indexes[ind]].1 != i32::MAX.into())
(range.0 != 0 || range.1 != 0)
&& (self.shape.slices[self.shape.indexes[ind]].0 != 0
|| self.shape.slices[self.shape.indexes[ind]].1 != i32::MAX)
}) {
self = self.contiguous();
}
Expand All @@ -201,9 +199,9 @@ impl<S: Shape> GraphTensor<S> {
let dim = Ax::as_array()[0];
// Create padding
let mut a_padding = vec![(Expression::default(), Expression::default()); self.shape.len()];
a_padding[dim].1 = rhs.shape.shape()[dim].clone().into();
a_padding[dim].1 = rhs.shape.shape()[dim].small();
let mut b_padding = vec![(Expression::default(), Expression::default()); rhs.shape.len()];
b_padding[dim].0 = self.shape.shape()[dim].clone().into();
b_padding[dim].0 = self.shape.shape()[dim].small();
// Pad and add
(self.pad(&a_padding) + rhs.pad(&b_padding)).sync_shape()
}
Expand Down
4 changes: 2 additions & 2 deletions src/hl_ops/other.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ impl From<BigExpression> for ConstantValue {
}
impl From<Expression> for ConstantValue {
fn from(value: Expression) -> Self {
ConstantValue::Expression(value.into())
ConstantValue::Expression((&value).into())
}
}
impl From<&BigExpression> for ConstantValue {
Expand All @@ -65,7 +65,7 @@ impl From<&BigExpression> for ConstantValue {
}
impl From<&Expression> for ConstantValue {
fn from(value: &Expression) -> Self {
ConstantValue::Expression((*value).into())
ConstantValue::Expression(value.into())
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/hl_ops/reduction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ impl<S: Shape> GraphTensor<S> {
ShapeTracker::fake(
&shape
.shape()
.into_iter()
.iter()
.map(Expression::from)
.collect::<Vec<_>>(),
),
Expand Down
Loading

0 comments on commit a39176d

Please sign in to comment.