Skip to content

Commit

Permalink
Fixed elementwise fusion
Browse files Browse the repository at this point in the history
  • Loading branch information
jafioti committed Jan 21, 2024
1 parent b1c435b commit 308938e
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 13 deletions.
1 change: 1 addition & 0 deletions src/compilers/metal/command_buffer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@ impl Operator for CommandBufferWrapper {
}
}

#[cfg(test)]
#[test]
fn test_common_buffer() {
crate::test_imports!();
Expand Down
44 changes: 32 additions & 12 deletions src/compilers/metal/elementwise_fusion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,19 @@ impl<T: MetalFloat> Compiler for ElementwiseFusionCompiler<T> {
.node_custom::<String, _>(a, "elementwise", ())
.unwrap();
let mut curr_input = to_input;
// Keep track of original edges to a and b
let a_orig_edges = graph
.graph
.edges_directed(a, Direction::Incoming)
.filter_map(|e| e.weight().as_data().map(|(i, ind, _)| (e.source(), i, ind)))
.sorted_by_key(|i| i.1)
.collect::<Vec<_>>();
let b_orig_edges = graph
.graph
.edges_directed(b, Direction::Incoming)
.filter_map(|e| e.weight().as_data().map(|(i, ind, _)| (e.source(), i, ind)))
.sorted_by_key(|i| i.1)
.collect::<Vec<_>>();
// Remove edge a -> b, and decrement indexes of all edges higher than it
graph.graph.remove_edge(edge_id);
for edge in graph
Expand Down Expand Up @@ -137,21 +150,29 @@ impl<T: MetalFloat> Compiler for ElementwiseFusionCompiler<T> {
}
// Alter a_equation to reflect the correct input indexes
let mut replacements = vec![];
for input_edge in graph
.graph
.edges_directed(a, Direction::Incoming)
.filter_map(|e| e.weight().as_data().map(|(a, b, c)| (e.source(), a, b, c)))
.sorted_by_key(|i| i.1)
{
for (src, inp_ind, out_ind) in a_orig_edges {
let n = graph
.graph
.edges_directed(b, Direction::Incoming)
.filter_map(|e| e.weight().as_data().map(|(a, b, c)| (e.source(), a, b, c)))
.find(|(src, _, out_ind, _)| *src == input_edge.0 && *out_ind == input_edge.2)
.find(|(c_src, _, c_out_ind, _)| *c_src == src && *c_out_ind == out_ind)
.unwrap();
replacements.push((format!("input{}", input_edge.1), format!("input{}", n.1)));
replacements.push((format!("input{inp_ind}"), format!("input{}", n.1)));
}
a_equation = multi_replace(&a_equation, &replacements);
// Alter b_equation to reflect the correct input indexes
replacements.clear();
for (src, inp_ind, out_ind) in b_orig_edges {
if inp_ind > to_input {
let n = graph
.graph
.edges_directed(b, Direction::Incoming)
.filter_map(|e| e.weight().as_data().map(|(a, b, c)| (e.source(), a, b, c)))
.find(|(c_src, _, c_out_ind, _)| *c_src == src && *c_out_ind == out_ind)
.unwrap();
replacements.push((format!("input{inp_ind}"), format!("input{}", n.1)));
}
}

if let Some(fused_op) = graph
.graph
Expand All @@ -163,15 +184,14 @@ impl<T: MetalFloat> Compiler for ElementwiseFusionCompiler<T> {
// B is already fused, just combine with b
new_op = b;
// Render a into b as input to_input
fused_op.equation = fused_op
.equation
fused_op.equation = multi_replace(&fused_op.equation, &replacements)
.replace(&format!("input{to_input}"), &format!("({a_equation})"));
} else {
let mut b_equation = graph
.node_custom::<String, _>(b, "elementwise", ())
.unwrap();
b_equation =
b_equation.replace(&format!("input{to_input}"), &format!("({a_equation})"));
b_equation = multi_replace(&b_equation, &replacements)
.replace(&format!("input{to_input}"), &format!("({a_equation})"));
// B is not a fused op, let's create a new one
new_op = graph
.add_op(FusedElementwiseOp::<T> {
Expand Down
2 changes: 1 addition & 1 deletion src/compilers/metal/fp16/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ pub type MetalFp16Compiler = (
std_norm::StdNormCompiler,
super::other::CopyCompiler<f16>,
super::other::ContiguousElimination<f16>,
// super::elementwise_fusion::ElementwiseFusionCompiler<f16>,
super::elementwise_fusion::ElementwiseFusionCompiler<f16>,
(
super::command_buffer::CommandBufferCompiler,
super::storage_buffer::StorageBufferCompiler,
Expand Down

0 comments on commit 308938e

Please sign in to comment.