Skip to content

Commit

Permalink
Refactored expression system
Browse files Browse the repository at this point in the history
  • Loading branch information
jafioti committed Jul 21, 2024
1 parent 2bdad48 commit b5c38dc
Show file tree
Hide file tree
Showing 48 changed files with 1,808 additions and 969 deletions.
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ as-any = "0.3.1"
egg = "0.9.5"
symbolic_expressions = "5.0.3"
serde = {version="1.0.202", features=["derive"]}
thread_local = "1.1.8"
generational-box = "0.5.6"

[dev-dependencies]
dfdx = { version = "0.13", features = ["f16"] }
Expand Down
2 changes: 1 addition & 1 deletion crates/luminal_cpu/src/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ impl Compiler for GatherCompiler {
.as_data()
.unwrap()
.2
.shape()[2]
.dims()[2]
.to_usize()
.unwrap();

Expand Down
4 changes: 2 additions & 2 deletions crates/luminal_cpu/src/matmul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ pub struct MatMul2D;

impl Operator for MatMul2D {
fn process(&mut self, inp: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
let (a_shape, b_shape) = (inp[0].1.shape(), inp[1].1.shape());
let (a_shape, b_shape) = (inp[0].1.dims(), inp[1].1.dims());
let (a_strides, b_strides) = (inp[0].1.strides(), inp[1].1.strides());
let a_data = inp[0].0.borrowed().downcast_ref::<Vec<f32>>().unwrap();
let b_data = inp[1].0.borrowed().downcast_ref::<Vec<f32>>().unwrap();
Expand Down Expand Up @@ -151,7 +151,7 @@ pub struct BatchedMatMul2D;
// ABCxCD -> ABD
impl Operator for BatchedMatMul2D {
fn process(&mut self, inp: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
let (a_shape, b_shape) = (inp[0].1.shape(), inp[1].1.shape());
let (a_shape, b_shape) = (inp[0].1.dims(), inp[1].1.dims());
let (a_strides, b_strides) = (inp[0].1.strides(), inp[1].1.strides());
let a_data = inp[0].0.borrowed().downcast_ref::<Vec<f32>>().unwrap();
let b_data = inp[1].0.borrowed().downcast_ref::<Vec<f32>>().unwrap();
Expand Down
4 changes: 2 additions & 2 deletions crates/luminal_cpu/src/other.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use super::binary::Sub;

#[derive(Debug, Clone, PartialEq)]
pub struct ARange {
pub size: BigExpression,
pub size: Expression,
dyn_map: *const FxHashMap<char, usize>,
}

Expand Down Expand Up @@ -61,7 +61,7 @@ impl Compiler for ARangeCompiler {
};
let arange_op = graph
.add_op(ARange {
size: arange_amount.into(),
size: arange_amount,
dyn_map: &graph.dyn_map,
})
.finish();
Expand Down
8 changes: 4 additions & 4 deletions crates/luminal_metal/src/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ kernel void mkernel(device {type_name} *inp_a [[buffer(0)]], device {type_name}
}

impl<T> MetalKernel for MetalSub<T> {
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<BigExpression> {
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<Expression> {
vec![input_shapes[0].n_elements() * size_of::<T>()]
}
fn metal_forward(
Expand Down Expand Up @@ -252,7 +252,7 @@ kernel void mkernel(device {type_name} *inp_a [[buffer(0)]], device {type_name}
}

impl<T> MetalKernel for MetalEqual<T> {
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<BigExpression> {
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<Expression> {
vec![input_shapes[0].n_elements() * size_of::<T>()]
}
fn metal_forward(
Expand Down Expand Up @@ -419,7 +419,7 @@ impl<T: MetalFloat> Operator for MetalGather<T> {
// Setup buffers
let indexes = tensors[0].0.borrowed().downcast_ref::<Vec<f32>>().unwrap();
let index_buffer = self.device.new_buffer_with_data(
unsafe { std::mem::transmute(indexes.as_ptr()) },
indexes.as_ptr() as *const _,
(indexes.len() * std::mem::size_of::<f32>()) as u64,
MTLResourceOptions::StorageModeShared,
);
Expand Down Expand Up @@ -498,7 +498,7 @@ impl<T: MetalFloat> Compiler for MetalGatherCompiler<T> {
.as_data()
.unwrap()
.2;
let embed_dim = emb_shape.shape()[2].to_usize().unwrap();
let embed_dim = emb_shape.dims()[2].to_usize().unwrap();
let index_shape = graph
.edges_connecting(s.get(&indexes), s.get(&ind_copy))
.next()
Expand Down
4 changes: 2 additions & 2 deletions crates/luminal_metal/src/command_buffer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -236,10 +236,10 @@ impl std::fmt::Debug for CommandBufferWrapper {
}

impl MetalKernel for CommandBufferWrapper {
fn intermediate_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<BigExpression> {
fn intermediate_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<Expression> {
self.wrapper.intermediate_buffer_sizes(input_shapes)
}
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<BigExpression> {
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<Expression> {
self.wrapper.output_buffer_sizes(input_shapes)
}
fn metal_forward(
Expand Down
61 changes: 28 additions & 33 deletions crates/luminal_metal/src/elementwise_fusion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ impl<T: MetalFloat> Compiler for ElementwiseFusionCompiler<T> {
let mut subexpressions_b = graph
.try_get_op::<FusedElementwiseOp<T>>(b)
.map(|o| o.subexpressions.clone())
.unwrap_or_else(|| vec![(expression_b.clone(), ShapeTracker::default())]);
.unwrap_or_else(|| vec![(expression_b.clone(), ShapeTracker::new(()))]);
let a_to_b_indexes = graph
.edges_connecting(a, b)
.map(|e| e.weight().as_data().unwrap().0 as usize)
Expand All @@ -141,7 +141,7 @@ impl<T: MetalFloat> Compiler for ElementwiseFusionCompiler<T> {
let mut subexpressions_a = graph
.try_get_op::<FusedElementwiseOp<T>>(a)
.map(|o| o.subexpressions.clone())
.unwrap_or_else(|| vec![(expression_a.clone(), ShapeTracker::default())]);
.unwrap_or_else(|| vec![(expression_a.clone(), ShapeTracker::new(()))]);
subexpressions_a.last_mut().unwrap().1 = connecting_shape;
// Re-reference b intermediates
for i in (0..subexpressions_b.len()).rev() {
Expand Down Expand Up @@ -236,6 +236,7 @@ impl<T: MetalFloat> Compiler for ElementwiseFusionCompiler<T> {
.into_iter()
.map(|s| s.simplify_cache(&mut simplification_cache))
.collect();
let g: *mut Graph = graph;
let new_op = graph
.add_op(FusedElementwiseOp::<T> {
kernel_str: "".to_string(),
Expand All @@ -247,6 +248,7 @@ impl<T: MetalFloat> Compiler for ElementwiseFusionCompiler<T> {
device: device.clone(),
output_buffer_sizes,
_phantom: Default::default(),
graph: g,
})
.finish();
// Add edges to new op
Expand Down Expand Up @@ -293,17 +295,20 @@ impl<T: MetalFloat> Compiler for ElementwiseFusionCompiler<T> {
.into_iter()
.map(|s| s.simplify_cache(&mut simplification_cache))
.collect();
let sh = ShapeTracker::new(());
let g: *mut Graph = graph;
let new_op = graph
.add_op(FusedElementwiseOp::<T> {
kernel_str: "".to_string(),
kernel: None,
dyn_map: &graph.dyn_map,
dyn_chars: vec![],
subexpressions: vec![(op_string, ShapeTracker::default())],
subexpressions: vec![(op_string, sh)],
queue: queue.clone(),
device: device.clone(),
output_buffer_sizes,
_phantom: Default::default(),
graph: g,
})
.finish();
// Add edges to new op
Expand Down Expand Up @@ -346,8 +351,9 @@ pub struct FusedElementwiseOp<T> {
pub subexpressions: Vec<(String, ShapeTracker)>,
pub queue: CommandQueue,
pub device: Device,
pub output_buffer_sizes: Vec<BigExpression>,
pub output_buffer_sizes: Vec<Expression>,
pub _phantom: PhantomData<T>,
pub graph: *mut Graph,
}
crate::debug_type!(FusedElementwiseOp);

Expand All @@ -357,7 +363,7 @@ impl<T: MetalFloat> FusedElementwiseOp<T> {
input_shapes: Vec<ShapeTracker>,
input_regexes: &mut FxHashMap<usize, Regex>,
intermediate_match: &Regex,
simplification_cache: &mut FxHashMap<BigExpression, BigExpression>,
simplification_cache: &mut FxHashMap<Expression, Expression>,
) {
let mut subexpressions = self.subexpressions.clone();
let shapes_used = subexpressions
Expand Down Expand Up @@ -415,7 +421,7 @@ impl<T: MetalFloat> FusedElementwiseOp<T> {
s.iter()
.rev()
.take(s.len() - 1)
.fold(BigExpression::from('z'), |acc, inp| {
.fold(Expression::from('z'), |acc, inp| {
inp.index_expression().substitute('z', acc)
})
})
Expand Down Expand Up @@ -451,8 +457,8 @@ impl<T: MetalFloat> FusedElementwiseOp<T> {
input_regexes.get(&i).unwrap()
};
let (ind, val) = (
ind_exp.clone().simplify_cache(simplification_cache),
val_exp.clone().simplify_cache(simplification_cache),
ind_exp.simplify_cache(simplification_cache),
val_exp.simplify_cache(simplification_cache),
);
*subexp = re
.replace_all(
Expand All @@ -475,10 +481,10 @@ impl<T: MetalFloat> FusedElementwiseOp<T> {
.iter()
.rev()
.fold(
(BigExpression::from(true), BigExpression::from('z')),
(Expression::from(true), Expression::from('z')),
|(_, ind_acc), inp| {
(
inp.valid_expression().substitute('z', ind_acc.clone()),
inp.valid_expression().substitute('z', ind_acc),
inp.index_expression().substitute('z', ind_acc),
)
},
Expand Down Expand Up @@ -526,7 +532,7 @@ out[idx] = ({type_name})({});
}

impl<T> MetalKernel for FusedElementwiseOp<T> {
fn output_buffer_sizes(&self, _: &[ShapeTracker]) -> Vec<BigExpression> {
fn output_buffer_sizes(&self, _: &[ShapeTracker]) -> Vec<Expression> {
self.output_buffer_sizes.clone()
}
fn metal_forward(
Expand Down Expand Up @@ -669,7 +675,7 @@ mod tests {
let inp = random_vec_rng(10, &mut rng);
let a = cx.named_tensor("a", (2, 5)).set(inp);
let mut padded = a
.slice((..Expression::from(1), ..))
.slice((..1, ..))
.cos()
.pad(((0, 1), (0, 0)))
.exp2()
Expand Down Expand Up @@ -711,7 +717,7 @@ mod tests {
const HEAD_DIM: usize = 4;
let freqs = (cx.arange(HEAD_DIM / 2) * 2.0) / (HEAD_DIM as f32);
let freqs = 1000000_f32.pow(freqs);
let pos = cx.arange(SEQ) + BigExpression::from(0);
let pos = cx.arange(SEQ) + 0;
let mut emb = pos.expand(1, 1).matmul(freqs.expand(0, SEQ)).retrieve();

cx.execute();
Expand Down Expand Up @@ -775,16 +781,12 @@ mod tests {
.keep();
let freqs = (cx.arange(HEAD_DIM / 2) * 2.0) / (HEAD_DIM as f32);
let freqs = 1000000_f32.pow(freqs);
let pos = cx.arange(SEQ) + BigExpression::from(0);
let pos = cx.arange(SEQ) + 0;
let emb = pos.expand(1, 1).matmul(freqs.expand(0, SEQ));
// Split input into evens and odds
let split = a.reshape((BATCH, N_HEADS, SEQ, HEAD_DIM / 2, 2));
let x0 = split
.slice((.., .., .., .., ..Expression::from(1)))
.contiguous();
let x1 = split
.slice((.., .., .., .., Expression::from(1)..))
.contiguous();
let x0 = split.slice((.., .., .., .., ..1)).contiguous();
let x1 = split.slice((.., .., .., .., 1..)).contiguous();

// Apply sin and cos embeddings
let x0_out = x0 * emb.cos().expand_to(x0.shape) - x1 * emb.sin().expand_to(x1.shape);
Expand Down Expand Up @@ -852,15 +854,9 @@ mod tests {
}
}

fn apply_rotary_embeddings_ggml(
input: GraphTensor,
prev_seq: BigExpression,
) -> GraphTensor {
fn apply_rotary_embeddings_ggml(input: GraphTensor, prev_seq: Expression) -> GraphTensor {
assert_eq!(input.shape.len(), 4); // batch, n_heads, seq, head_dim
let batch = input.shape()[0].small();
let n_heads = input.shape()[1].small();
let seq = input.shape()[2].small();
let head_dim = input.shape()[3].small();
let (batch, n_heads, seq, head_dim) = input.dims4();
// Get freqs
let freqs =
(input.graph().arange(head_dim / 2) * 2.0) / (head_dim.to_usize().unwrap() as f32);
Expand Down Expand Up @@ -892,9 +888,8 @@ mod tests {
type Output = (GraphTensor, KVCache);
fn forward(&self, (x, (k_cache, v_cache)): (GraphTensor, KVCache)) -> Self::Output {
// x: batch, seq, hidden
let batch = x.shape()[0].small();
let seq = x.shape()[1].small();
let prev_seq = k_cache.shape()[2].small();
let (batch, seq, _) = x.dims3();
let (_, _, prev_seq, _) = k_cache.dims4();
// Apply the Projections
let queries = x
.matmul(self.q_proj.permute((1, 0)))
Expand All @@ -912,8 +907,8 @@ mod tests {
.permute((0, 2, 1, 3));

// Rotary embed queries and keys
let queries = apply_rotary_embeddings_ggml(queries, prev_seq.big());
let keys = apply_rotary_embeddings_ggml(keys, prev_seq.big());
let queries = apply_rotary_embeddings_ggml(queries, prev_seq);
let keys = apply_rotary_embeddings_ggml(keys, prev_seq);

// Add KV cache
let keys = k_cache.concat_along(keys, 2);
Expand Down
20 changes: 8 additions & 12 deletions crates/luminal_metal/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,11 +140,11 @@ impl MetalFloat for f16 {

pub trait MetalKernel: Debug {
/// Annotate the buffer sizes of the intermediate buffers
fn intermediate_buffer_sizes(&self, _: &[ShapeTracker]) -> Vec<BigExpression> {
fn intermediate_buffer_sizes(&self, _: &[ShapeTracker]) -> Vec<Expression> {
vec![]
}
/// Annotate the buffer sizes of the output buffers
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<BigExpression>;
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<Expression>;
/// Set up the kernel on the buffer
fn metal_forward(
&self,
Expand Down Expand Up @@ -227,7 +227,7 @@ impl Deref for MetalKernelWrapper {
}

impl MetalKernel for () {
fn output_buffer_sizes(&self, _: &[ShapeTracker]) -> Vec<BigExpression> {
fn output_buffer_sizes(&self, _: &[ShapeTracker]) -> Vec<Expression> {
vec![]
}
fn metal_forward(
Expand Down Expand Up @@ -365,14 +365,10 @@ fn render_dyn_dim_inputs(shapes: &[ShapeTracker], offset: usize) -> (Vec<char>,
let symbols: Vec<char> = shapes
.iter()
.flat_map(|st| {
st.shape()
st.dims()
.into_iter()
.chain(
st.padding
.into_iter()
.flat_map(|i| [i.0.into(), i.1.into()]),
)
.chain(st.mask.into_iter().flat_map(|i| [i.0.into(), i.1.into()]))
.chain(st.padding.into_iter().flat_map(|i| [i.0, i.1]))
.chain(st.mask.into_iter().flat_map(|i| [i.0, i.1]))
})
.flat_map(|d| d.to_symbols())
.unique()
Expand All @@ -389,9 +385,9 @@ fn render_dyn_dim_inputs(shapes: &[ShapeTracker], offset: usize) -> (Vec<char>,
)
}

fn expr_to_metal_string(expr: &BigExpression) -> String {
fn expr_to_metal_string(expr: &Expression) -> String {
let mut symbols = vec![];
for term in expr.terms.clone() {
for term in expr.terms.read().clone() {
let new_symbol = match term {
Term::Num(n) => n.to_string(),
Term::Var(c) => {
Expand Down
18 changes: 9 additions & 9 deletions crates/luminal_metal/src/matmul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,15 @@ impl<T> Debug for Matmul<T> {
const BM: u64 = 8;
const BN: u64 = 32;
impl<T> MetalKernel for Matmul<T> {
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<BigExpression> {
let m = input_shapes[0].shape()[input_shapes[0].len() - 2].clone();
let n = input_shapes[1].shape()[input_shapes[1].len() - 1].clone();
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<Expression> {
let m = input_shapes[0].dims()[input_shapes[0].len() - 2];
let n = input_shapes[1].dims()[input_shapes[1].len() - 1];
let batch_size = input_shapes[0]
.shape()
.dims()
.into_iter()
.take(input_shapes[0].len() - 2)
.product::<BigExpression>()
.max(BigExpression::from(1));
.product::<Expression>()
.max(1);
vec![batch_size * m * n * size_of::<T>()]
}
fn metal_forward(
Expand All @@ -54,13 +54,13 @@ impl<T> MetalKernel for Matmul<T> {
let (a_shape, b_shape) = (
inputs[0]
.1
.shape()
.dims()
.into_iter()
.map(|i| i.to_usize().unwrap())
.collect::<Vec<_>>(),
inputs[1]
.1
.shape()
.dims()
.into_iter()
.map(|i| i.to_usize().unwrap())
.collect::<Vec<_>>(),
Expand Down Expand Up @@ -158,7 +158,7 @@ impl<T: MetalFloat> Operator for Matmul<T> {
// Setup command queue / command buffer / encoder
let command_buffer = self.queue.new_command_buffer();

let (a_shape, b_shape) = (inp[0].1.shape(), inp[1].1.shape());
let (a_shape, b_shape) = (inp[0].1.dims(), inp[1].1.dims());
let n = b_shape.last().unwrap().to_usize().unwrap();
let batch_size = a_shape
.iter()
Expand Down
Loading

0 comments on commit b5c38dc

Please sign in to comment.