Skip to content

Commit

Permalink
Contiguous elimination
Browse files Browse the repository at this point in the history
  • Loading branch information
Joe Fioti authored and Joe Fioti committed Jan 6, 2024
1 parent d0afd42 commit 9a0261a
Show file tree
Hide file tree
Showing 13 changed files with 229 additions and 61 deletions.
2 changes: 2 additions & 0 deletions examples/mistral/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,8 @@ fn main() {
transfer_data(&kv_cache, &mut cx1, &cache_src_set, &mut cx2);
drop(cx1);

// cx2.display_set(&[NodeIndex::new(51), NodeIndex::new(85), NodeIndex::new(87)]);

// Decode loop
let mut token_decode_times = vec![];
for _ in 0..tokens_to_generate {
Expand Down
36 changes: 34 additions & 2 deletions src/compilers/metal/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,12 +127,28 @@ impl<T: MetalFloat> Operator for MetalSub<T> {
})
}

fn custom(&self, key: &str) -> Option<Box<dyn Any>> {
fn custom(&mut self, key: &str, input: Box<dyn Any>) -> Option<Box<dyn Any>> {
if key == "metal" {
return Some(Box::new(MetalKernelWrapper(Arc::new(Box::new(
self.clone(),
)))));
}
// This op can accept non contiguous inputs
if key == "non_contiguous" {
return Some(Box::new(()));
}
if key == "recompile_shapes" {
if let Some(input_shapes) = input.downcast_ref::<Vec<ShapeTracker>>() {
*self = Self::new(
input_shapes[0],
input_shapes[1],
self.2.clone(),
self.1.clone(),
&mut HashMap::new(),
self.6,
)
}
}
None
}
}
Expand Down Expand Up @@ -332,12 +348,28 @@ impl<T: MetalFloat> Operator for MetalEqual<T> {
})
}

fn custom(&self, key: &str) -> Option<Box<dyn Any>> {
fn custom(&mut self, key: &str, input: Box<dyn Any>) -> Option<Box<dyn Any>> {
if key == "metal" {
return Some(Box::new(MetalKernelWrapper(Arc::new(Box::new(
self.clone(),
)))));
}
// This op can accept non contiguous inputs
if key == "non_contiguous" {
return Some(Box::new(()));
}
if key == "recompile_shapes" {
if let Some(input_shapes) = input.downcast_ref::<Vec<ShapeTracker>>() {
*self = Self::new(
input_shapes[0],
input_shapes[1],
self.2.clone(),
self.1.clone(),
&mut HashMap::new(),
self.6,
)
}
}
None
}
}
Expand Down
13 changes: 8 additions & 5 deletions src/compilers/metal/command_buffer.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::{
any::Any,
cell::UnsafeCell,
collections::{HashMap, HashSet},
fmt::Debug,
Expand Down Expand Up @@ -28,12 +29,14 @@ impl Compiler for CommandBufferCompiler {
let is_metal: HashSet<NodeIndex> = graph
.graph
.node_indices()
.collect::<Vec<_>>()
.into_iter()
.filter(|i| {
graph
.graph
.node_weight(*i)
.node_weight_mut(*i)
.unwrap()
.custom("metal")
.custom("metal", Box::new(()))
.is_some()
})
.collect();
Expand Down Expand Up @@ -170,9 +173,9 @@ impl Compiler for CommandBufferCompiler {
// Wrap node in MetalKernelOperation
let wrapper = graph
.graph
.node_weight(*node)
.node_weight_mut(*node)
.unwrap()
.custom("metal")
.custom("metal", Box::new(()))
.unwrap()
.downcast::<MetalKernelWrapper>()
.unwrap();
Expand Down Expand Up @@ -280,7 +283,7 @@ impl Operator for CommandBufferWrapper {
}

#[allow(clippy::arc_with_non_send_sync)]
fn custom(&self, key: &str) -> Option<Box<dyn std::any::Any>> {
fn custom(&mut self, key: &str, _: Box<dyn Any>) -> Option<Box<dyn Any>> {
if key == "metal" {
return Some(Box::new(MetalKernelWrapper(Arc::new(Box::new(
self.clone(),
Expand Down
6 changes: 3 additions & 3 deletions src/compilers/metal/fp16/matmul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ impl Operator for MetalVecMat {
})
}

fn custom(&self, key: &str) -> Option<Box<dyn Any>> {
fn custom(&mut self, key: &str, _: Box<dyn Any>) -> Option<Box<dyn Any>> {
if key == "metal" {
return Some(Box::new(MetalKernelWrapper(Arc::new(Box::new(
self.clone(),
Expand Down Expand Up @@ -404,7 +404,7 @@ impl Operator for MetalMatmul2D {
})
}

fn custom(&self, key: &str) -> Option<Box<dyn Any>> {
fn custom(&mut self, key: &str, _: Box<dyn Any>) -> Option<Box<dyn Any>> {
if key == "metal" {
return Some(Box::new(MetalKernelWrapper(Arc::new(Box::new(
self.clone(),
Expand Down Expand Up @@ -580,7 +580,7 @@ impl Operator for MetalBatchMatmul2D {
})
}

fn custom(&self, key: &str) -> Option<Box<dyn Any>> {
fn custom(&mut self, key: &str, _: Box<dyn Any>) -> Option<Box<dyn Any>> {
if key == "metal" {
return Some(Box::new(MetalKernelWrapper(Arc::new(Box::new(
self.clone(),
Expand Down
17 changes: 16 additions & 1 deletion src/compilers/metal/fp16/mean_reduce.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,13 +147,28 @@ impl Operator for MetalMeanReduce {
})
}

fn custom(&self, key: &str) -> Option<Box<dyn Any>> {
fn custom(&mut self, key: &str, input: Box<dyn Any>) -> Option<Box<dyn Any>> {
if key == "metal" {
#[allow(clippy::arc_with_non_send_sync)]
return Some(Box::new(MetalKernelWrapper(Arc::new(Box::new(
self.clone(),
)))));
}
// This op can accept non contiguous inputs
if key == "non_contiguous" {
return Some(Box::new(()));
}
if key == "recompile_shapes" {
if let Some(input_shapes) = input.downcast_ref::<Vec<ShapeTracker>>() {
*self = Self::new(
self.2.clone(),
self.1.clone(),
self.3,
input_shapes[0],
self.5,
);
}
}
None
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/compilers/metal/fp16/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ pub type MetalFp16Compiler = (
mean_reduce::MeanReduceCompiler,
rms_norm::RMSNormCompiler,
super::other::CopyCompiler<f16>,
// super::other::ContiguousElimination<f16>,
super::other::ContiguousElimination<f16>,
super::command_buffer::CommandBufferCompiler,
super::storage_buffer::StorageBufferCompiler,
);
Expand Down
12 changes: 10 additions & 2 deletions src/compilers/metal/fp16/other.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,16 @@ impl Operator for MetalCos {
})
}

fn custom(&self, key: &str) -> Option<Box<dyn Any>> {
fn custom(&mut self, key: &str, _: Box<dyn Any>) -> Option<Box<dyn Any>> {
if key == "metal" {
return Some(Box::new(MetalKernelWrapper(Arc::new(Box::new(
self.clone(),
)))));
}
// This op can accept non contiguous inputs
if key == "non_contiguous" {
return Some(Box::new(()));
}
None
}
}
Expand Down Expand Up @@ -258,12 +262,16 @@ impl Operator for MetalExp {
})
}

fn custom(&self, key: &str) -> Option<Box<dyn Any>> {
fn custom(&mut self, key: &str, _: Box<dyn Any>) -> Option<Box<dyn Any>> {
if key == "metal" {
return Some(Box::new(MetalKernelWrapper(Arc::new(Box::new(
self.clone(),
)))));
}
// This op can accept non contiguous inputs
if key == "non_contiguous" {
return Some(Box::new(()));
}
None
}
}
Expand Down
17 changes: 14 additions & 3 deletions src/compilers/metal/fp16/rms_norm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ pub struct MetalRMSNorm(
ComputePipelineState, // RMSNorm kernel
Device,
ShapeTracker, // Input shape
f32, // Epsilon
*const HashMap<char, usize>,
);

Expand Down Expand Up @@ -79,6 +80,7 @@ kernel void mkernel(device float *inp [[buffer(0)]], device half *x [[buffer(1)]
compile_function(&rms_norm_code_name, &rms_norm_code, &dev),
dev,
inp_shape,
epsilon,
dyn_map,
)
}
Expand Down Expand Up @@ -127,7 +129,7 @@ impl MetalKernel for MetalRMSNorm {
encoder.set_int(3, front_size as u32);
encoder.set_int(4, back_size as u32);
encoder.set_int(5, dim_size as u32);
input_dyn_dims(&[self.3], unsafe { self.4.as_ref().unwrap() }, encoder, 6);
input_dyn_dims(&[self.3], unsafe { self.5.as_ref().unwrap() }, encoder, 6);

encoder.dispatch_1d(meaned_elements);
encoder.end_encoding();
Expand All @@ -141,7 +143,7 @@ impl MetalKernel for MetalRMSNorm {
encoder.set_buffer(1, Some(inputs[0].0), 0);
encoder.set_buffer(2, Some(output_buffers[0]), 0);
encoder.set_int(3, inputs[0].1.n_elements().to_usize().unwrap() as u32);
input_dyn_dims(&[self.3], unsafe { self.4.as_ref().unwrap() }, encoder, 4);
input_dyn_dims(&[self.3], unsafe { self.5.as_ref().unwrap() }, encoder, 4);

// Execute
encoder.dispatch_1d(inputs[0].1.n_elements().to_usize().unwrap());
Expand Down Expand Up @@ -183,13 +185,22 @@ impl Operator for MetalRMSNorm {
})
}

fn custom(&self, key: &str) -> Option<Box<dyn Any>> {
fn custom(&mut self, key: &str, input: Box<dyn Any>) -> Option<Box<dyn Any>> {
if key == "metal" {
#[allow(clippy::arc_with_non_send_sync)]
return Some(Box::new(MetalKernelWrapper(Arc::new(Box::new(
self.clone(),
)))));
}
// This op can accept non contiguous inputs
if key == "non_contiguous" {
return Some(Box::new(()));
}
if key == "recompile_shapes" {
if let Some(input_shapes) = input.downcast_ref::<Vec<ShapeTracker>>() {
*self = Self::new(self.4, self.2.clone(), input_shapes[0], self.5)
}
}
None
}
}
Expand Down
22 changes: 19 additions & 3 deletions src/compilers/metal/other.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ impl<T: MetalFloat> Operator for MetalARange<T> {
})
}

fn custom(&self, key: &str) -> Option<Box<dyn Any>> {
fn custom(&mut self, key: &str, _: Box<dyn Any>) -> Option<Box<dyn Any>> {
if key == "metal" {
#[allow(clippy::arc_with_non_send_sync)]
return Some(Box::new(MetalKernelWrapper(Arc::new(Box::new(
Expand Down Expand Up @@ -315,7 +315,7 @@ impl<T: MetalFloat> Compiler for ContiguousElimination<T> {
.ptr(&mut contig)
.edge(
SelectOp::new()
.check(|op, _| op.custom("non_contiguous").is_some())
.check(|op, _| op.custom("non_contiguous", Box::new(())).is_some())
.ptr(&mut op),
);
let mut selector = pattern.search(graph);
Expand All @@ -330,12 +330,19 @@ impl<T: MetalFloat> Compiler for ContiguousElimination<T> {
continue;
}
// Shape going from contig to op
// let first_shape = graph
// .graph
// .edges_directed(contig, Direction::Incoming)
// .find_map(|e| e.weight().as_data())
// .unwrap()
// .2;
let second_shape = graph
.graph
.edges_connecting(contig, op)
.find_map(|e| e.weight().as_data())
.unwrap()
.2;
// Here we should check if second shape and first shape are mergeable instead of just checking if second_shape is contiguous
if second_shape.is_contiguous()
&& !second_shape.is_sliced()
&& !second_shape.is_padded()
Expand All @@ -354,7 +361,16 @@ impl<T: MetalFloat> Compiler for ContiguousElimination<T> {
source,
);
graph.graph.remove_node(contig);
selector.clear_cached_results();
let new_shapes = graph
.get_sources(op)
.into_iter()
.map(|(_, _, s)| s)
.collect::<Vec<_>>();
graph
.graph
.node_weight_mut(op)
.unwrap()
.custom("recompile_shapes", Box::new(new_shapes));
}
}
}
Expand Down
Loading

0 comments on commit 9a0261a

Please sign in to comment.