Skip to content

Commit

Permalink
Disabled elementwise fusion
Browse files Browse the repository at this point in the history
  • Loading branch information
jafioti committed Jan 19, 2024
1 parent 8f2d13d commit ebb0df6
Show file tree
Hide file tree
Showing 4 changed files with 191 additions and 62 deletions.
3 changes: 2 additions & 1 deletion src/compilers/generic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ pub type PreGenericCompiler = (

pub type PostGenericCompiler = (
// RemoveUnusedNodes, // Broken right now, unclear why
CSE,
// CSE,
(),
);

/// Eliminate complementary unary sequential operations like `x.log().exp()`
Expand Down
244 changes: 184 additions & 60 deletions src/compilers/metal/elementwise_fusion.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
use std::{any::Any, collections::HashMap, marker::PhantomData, sync::Arc};
use std::{
any::Any,
collections::{HashMap, HashSet},
marker::PhantomData,
sync::Arc,
};

use itertools::Itertools;
use metal_rs::{
Expand Down Expand Up @@ -37,25 +42,25 @@ impl<T: MetalFloat> Compiler for ElementwiseFusionCompiler<T> {
.ptr(&mut b),
)
.search(graph);
let mut fused_ops = vec![];
let mut fused_ops = HashSet::new();

while selector.next_match() {
// More than one connecting edge
if graph.no_delete.contains(&a)
|| graph
.graph
.edges_connecting(a, b)
.edges_directed(a, Direction::Outgoing)
.filter(|e| !e.weight().is_schedule())
.count()
> 1
{
continue;
}
// Connecting shape isn't contiguous
let (to_input, _, connecting_shape) = graph
let (edge_id, (to_input, _, connecting_shape)) = graph
.graph
.edges_connecting(a, b)
.find_map(|e| e.weight().as_data())
.find_map(|e| e.weight().as_data().map(|i| (e.id(), i)))
.unwrap();
if !connecting_shape.is_contiguous()
|| connecting_shape.is_sliced()
Expand All @@ -69,44 +74,114 @@ impl<T: MetalFloat> Compiler for ElementwiseFusionCompiler<T> {
let mut a_equation = graph
.node_custom::<String, _>(a, "elementwise", ())
.unwrap();
let mut n_edges = graph
let mut curr_input = to_input;
// Remove edge a -> b, and decrement indexes of all edges higher than it
graph.graph.remove_edge(edge_id);
for edge in graph
.graph
.edges_directed(a, Direction::Incoming)
.filter(|e| !e.weight().is_schedule())
.count() as u8;
// Adjust variables in a_equation to the new inputs
.edges_directed(b, Direction::Incoming)
.map(|e| e.id())
.collect_vec()
{
if let Some(Dependency::Data { input_order, .. }) =
graph.graph.edge_weight_mut(edge)
{
if *input_order > curr_input {
*input_order -= 1;
}
}
}
// Add edges if they don't exist
for input_edge in graph
.graph
.edges_directed(a, Direction::Incoming)
.filter_map(|e| e.weight().as_data().map(|(a, b, c)| (e.source(), a, b, c)))
.sorted_by_key(|i| i.1)
.collect_vec()
{
// Find edge or add it
if let Some(n) = graph
if graph
.graph
.edges_directed(b, Direction::Incoming)
.filter_map(|e| e.weight().as_data().map(|(a, b, c)| (e.source(), a, b, c)))
.find(|(src, inp_ind, _, _)| *src == input_edge.0 && *inp_ind == input_edge.2)
.find(|(src, _, out_ind, _)| *src == input_edge.0 && *out_ind == input_edge.2)
.is_none()
{
a_equation = a_equation
.replace(&format!("input{}", input_edge.1), &format!("input{}", n.1));
} else {
println!("Adding edge {curr_input}");
// Move all edges >= curr_input up by one
for edge in graph
.graph
.edges_directed(b, Direction::Incoming)
.map(|e| e.id())
.collect_vec()
{
if let Some(Dependency::Data { input_order, .. }) =
graph.graph.edge_weight_mut(edge)
{
if *input_order >= curr_input {
*input_order += 1;
}
}
}
// Add edge
graph.graph.add_edge(
input_edge.0,
b,
Dependency::Data {
input_order: n_edges,
input_order: curr_input,
output_order: input_edge.2,
shape: input_edge.3,
},
);
a_equation = a_equation.replace(
&format!("input{}", input_edge.1),
&format!("input{}", n_edges),
curr_input += 1;
println!(
"Current: {:?}",
graph
.graph
.edges_directed(b, Direction::Incoming)
.filter_map(|e| e.weight().as_data())
.map(|w| w.0)
.collect_vec()
);
n_edges += 1;
}
}
// Alter a_equation to reflect the correct input indexes
println!("Pre equation: {a_equation}");
println!(
"Ind: {:?}",
graph
.graph
.edges_directed(a, Direction::Incoming)
.map(|e| e.source())
.collect_vec()
);
println!(
"Ind2: {:?}",
graph
.graph
.edges_directed(b, Direction::Incoming)
.map(|e| (e.source(), e.weight().as_data().unwrap().0))
.collect_vec()
);
let mut replacements = vec![];
for input_edge in graph
.graph
.edges_directed(a, Direction::Incoming)
.filter_map(|e| e.weight().as_data().map(|(a, b, c)| (e.source(), a, b, c)))
.sorted_by_key(|i| i.1)
{
let n = graph
.graph
.edges_directed(b, Direction::Incoming)
.filter_map(|e| e.weight().as_data().map(|(a, b, c)| (e.source(), a, b, c)))
.find(|(src, _, out_ind, _)| *src == input_edge.0 && *out_ind == input_edge.2)
.unwrap();
println!("replacing {} with {}", input_edge.1, n.1);
replacements.push((format!("input{}", input_edge.1), format!("input{}", n.1)));
}
a_equation = multi_replace(&a_equation, &replacements);
println!("Post equation: {a_equation}");

if let Some(fused_op) = graph
.graph
.node_weight_mut(b)
Expand All @@ -120,23 +195,22 @@ impl<T: MetalFloat> Compiler for ElementwiseFusionCompiler<T> {
fused_op.equation = fused_op
.equation
.replace(&format!("input{to_input}"), &format!("({a_equation})"));
// Since we are removing the input from a, we must decrement all inputs larger than that
for i in to_input + 1..n_edges {
fused_op.equation = fused_op
.equation
.replace(&format!("input{i}"), &format!("input{}", i - 1));
}
println!(
"B edges: {:?}",
graph
.graph
.edges_directed(new_op, Direction::Incoming)
.filter_map(|e| e.weight().as_data())
.map(|w| w.0)
.collect_vec()
);
} else {
let mut b_equation = graph
.node_custom::<String, _>(b, "elementwise", ())
.unwrap();
b_equation =
b_equation.replace(&format!("input{to_input}"), &format!("({a_equation})"));
// Since we are removing the input from a, we must decrement all inputs larger than that
for i in to_input + 1..n_edges {
b_equation =
b_equation.replace(&format!("input{i}"), &format!("input{}", i - 1));
}
println!("Eq: {}", b_equation);
// B is not a fused op, let's create a new one
new_op = graph
.add_op(FusedElementwiseOp::<T> {
Expand All @@ -159,6 +233,16 @@ impl<T: MetalFloat> Compiler for ElementwiseFusionCompiler<T> {
new_op,
);
graph.graph.remove_node(b);
fused_ops.remove(&b);
println!(
"New Op edges: {:?}",
graph
.graph
.edges_directed(new_op, Direction::Incoming)
.filter_map(|e| e.weight().as_data())
.map(|w| w.0)
.collect_vec()
);
}
// Remove a
move_references(
Expand All @@ -169,23 +253,10 @@ impl<T: MetalFloat> Compiler for ElementwiseFusionCompiler<T> {
new_op,
);
graph.graph.remove_node(a);
// Bring input indexes back in line
for (i, e) in graph
.graph
.edges_directed(new_op, Direction::Incoming)
.filter(|e| !e.weight().is_schedule())
.sorted_by_key(|e| e.weight().as_data().unwrap().0)
.map(|e| e.id())
.enumerate()
.collect_vec()
{
if let Dependency::Data { input_order, .. } =
graph.graph.edge_weight_mut(e).unwrap()
{
*input_order = i as u8;
}
}
fused_ops.push(new_op);
fused_ops.remove(&a);
fused_ops.insert(new_op);
selector.reset();
println!("Finished {}", new_op.index());
}
// Compile all the kernels we placed
let type_name = T::type_name();
Expand All @@ -195,21 +266,43 @@ impl<T: MetalFloat> Compiler for ElementwiseFusionCompiler<T> {
.edges_directed(fused_op, Direction::Incoming)
.filter_map(|e| e.weight().as_data())
.collect_vec();
println!("Node: {:?}", fused_op);
println!(
"Edges: {:?}",
graph
.graph
.edges_directed(fused_op, Direction::Incoming)
.sorted_by_key(|e| e.weight().as_data().unwrap().0)
.map(|e| (
e.source(),
format!("{:?}", graph.graph.node_weight(e.source()))
))
.collect_vec()
);
if let Some(op) = graph
.graph
.node_weight_mut(fused_op)
.unwrap()
.as_any_mut()
.downcast_mut::<FusedElementwiseOp<T>>()
{
let (dyn_chars, rendered) =
render_dyn_dim_inputs(&edges.iter().map(|i| i.2).collect_vec(), 0);
let (dyn_chars, rendered) = render_dyn_dim_inputs(
&edges.iter().map(|i| i.2).collect_vec(),
edges.len() + 2,
);
for (inp_ind, _, sh) in &edges {
let (ind, val) = get_idx_valid_exps(*sh);
op.equation = op.equation.replace(
&format!("input{inp_ind}"),
&format!("({val} != 0) ? input{inp_ind}[{ind}] : 0.0"),
);
if sh.is_contiguous() && !sh.is_sliced() && !sh.is_padded() {
op.equation = op.equation.replace(
&format!("input{inp_ind}"),
&format!("input{inp_ind}[{ind}]"),
);
} else {
op.equation = op.equation.replace(
&format!("input{inp_ind}"),
&format!("(({val} != 0) ? input{inp_ind}[{ind}] : 0.0)"),
);
}
}
let kernel = format!(
"
Expand All @@ -231,13 +324,42 @@ kernel void mkernel({} device {type_name} *out [[buffer({})]], device uint& n_el
edges.len() + 1,
op.equation
);
println!("{kernel}");
op.kernel = Some(compile_function("mkernel", &kernel, &device));
op.dyn_chars = dyn_chars;
}
}
}
}

fn multi_replace(input: &str, replacements: &[(String, String)]) -> String {
// Use Unicode Private Use Areas as unlikely placeholders
// Starting at U+E000
let mut placeholder_start = 0xE000;

let mut output = input.to_string();

// Generate placeholder characters for each replacement pair
let mut placeholders: Vec<(String, char)> = Vec::new();
for (from, _) in replacements {
let placeholder = std::char::from_u32(placeholder_start).unwrap();
placeholder_start += 1;
placeholders.push((from.clone(), placeholder));
}

// First pass: Replace all target strings with placeholders
for (from, placeholder) in &placeholders {
output = output.replace(from, &placeholder.to_string());
}

// Second pass: Replace placeholders with final strings
for ((_, placeholder), (_, to)) in placeholders.iter().zip(replacements) {
output = output.replace(&placeholder.to_string(), to);
}

output
}

#[derive(LuminalPrint, LuminalEq, Clone)]
pub struct FusedElementwiseOp<T> {
kernel: Option<ComputePipelineState>,
Expand Down Expand Up @@ -334,16 +456,18 @@ mod tests {
#[test]
fn test_fusion() {
let mut cx = Graph::new();
let a = cx.tensor::<R1<10>>().set(random_vec(10)).keep();
let mut b = a.exp2().sin().retrieve();
let a = cx.named_tensor::<R1<10>>("a").set(random_vec(10)).keep();
let b = cx.named_tensor::<R1<10>>("b").set(random_vec(10)).keep();
let mut c = (a.exp2() - b.sin()).relu().retrieve();

cx.execute();
let unopt_b = b.data();
b.drop();
let unopt_c = c.data();
c.drop();

cx.compile(GenericCompiler::<MetalFp16Compiler>::default(), &mut b);
cx.compile(GenericCompiler::<MetalFp16Compiler>::default(), &mut c);
// cx.display();
cx.execute();

assert_close(&b.data(), &unopt_b);
assert_close(&c.data(), &unopt_c);
}
}
2 changes: 1 addition & 1 deletion src/compilers/metal/fp16/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ pub type MetalFp16Compiler = (
rms_norm::RMSNormCompiler,
super::other::CopyCompiler<f16>,
super::other::ContiguousElimination<f16>,
super::elementwise_fusion::ElementwiseFusionCompiler<f16>,
// super::elementwise_fusion::ElementwiseFusionCompiler<f16>,
(
super::command_buffer::CommandBufferCompiler,
super::storage_buffer::StorageBufferCompiler,
Expand Down
4 changes: 4 additions & 0 deletions src/core/compiler_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -589,6 +589,10 @@ impl GraphSearch {
pub fn clear_cached_results(&mut self) {
self.to_return.clear();
}
pub fn reset(&mut self) {
self.clear_cached_results();
self.returned_anchors.clear();
}
}

fn backtrack_match(
Expand Down

0 comments on commit ebb0df6

Please sign in to comment.