From 69c207b599567c739c6a4fbd3ed5b7bfd7b87a4b Mon Sep 17 00:00:00 2001 From: Joe Fioti Date: Tue, 16 Jan 2024 17:06:29 -0600 Subject: [PATCH] Fixed fusion bugs --- src/compilers/generic.rs | 2 +- src/compilers/metal/elementwise_fusion.rs | 59 ++++++++++++++++++++--- src/compilers/metal/fp16/tests.rs | 6 +-- src/compilers/metal/fp32/tests.rs | 4 +- src/compilers/metal/prim.rs | 6 +++ src/compilers/metal/storage_buffer.rs | 4 +- src/core/op.rs | 4 +- src/hl_ops/unary.rs | 8 +-- src/tests/mod.rs | 2 +- 9 files changed, 72 insertions(+), 23 deletions(-) diff --git a/src/compilers/generic.rs b/src/compilers/generic.rs index b25a4288..3dc57050 100644 --- a/src/compilers/generic.rs +++ b/src/compilers/generic.rs @@ -378,7 +378,7 @@ mod tests { fn test_log_exp() { let mut cx = Graph::new(); let a = cx.tensor::(); - let _ = a.log_2().exp_2().retrieve(); + let _ = a.log2().exp2().retrieve(); cx.compile(GenericCompiler::<()>::default(), ()); assert_eq!(cx.graph.node_count(), 1); diff --git a/src/compilers/metal/elementwise_fusion.rs b/src/compilers/metal/elementwise_fusion.rs index a6243d4e..a8722163 100644 --- a/src/compilers/metal/elementwise_fusion.rs +++ b/src/compilers/metal/elementwise_fusion.rs @@ -14,7 +14,10 @@ use crate::{ use self::symbolic::BigExpression; -use super::{compile_function, input_dyn_dims, render_dyn_dim_inputs, DispatchNElements}; +use super::{ + compile_function, get_idx_valid_exps, input_dyn_dims, render_dyn_dim_inputs, DispatchNElements, + SetInt, +}; #[derive(Default, Debug)] pub struct ElementwiseFusionCompiler(PhantomData); @@ -118,7 +121,7 @@ impl Compiler for ElementwiseFusionCompiler { .equation .replace(&format!("input{to_input}"), &a_equation); // Since we are removing the input from a, we must decrement all inputs larger than that - for i in to_input..n_edges { + for i in to_input + 1..n_edges { fused_op.equation = fused_op .equation .replace(&format!("input{i}"), &format!("input{}", i - 1)); @@ -129,7 +132,7 @@ impl Compiler for ElementwiseFusionCompiler { .unwrap(); b_equation = b_equation.replace(&format!("input{to_input}"), &a_equation); // Since we are removing the input from a, we must decrement all inputs larger than that - for i in to_input..n_edges { + for i in to_input + 1..n_edges { b_equation = b_equation.replace(&format!("input{i}"), &format!("input{}", i - 1)); } @@ -145,6 +148,16 @@ impl Compiler for ElementwiseFusionCompiler { _phantom: Default::default(), }) .finish(); + move_incoming_edge(b, new_op, &mut graph.graph); + move_outgoing_edge(b, new_op, &mut graph.graph); + move_references( + &mut remap, + &mut graph.no_delete, + &mut graph.to_retrieve, + b, + new_op, + ); + graph.graph.remove_node(b); } // Remove a move_references( @@ -190,22 +203,31 @@ impl Compiler for ElementwiseFusionCompiler { { let (dyn_chars, rendered) = render_dyn_dim_inputs(&edges.iter().map(|i| i.2).collect_vec(), 0); + for (inp_ind, _, sh) in &edges { + let (ind, val) = get_idx_valid_exps(*sh); + op.equation = op.equation.replace( + &format!("input{inp_ind}"), + &format!("({val} != 0) ? input{inp_ind}[{ind}] : 0.0"), + ); + } let kernel = format!( " #include using namespace metal; -kernel void mkernel({} uint idx [[thread_position_in_grid]]{rendered}) {{ - if (idx < n_element) {{ +kernel void mkernel({} device {type_name} *out [[buffer({})]], device uint& n_elements [[buffer({})]], uint idx [[thread_position_in_grid]]{rendered}) {{ + if (idx < n_elements) {{ out[idx] = {}; }} }}", edges .iter() .map(|(inp_ind, _, _)| format!( - "device {type_name}* input{inp_ind} [buffer({inp_ind})]," + "device {type_name}* input{inp_ind} [[buffer({inp_ind})]]," )) .collect_vec() .join(" "), + edges.len(), + edges.len() + 1, op.equation ); op.kernel = Some(compile_function("mkernel", &kernel, &device)); @@ -239,21 +261,22 @@ impl MetalKernel for FusedElementwiseOp { let encoder = command_buffer.compute_command_encoder_with_descriptor(ComputePassDescriptor::new()); encoder.set_compute_pipeline_state(self.kernel.as_ref().unwrap()); + let out_size = inputs[0].1.n_physical_elements().to_usize().unwrap(); // Set function inputs for (i, (buf, _)) in inputs.iter().enumerate() { encoder.set_buffer(i as u64, Some(*buf), 0); } encoder.set_buffer(inputs.len() as u64, Some(output_buffers[0]), 0); + encoder.set_u32(inputs.len() + 1, out_size as u32); input_dyn_dims( &self.dyn_chars, unsafe { self.dyn_map.as_ref().unwrap() }, &encoder, - inputs.len() + 1, + inputs.len() + 2, ); // Execute - let out_size = inputs[0].1.n_physical_elements().to_usize().unwrap(); encoder.dispatch_1d(out_size); encoder.end_encoding(); } @@ -303,3 +326,23 @@ impl Operator for FusedElementwiseOp { None } } + +#[cfg(test)] +mod tests { + crate::test_imports!(); + #[test] + fn test_fusion() { + let mut cx = Graph::new(); + let a = cx.tensor::>().set(random_vec(10)).keep(); + let mut b = a.exp2().sin().retrieve(); + + cx.execute(); + let unopt_b = b.data(); + b.drop(); + + cx.compile(GenericCompiler::::default(), &mut b); + cx.execute(); + + assert_close(&b.data(), &unopt_b); + } +} diff --git a/src/compilers/metal/fp16/tests.rs b/src/compilers/metal/fp16/tests.rs index 19b14271..8996c2e8 100644 --- a/src/compilers/metal/fp16/tests.rs +++ b/src/compilers/metal/fp16/tests.rs @@ -52,7 +52,7 @@ fn test_log2() { let mut cx = Graph::new(); let data = random_vec(3); let a = cx.tensor::>().set(data.clone()); - let mut b = a.log_2().retrieve(); + let mut b = a.log2().retrieve(); cx.compile(MetalFp16Compiler::default(), &mut b); cx.execute(); @@ -71,7 +71,7 @@ fn test_exp2() { let mut cx = Graph::new(); let data = random_vec(3); let a = cx.tensor::>().set(data.clone()); - let mut b = a.exp_2().retrieve(); + let mut b = a.exp2().retrieve(); cx.compile(MetalFp16Compiler::default(), &mut b); cx.execute(); @@ -834,7 +834,7 @@ fn test_common_buffer() { let a1 = cx.tensor::>(); a1.set(data.clone()); let exped = a * a1; - let mut b = exped.log_2().retrieve(); + let mut b = exped.log2().retrieve(); let mut c = exped.sin().retrieve(); cx.compile(MetalFp16Compiler::default(), (&mut b, &mut c)); diff --git a/src/compilers/metal/fp32/tests.rs b/src/compilers/metal/fp32/tests.rs index f4123667..df1348d3 100644 --- a/src/compilers/metal/fp32/tests.rs +++ b/src/compilers/metal/fp32/tests.rs @@ -32,7 +32,7 @@ fn test_log2() { let mut cx = Graph::new(); let data = random_vec(3); let a = cx.tensor::>().set(data.clone()); - let mut b = a.log_2().retrieve(); + let mut b = a.log2().retrieve(); cx.compile(MetalFp32Compiler::default(), &mut b); cx.execute(); @@ -48,7 +48,7 @@ fn test_exp2() { let mut cx = Graph::new(); let data = random_vec(3); let a = cx.tensor::>().set(data.clone()); - let mut b = a.exp_2().retrieve(); + let mut b = a.exp2().retrieve(); cx.compile(MetalFp32Compiler::default(), &mut b); cx.execute(); diff --git a/src/compilers/metal/prim.rs b/src/compilers/metal/prim.rs index ed313152..db59c7cf 100644 --- a/src/compilers/metal/prim.rs +++ b/src/compilers/metal/prim.rs @@ -435,6 +435,9 @@ impl Operator for MetalExp2 { self.clone(), ))))); } + if key == "elementwise" { + return Some(Box::new("exp2(input0)".to_string())); + } // This op can accept non contiguous inputs if key == "non_contiguous" { return Some(Box::new(())); @@ -527,6 +530,9 @@ impl Operator for MetalSin { self.clone(), ))))); } + if key == "elementwise" { + return Some(Box::new("sin(input0)".to_string())); + } // This op can accept non contiguous inputs if key == "non_contiguous" { return Some(Box::new(())); diff --git a/src/compilers/metal/storage_buffer.rs b/src/compilers/metal/storage_buffer.rs index aedd369f..0efbf1a9 100644 --- a/src/compilers/metal/storage_buffer.rs +++ b/src/compilers/metal/storage_buffer.rs @@ -349,8 +349,8 @@ fn test_shared_buffers() { crate::test_imports!(); let mut cx = Graph::new(); let a = cx.tensor::>().set(random_vec(5)).keep(); - let b = a.exp_2(); - let c = a.log_2() * b; + let b = a.exp2(); + let c = a.log2() * b; let d = b.recip(); let mut e = (c + d).retrieve(); diff --git a/src/core/op.rs b/src/core/op.rs index 83ffff68..223b1bdf 100644 --- a/src/core/op.rs +++ b/src/core/op.rs @@ -639,7 +639,7 @@ mod tests { let mut cx = Graph::new(); let a = cx.tensor::>(); a.set(vec![1., 2., 3.]); - let b = a.log_2(); + let b = a.log2(); b.retrieve(); cx.execute(); @@ -658,7 +658,7 @@ mod tests { let mut cx = Graph::new(); let a = cx.tensor::>(); a.set(vec![1., 2., 3.]); - let b = a.exp_2(); + let b = a.exp2(); b.retrieve(); cx.execute(); diff --git a/src/hl_ops/unary.rs b/src/hl_ops/unary.rs index 719b650f..cb23db21 100644 --- a/src/hl_ops/unary.rs +++ b/src/hl_ops/unary.rs @@ -11,7 +11,7 @@ impl Neg for GraphTensor { impl GraphTensor { /// Base 2 log - pub fn log_2(self) -> GraphTensor { + pub fn log2(self) -> GraphTensor { let new_id = self .graph() .add_op(op::Log2) @@ -21,7 +21,7 @@ impl GraphTensor { } /// Base 2 exp - pub fn exp_2(self) -> GraphTensor { + pub fn exp2(self) -> GraphTensor { let new_id = self .graph() .add_op(op::Exp2) @@ -32,12 +32,12 @@ impl GraphTensor { /// Natural exp pub fn exp(self) -> GraphTensor { - (self * (1.0 / f32::ln(2.))).exp_2() + (self * (1.0 / f32::ln(2.))).exp2() } /// Natural log pub fn ln(self) -> GraphTensor { - self.log_2() * f32::ln(2.) + self.log2() * f32::ln(2.) } /// Take the reciprocal of each element diff --git a/src/tests/mod.rs b/src/tests/mod.rs index 2d736b7b..8b0b6618 100644 --- a/src/tests/mod.rs +++ b/src/tests/mod.rs @@ -21,7 +21,7 @@ fn main() { let e = cx.tensor::>().set(vec![1.0, 2.0, 3.0]); let mut a = (b * c + g).retrieve(); - let mut d = (b * c / e).exp_2().log_2().retrieve(); + let mut d = (b * c / e).exp2().log2().retrieve(); cx.execute();