Skip to content

Commit

Permalink
Fixed fusion bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
jafioti committed Jan 16, 2024
1 parent fa04b05 commit 69c207b
Show file tree
Hide file tree
Showing 9 changed files with 72 additions and 23 deletions.
2 changes: 1 addition & 1 deletion src/compilers/generic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ mod tests {
fn test_log_exp() {
let mut cx = Graph::new();
let a = cx.tensor::<R0>();
let _ = a.log_2().exp_2().retrieve();
let _ = a.log2().exp2().retrieve();

cx.compile(GenericCompiler::<()>::default(), ());
assert_eq!(cx.graph.node_count(), 1);
Expand Down
59 changes: 51 additions & 8 deletions src/compilers/metal/elementwise_fusion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@ use crate::{

use self::symbolic::BigExpression;

use super::{compile_function, input_dyn_dims, render_dyn_dim_inputs, DispatchNElements};
use super::{
compile_function, get_idx_valid_exps, input_dyn_dims, render_dyn_dim_inputs, DispatchNElements,
SetInt,
};

#[derive(Default, Debug)]
pub struct ElementwiseFusionCompiler<T>(PhantomData<T>);
Expand Down Expand Up @@ -118,7 +121,7 @@ impl<T: MetalFloat> Compiler for ElementwiseFusionCompiler<T> {
.equation
.replace(&format!("input{to_input}"), &a_equation);
// Since we are removing the input from a, we must decrement all inputs larger than that
for i in to_input..n_edges {
for i in to_input + 1..n_edges {
fused_op.equation = fused_op
.equation
.replace(&format!("input{i}"), &format!("input{}", i - 1));
Expand All @@ -129,7 +132,7 @@ impl<T: MetalFloat> Compiler for ElementwiseFusionCompiler<T> {
.unwrap();
b_equation = b_equation.replace(&format!("input{to_input}"), &a_equation);
// Since we are removing the input from a, we must decrement all inputs larger than that
for i in to_input..n_edges {
for i in to_input + 1..n_edges {
b_equation =
b_equation.replace(&format!("input{i}"), &format!("input{}", i - 1));
}
Expand All @@ -145,6 +148,16 @@ impl<T: MetalFloat> Compiler for ElementwiseFusionCompiler<T> {
_phantom: Default::default(),
})
.finish();
move_incoming_edge(b, new_op, &mut graph.graph);
move_outgoing_edge(b, new_op, &mut graph.graph);
move_references(
&mut remap,
&mut graph.no_delete,
&mut graph.to_retrieve,
b,
new_op,
);
graph.graph.remove_node(b);
}
// Remove a
move_references(
Expand Down Expand Up @@ -190,22 +203,31 @@ impl<T: MetalFloat> Compiler for ElementwiseFusionCompiler<T> {
{
let (dyn_chars, rendered) =
render_dyn_dim_inputs(&edges.iter().map(|i| i.2).collect_vec(), 0);
for (inp_ind, _, sh) in &edges {
let (ind, val) = get_idx_valid_exps(*sh);
op.equation = op.equation.replace(
&format!("input{inp_ind}"),
&format!("({val} != 0) ? input{inp_ind}[{ind}] : 0.0"),
);
}
let kernel = format!(
"
#include <metal_stdlib>
using namespace metal;
kernel void mkernel({} uint idx [[thread_position_in_grid]]{rendered}) {{
if (idx < n_element) {{
kernel void mkernel({} device {type_name} *out [[buffer({})]], device uint& n_elements [[buffer({})]], uint idx [[thread_position_in_grid]]{rendered}) {{
if (idx < n_elements) {{
out[idx] = {};
}}
}}",
edges
.iter()
.map(|(inp_ind, _, _)| format!(
"device {type_name}* input{inp_ind} [buffer({inp_ind})],"
"device {type_name}* input{inp_ind} [[buffer({inp_ind})]],"
))
.collect_vec()
.join(" "),
edges.len(),
edges.len() + 1,
op.equation
);
op.kernel = Some(compile_function("mkernel", &kernel, &device));
Expand Down Expand Up @@ -239,21 +261,22 @@ impl<T> MetalKernel for FusedElementwiseOp<T> {
let encoder =
command_buffer.compute_command_encoder_with_descriptor(ComputePassDescriptor::new());
encoder.set_compute_pipeline_state(self.kernel.as_ref().unwrap());
let out_size = inputs[0].1.n_physical_elements().to_usize().unwrap();

// Set function inputs
for (i, (buf, _)) in inputs.iter().enumerate() {
encoder.set_buffer(i as u64, Some(*buf), 0);
}
encoder.set_buffer(inputs.len() as u64, Some(output_buffers[0]), 0);
encoder.set_u32(inputs.len() + 1, out_size as u32);
input_dyn_dims(
&self.dyn_chars,
unsafe { self.dyn_map.as_ref().unwrap() },
&encoder,
inputs.len() + 1,
inputs.len() + 2,
);

// Execute
let out_size = inputs[0].1.n_physical_elements().to_usize().unwrap();
encoder.dispatch_1d(out_size);
encoder.end_encoding();
}
Expand Down Expand Up @@ -303,3 +326,23 @@ impl<T: MetalFloat> Operator for FusedElementwiseOp<T> {
None
}
}

#[cfg(test)]
mod tests {
crate::test_imports!();
#[test]
fn test_fusion() {
let mut cx = Graph::new();
let a = cx.tensor::<R1<10>>().set(random_vec(10)).keep();
let mut b = a.exp2().sin().retrieve();

cx.execute();
let unopt_b = b.data();
b.drop();

cx.compile(GenericCompiler::<MetalFp16Compiler>::default(), &mut b);
cx.execute();

assert_close(&b.data(), &unopt_b);
}
}
6 changes: 3 additions & 3 deletions src/compilers/metal/fp16/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ fn test_log2() {
let mut cx = Graph::new();
let data = random_vec(3);
let a = cx.tensor::<R1<3>>().set(data.clone());
let mut b = a.log_2().retrieve();
let mut b = a.log2().retrieve();

cx.compile(MetalFp16Compiler::default(), &mut b);
cx.execute();
Expand All @@ -71,7 +71,7 @@ fn test_exp2() {
let mut cx = Graph::new();
let data = random_vec(3);
let a = cx.tensor::<R1<3>>().set(data.clone());
let mut b = a.exp_2().retrieve();
let mut b = a.exp2().retrieve();

cx.compile(MetalFp16Compiler::default(), &mut b);
cx.execute();
Expand Down Expand Up @@ -834,7 +834,7 @@ fn test_common_buffer() {
let a1 = cx.tensor::<R1<32>>();
a1.set(data.clone());
let exped = a * a1;
let mut b = exped.log_2().retrieve();
let mut b = exped.log2().retrieve();
let mut c = exped.sin().retrieve();

cx.compile(MetalFp16Compiler::default(), (&mut b, &mut c));
Expand Down
4 changes: 2 additions & 2 deletions src/compilers/metal/fp32/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ fn test_log2() {
let mut cx = Graph::new();
let data = random_vec(3);
let a = cx.tensor::<R1<3>>().set(data.clone());
let mut b = a.log_2().retrieve();
let mut b = a.log2().retrieve();

cx.compile(MetalFp32Compiler::default(), &mut b);
cx.execute();
Expand All @@ -48,7 +48,7 @@ fn test_exp2() {
let mut cx = Graph::new();
let data = random_vec(3);
let a = cx.tensor::<R1<3>>().set(data.clone());
let mut b = a.exp_2().retrieve();
let mut b = a.exp2().retrieve();

cx.compile(MetalFp32Compiler::default(), &mut b);
cx.execute();
Expand Down
6 changes: 6 additions & 0 deletions src/compilers/metal/prim.rs
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,9 @@ impl<T: MetalFloat> Operator for MetalExp2<T> {
self.clone(),
)))));
}
if key == "elementwise" {
return Some(Box::new("exp2(input0)".to_string()));
}
// This op can accept non contiguous inputs
if key == "non_contiguous" {
return Some(Box::new(()));
Expand Down Expand Up @@ -527,6 +530,9 @@ impl<T: MetalFloat> Operator for MetalSin<T> {
self.clone(),
)))));
}
if key == "elementwise" {
return Some(Box::new("sin(input0)".to_string()));
}
// This op can accept non contiguous inputs
if key == "non_contiguous" {
return Some(Box::new(()));
Expand Down
4 changes: 2 additions & 2 deletions src/compilers/metal/storage_buffer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -349,8 +349,8 @@ fn test_shared_buffers() {
crate::test_imports!();
let mut cx = Graph::new();
let a = cx.tensor::<R1<5>>().set(random_vec(5)).keep();
let b = a.exp_2();
let c = a.log_2() * b;
let b = a.exp2();
let c = a.log2() * b;
let d = b.recip();
let mut e = (c + d).retrieve();

Expand Down
4 changes: 2 additions & 2 deletions src/core/op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -639,7 +639,7 @@ mod tests {
let mut cx = Graph::new();
let a = cx.tensor::<R1<3>>();
a.set(vec![1., 2., 3.]);
let b = a.log_2();
let b = a.log2();
b.retrieve();
cx.execute();

Expand All @@ -658,7 +658,7 @@ mod tests {
let mut cx = Graph::new();
let a = cx.tensor::<R1<3>>();
a.set(vec![1., 2., 3.]);
let b = a.exp_2();
let b = a.exp2();
b.retrieve();
cx.execute();

Expand Down
8 changes: 4 additions & 4 deletions src/hl_ops/unary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ impl<S: Shape> Neg for GraphTensor<S> {

impl<S: Shape> GraphTensor<S> {
/// Base 2 log
pub fn log_2(self) -> GraphTensor<S> {
pub fn log2(self) -> GraphTensor<S> {
let new_id = self
.graph()
.add_op(op::Log2)
Expand All @@ -21,7 +21,7 @@ impl<S: Shape> GraphTensor<S> {
}

/// Base 2 exp
pub fn exp_2(self) -> GraphTensor<S> {
pub fn exp2(self) -> GraphTensor<S> {
let new_id = self
.graph()
.add_op(op::Exp2)
Expand All @@ -32,12 +32,12 @@ impl<S: Shape> GraphTensor<S> {

/// Natural exp
pub fn exp(self) -> GraphTensor<S> {
(self * (1.0 / f32::ln(2.))).exp_2()
(self * (1.0 / f32::ln(2.))).exp2()
}

/// Natural log
pub fn ln(self) -> GraphTensor<S> {
self.log_2() * f32::ln(2.)
self.log2() * f32::ln(2.)
}

/// Take the reciprocal of each element
Expand Down
2 changes: 1 addition & 1 deletion src/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ fn main() {
let e = cx.tensor::<R1<3>>().set(vec![1.0, 2.0, 3.0]);

let mut a = (b * c + g).retrieve();
let mut d = (b * c / e).exp_2().log_2().retrieve();
let mut d = (b * c / e).exp2().log2().retrieve();

cx.execute();

Expand Down

0 comments on commit 69c207b

Please sign in to comment.