Skip to content

Commit

Permalink
Initial version of elementwise fusion
Browse files Browse the repository at this point in the history
  • Loading branch information
jafioti committed Jan 16, 2024
1 parent 1c0f525 commit 54912c4
Show file tree
Hide file tree
Showing 6 changed files with 325 additions and 6 deletions.
4 changes: 2 additions & 2 deletions resources/luminal_macro/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ pub fn luminal_eq(input: TokenStream) -> TokenStream {

// Create the expanded trait implementation
let expanded = quote! {
impl #impl_generics PartialEq for #name #ty_generics #where_clause {
impl #impl_generics core::cmp::PartialEq for #name #ty_generics #where_clause {
fn eq(&self, _other: &Self) -> bool {
false
}
Expand All @@ -41,7 +41,7 @@ pub fn luminal_print(input: TokenStream) -> TokenStream {

// Create an identifier for the trait implementation
let gen = quote! {
impl #impl_generics Debug for #name #ty_generics #where_clause {
impl #impl_generics std::fmt::Debug for #name #ty_generics #where_clause {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, stringify!(#name))
}
Expand Down
1 change: 0 additions & 1 deletion src/compilers/metal/command_buffer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use std::{
any::Any,
cell::UnsafeCell,
collections::{HashMap, HashSet},
fmt::Debug,
sync::Arc,
};

Expand Down
317 changes: 317 additions & 0 deletions src/compilers/metal/elementwise_fusion.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,317 @@
use std::{any::Any, collections::HashMap, marker::PhantomData, sync::Arc};

use itertools::Itertools;
use metal_rs::{
objc::rc::autoreleasepool, Buffer, CommandBufferRef, CommandQueue, ComputePassDescriptor,
ComputePipelineState, Device, MTLResourceOptions,
};
use petgraph::{visit::EdgeRef, Direction};

use crate::{
op::{InputTensor, Operator},
prelude::{metal::get_buffer_from_tensor, *},
};

use self::symbolic::BigExpression;

use super::{compile_function, input_dyn_dims, render_dyn_dim_inputs, DispatchNElements};

#[derive(Default, Debug)]
pub struct ElementwiseFusionCompiler<T>(PhantomData<T>);

impl<T: MetalFloat> Compiler for ElementwiseFusionCompiler<T> {
fn compile<To: ToIdsMut>(&self, graph: &mut Graph, mut remap: To) {
let device = Device::system_default().unwrap();
let queue = device.new_command_queue();
// Find two elementwise ops that have a contiguous edge
let (mut a, mut b) = (NodeIndex::default(), NodeIndex::default());
let mut selector = SelectOp::new()
.check(|o, _| o.custom("elementwise", Box::<()>::default()).is_some())
.ptr(&mut a)
.edge(
SelectOp::new()
.check(|o, _| o.custom("elementwise", Box::<()>::default()).is_some())
.ptr(&mut b),
)
.search(graph);
let mut fused_ops = vec![];

while selector.next_match() {
// More than one connecting edge
if graph.no_delete.contains(&a)
|| graph
.graph
.edges_connecting(a, b)
.filter(|e| !e.weight().is_schedule())
.count()
> 1
{
continue;
}
// Connecting shape isn't contiguous
let (to_input, _, connecting_shape) = graph
.graph
.edges_connecting(a, b)
.find_map(|e| e.weight().as_data())
.unwrap();
if !connecting_shape.is_contiguous()
|| connecting_shape.is_sliced()
|| connecting_shape.is_padded()
{
continue;
}

// Fuse into a FusedElementwiseOp
let new_op;
let mut a_equation = graph
.graph
.node_weight_mut(a)
.unwrap()
.custom("elementwise", Box::<()>::default())
.unwrap()
.downcast_ref::<String>()
.unwrap()
.clone();
let mut n_edges = graph
.graph
.edges_directed(a, Direction::Incoming)
.filter(|e| !e.weight().is_schedule())
.count() as u8;
// Adjust variables in a_equation to the new inputs
for input_edge in graph
.graph
.edges_directed(a, Direction::Incoming)
.filter_map(|e| e.weight().as_data().map(|(a, b, c)| (e.source(), a, b, c)))
.collect_vec()
{
// Find edge or add it
if let Some(n) = graph
.graph
.edges_directed(b, Direction::Incoming)
.filter_map(|e| e.weight().as_data().map(|(a, b, c)| (e.source(), a, b, c)))
.find(|(src, inp_ind, _, _)| *src == input_edge.0 && *inp_ind == input_edge.2)
{
a_equation = a_equation
.replace(&format!("input{}", input_edge.1), &format!("input{}", n.1));
} else {
graph.graph.add_edge(
input_edge.0,
b,
Dependency::Data {
input_order: n_edges,
output_order: input_edge.2,
shape: input_edge.3,
},
);
a_equation = a_equation.replace(
&format!("input{}", input_edge.1),
&format!("input{}", n_edges),
);
n_edges += 1;
}
}
if let Some(fused_op) = graph
.graph
.node_weight_mut(b)
.unwrap()
.as_any_mut()
.downcast_mut::<FusedElementwiseOp<T>>()
{
// B is already fused, just combine with b
new_op = b;
// Render a into b as input to_input
fused_op.equation = fused_op
.equation
.replace(&format!("input{to_input}"), &a_equation);
// Since we are removing the input from a, we must decrement all inputs larger than that
for i in to_input..n_edges {
fused_op.equation = fused_op
.equation
.replace(&format!("input{i}"), &format!("input{}", i - 1));
}
} else {
let mut b_equation = graph
.graph
.node_weight_mut(b)
.unwrap()
.custom("elementwise", Box::<()>::default())
.unwrap()
.downcast_ref::<String>()
.unwrap()
.clone();
b_equation = b_equation.replace(&format!("input{to_input}"), &a_equation);
// Since we are removing the input from a, we must decrement all inputs larger than that
for i in to_input..n_edges {
b_equation =
b_equation.replace(&format!("input{i}"), &format!("input{}", i - 1));
}
// B is not a fused op, let's create a new one
new_op = graph
.add_op(FusedElementwiseOp::<T> {
kernel: None,
dyn_map: &graph.dyn_map,
dyn_chars: vec![],
equation: b_equation,
queue: queue.clone(),
device: device.clone(),
_phantom: Default::default(),
})
.finish();
}
// Remove a
move_references(
&mut remap,
&mut graph.no_delete,
&mut graph.to_retrieve,
a,
new_op,
);
graph.graph.remove_node(a);
// Bring input indexes back in line
for (i, e) in graph
.graph
.edges_directed(new_op, Direction::Incoming)
.filter(|e| !e.weight().is_schedule())
.sorted_by_key(|e| e.weight().as_data().unwrap().0)
.map(|e| e.id())
.enumerate()
.collect_vec()
{
if let Dependency::Data { input_order, .. } =
graph.graph.edge_weight_mut(e).unwrap()
{
*input_order = i as u8;
}
}
fused_ops.push(new_op);
}
// Compile all the kernels we placed
let type_name = T::type_name();
for fused_op in fused_ops {
let edges = graph
.graph
.edges_directed(fused_op, Direction::Incoming)
.filter_map(|e| e.weight().as_data())
.collect_vec();
if let Some(op) = graph
.graph
.node_weight_mut(fused_op)
.unwrap()
.as_any_mut()
.downcast_mut::<FusedElementwiseOp<T>>()
{
let (dyn_chars, rendered) =
render_dyn_dim_inputs(&edges.iter().map(|i| i.2).collect_vec(), 0);
let kernel = format!(
"
#include <metal_stdlib>
using namespace metal;
kernel void mkernel({} uint idx [[thread_position_in_grid]]{rendered}) {{
if (idx < n_element) {{
out[idx] = {};
}}
}}",
edges
.iter()
.map(|(inp_ind, _, _)| format!(
"device {type_name}* input{inp_ind} [buffer({inp_ind})],"
))
.collect_vec()
.join(" "),
op.equation
);
op.kernel = Some(compile_function("mkernel", &kernel, &device));
op.dyn_chars = dyn_chars;
}
}
}
}

#[derive(LuminalPrint, LuminalEq, Clone)]
pub struct FusedElementwiseOp<T> {
kernel: Option<ComputePipelineState>,
dyn_map: *const HashMap<char, usize>,
dyn_chars: Vec<char>,
equation: String,
queue: CommandQueue,
device: Device,
_phantom: PhantomData<T>,
}
impl<T> MetalKernel for FusedElementwiseOp<T> {
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<BigExpression> {
vec![input_shapes[0].n_physical_elements() * std::mem::size_of::<T>()]
}
fn metal_forward(
&self,
inputs: &[(&Buffer, ShapeTracker)],
command_buffer: &CommandBufferRef,
_: &[&Buffer],
output_buffers: &[&Buffer],
) {
let encoder =
command_buffer.compute_command_encoder_with_descriptor(ComputePassDescriptor::new());
encoder.set_compute_pipeline_state(self.kernel.as_ref().unwrap());

// Set function inputs
for (i, (buf, _)) in inputs.iter().enumerate() {
encoder.set_buffer(i as u64, Some(*buf), 0);
}
encoder.set_buffer(inputs.len() as u64, Some(output_buffers[0]), 0);
input_dyn_dims(
&self.dyn_chars,
unsafe { self.dyn_map.as_ref().unwrap() },
&encoder,
inputs.len() + 1,
);

// Execute
let out_size = inputs[0].1.n_physical_elements().to_usize().unwrap();
encoder.dispatch_1d(out_size);
encoder.end_encoding();
}
}

impl<T: MetalFloat> Operator for FusedElementwiseOp<T> {
fn process(&mut self, tensors: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
autoreleasepool(|| {
let command_buffer = self.queue.new_command_buffer();
let out = self.device.new_buffer(
self.output_buffer_sizes(&tensors.iter().map(|(_, s)| *s).collect_vec())[0]
.exec(unsafe { self.dyn_map.as_ref().unwrap() })
.unwrap() as u64,
MTLResourceOptions::StorageModeShared,
);

self.metal_forward(
&tensors
.iter()
.map(|(t, s)| (get_buffer_from_tensor(t), *s))
.collect_vec(),
command_buffer,
&[],
&[&out],
);

command_buffer.commit();
command_buffer.wait_until_completed();

vec![Tensor::new(out)]
})
}

fn custom(&mut self, key: &str, _: Box<dyn Any>) -> Option<Box<dyn Any>> {
if key == "metal" {
return Some(Box::new(MetalKernelWrapper(Arc::new(Box::new(
self.clone(),
)))));
}
// This op can accept non contiguous inputs
if key == "non_contiguous" {
return Some(Box::new(()));
}
if key == "elementwise" {
return Some(Box::new(self.equation.clone()));
}
None
}
}
7 changes: 5 additions & 2 deletions src/compilers/metal/fp16/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,11 @@ pub type MetalFp16Compiler = (
rms_norm::RMSNormCompiler,
super::other::CopyCompiler<f16>,
super::other::ContiguousElimination<f16>,
super::command_buffer::CommandBufferCompiler,
super::storage_buffer::StorageBufferCompiler,
super::elementwise_fusion::ElementwiseFusionCompiler<f16>,
(
super::command_buffer::CommandBufferCompiler,
super::storage_buffer::StorageBufferCompiler,
),
);

#[cfg(test)]
Expand Down
1 change: 1 addition & 0 deletions src/compilers/metal/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ pub mod fp16;
pub use fp16::*;
mod binary;
mod command_buffer;
mod elementwise_fusion;
mod other;
mod prim;
mod storage_buffer;
Expand Down
1 change: 0 additions & 1 deletion src/compilers/metal/storage_buffer.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use std::{
cell::UnsafeCell,
collections::{BTreeMap, BTreeSet, HashMap, HashSet},
fmt::Debug,
sync::Arc,
};

Expand Down

0 comments on commit 54912c4

Please sign in to comment.