diff --git a/src/compilers/metal/fp16/rms_norm.rs b/src/compilers/metal/fp16/rms_norm.rs index 34722959..668c2798 100644 --- a/src/compilers/metal/fp16/rms_norm.rs +++ b/src/compilers/metal/fp16/rms_norm.rs @@ -15,91 +15,75 @@ use metal_rs::{objc::rc::autoreleasepool, *}; /// Special kernel for efficient rms norming #[derive(LuminalEq, LuminalPrint, Clone)] pub struct MetalRMSNorm { - square_mean_pipeline: ComputePipelineState, // Square-Mean kernel - rms_norm_pipeline: ComputePipelineState, // RMSNorm kernel + pipeline: ComputePipelineState, device: Device, queue: CommandQueue, - dyn_symbols: Vec, epsilon: f32, // Epsilon - dyn_map: *const HashMap, } impl MetalRMSNorm { - fn new( - epsilon: f32, - device: Device, - queue: CommandQueue, - inp_shape: ShapeTracker, - dyn_map: *const HashMap, - ) -> Self { - let (idx_exp, valid_exp) = get_idx_valid_exps(inp_shape); - let (dyn_symbols, rendered) = render_dyn_dim_inputs(&[inp_shape], 6); - let mut square_mean_code = format!( - " -#include + fn new(epsilon: f32, device: Device, queue: CommandQueue) -> Self { + let kernel_code = "#include using namespace metal; -kernel void mkernel(device half *inp [[buffer(0)]], device float *out [[buffer(1)]], device int& n_elements [[buffer(2)]], device int& front_size [[buffer(3)]], device int& back_size [[buffer(4)]], device int& dim_size [[buffer(5)]], uint i_ [[thread_position_in_grid]]{rendered}) {{ - if (i_ < n_elements) {{ - int a_ = i_ / back_size; - int b_ = i_ % back_size; - float reduce_value = 0.0; - int add_factor = a_ * dim_size * back_size + b_; - for (int c_ = 0; c_ < dim_size * back_size; c_ += back_size) {{ - int idx = add_factor + c_; - if (({valid_exp}) != 0) {{ - float val = (float)inp[{idx_exp}]; - reduce_value += (val * val); - }} - }} - out[i_] = (reduce_value / (float)dim_size); - }} -}} -"); - let square_mean_code_name = format!("kernel_{}", hash(&square_mean_code)); - square_mean_code = square_mean_code.replace("mkernel", &square_mean_code_name); +kernel void kernel_rms_norm( + device const void * src0, + device half * dst, + constant int64_t & ne00, + constant uint64_t & nb01, + constant float & eps, + threadgroup float * buf [[threadgroup(0)]], + uint tgpig[[threadgroup_position_in_grid]], + uint tpitg[[thread_position_in_threadgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint ntg[[threads_per_threadgroup]]) { + device const half4 * x = (device const half4 *) ((device const char *) src0 + tgpig*nb01); - let mut meaned_shape = inp_shape; - let meaned_size = meaned_shape.remove_dim(meaned_shape.len() - 1); - meaned_shape.expand(meaned_shape.len(), meaned_size); - let (meaned_idx_exp, _) = get_idx_valid_exps(meaned_shape); - let (_, rendered) = render_dyn_dim_inputs(&[inp_shape], 4); - let mut rms_norm_code = format!(" -#include -using namespace metal; + float4 sumf = 0; + float all_sum = 0; + + // parallel sum + for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { + sumf += (float4)x[i00] * (float4)x[i00]; + } + all_sum = sumf[0] + sumf[1] + sumf[2] + sumf[3]; + all_sum = simd_sum(all_sum); + if (ntg > 32) { + if (sgitg == 0) { + buf[tiisg] = 0.0f; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + buf[sgitg] = all_sum; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + all_sum = buf[tiisg]; + all_sum = simd_sum(all_sum); + } -kernel void mkernel(device float *inp [[buffer(0)]], device half *x [[buffer(1)]], device half *out [[buffer(2)]], device int& n_elements [[buffer(3)]], uint idx [[thread_position_in_grid]]{rendered}) {{ - if (idx < n_elements) {{ - float added = inp[{meaned_idx_exp}] + {epsilon}; - float sq = sqrt(added); - float recip = 1.0f / sq; - out[idx] = (half)(recip * (float)x[{idx_exp}]); - }} -}}"); - let rms_norm_code_name = format!("kernel_{}", hash(&rms_norm_code)); - rms_norm_code = rms_norm_code.replace("mkernel", &rms_norm_code_name); + const float mean = all_sum/ne00; + const float scale = 1.0f/sqrt(mean + eps); + + device half4 * y = (device half4 *) (dst + tgpig*ne00); + for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { + y[i00] = (half4)(x[i00] * scale); + } +}"; Self { - square_mean_pipeline: compile_function( - &square_mean_code_name, - &square_mean_code, - &device, - ), - rms_norm_pipeline: compile_function(&rms_norm_code_name, &rms_norm_code, &device), + pipeline: compile_function(&"kernel_rms_norm", kernel_code, &device), device, queue, - dyn_symbols, epsilon, - dyn_map, } } } impl MetalKernel for MetalRMSNorm { - fn intermediate_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec { - let mut meaned_shape = input_shapes[0]; - meaned_shape.remove_dim(meaned_shape.len() - 1); - vec![meaned_shape.n_elements() * size_of::()] - } fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec { vec![input_shapes[0].n_elements() * size_of::()] } @@ -108,62 +92,58 @@ impl MetalKernel for MetalRMSNorm { &self, inputs: &[(&Buffer, ShapeTracker)], command_buffer: &CommandBufferRef, - intermediate_buffers: &[&Buffer], + _: &[&Buffer], output_buffers: &[&Buffer], ) { - let mut meaned_shape = inputs[0].1; - meaned_shape.remove_dim(meaned_shape.len() - 1); - // Setup buffers - let front_size: usize = inputs[0] - .1 - .shape() - .iter() - .take(meaned_shape.len()) - .map(|i| i.to_usize().unwrap()) - .product(); - let back_size = 1; - let dim_size = inputs[0].1.shape()[meaned_shape.len()].to_usize().unwrap(); - let encoder = command_buffer.compute_command_encoder_with_descriptor(ComputePassDescriptor::new()); - encoder.set_compute_pipeline_state(&self.square_mean_pipeline); - let meaned_elements = meaned_shape.n_elements().to_usize().unwrap(); + encoder.set_compute_pipeline_state(&self.pipeline); + let ne00 = inputs[0].1.shape()[2].to_usize().unwrap(); + let nb01 = inputs[0].1.strides()[1].to_usize().unwrap() * size_of::(); // Set inputs encoder.set_buffer(0, Some(inputs[0].0), 0); - encoder.set_buffer(1, Some(intermediate_buffers[0]), 0); - encoder.set_int(2, meaned_elements as u32); - encoder.set_int(3, front_size as u32); - encoder.set_int(4, back_size as u32); - encoder.set_int(5, dim_size as u32); - input_dyn_dims( - &self.dyn_symbols, - unsafe { self.dyn_map.as_ref().unwrap() }, - encoder, - 6, + encoder.set_buffer(1, Some(output_buffers[0]), 0); + encoder.set_bytes( + 2, + size_of::() as u64, + &(ne00 as i64) as *const i64 as *const _, ); - - encoder.dispatch_1d(meaned_elements); - encoder.end_encoding(); - - let encoder = - command_buffer.compute_command_encoder_with_descriptor(ComputePassDescriptor::new()); - encoder.set_compute_pipeline_state(&self.rms_norm_pipeline); - - // Set inputs - encoder.set_buffer(0, Some(intermediate_buffers[0]), 0); - encoder.set_buffer(1, Some(inputs[0].0), 0); - encoder.set_buffer(2, Some(output_buffers[0]), 0); - encoder.set_int(3, inputs[0].1.n_elements().to_usize().unwrap() as u32); - input_dyn_dims( - &self.dyn_symbols, - unsafe { self.dyn_map.as_ref().unwrap() }, - encoder, + encoder.set_bytes( + 3, + size_of::() as u64, + &(nb01 as u64) as *const u64 as *const _, + ); + encoder.set_bytes( 4, + size_of::() as u64, + &self.epsilon as *const f32 as *const _, ); - // Execute - encoder.dispatch_1d(inputs[0].1.n_elements().to_usize().unwrap()); + let mut nth = 32; // SIMD width + while nth < ne00 / 4 && nth < 1024 { + nth *= 2; + } + let n_rows = inputs[0] + .1 + .shape() + .into_iter() + .take(2) + .map(|i| i.to_usize().unwrap()) + .product::(); + encoder.set_threadgroup_memory_length(0, 32 * size_of::() as u64); + encoder.dispatch_thread_groups( + MTLSize { + width: n_rows as u64, + height: 1, + depth: 1, + }, + MTLSize { + width: nth as u64, + height: 1, + depth: 1, + }, + ); encoder.end_encoding(); } } @@ -179,18 +159,12 @@ impl Operator for MetalRMSNorm { .as_any() .downcast_ref::() .unwrap(); - let mut meaned_shape = tensors[0].1; - meaned_shape.remove_dim(meaned_shape.len() - 1); - let meaned = self.device.new_buffer( - (meaned_shape.n_elements().to_usize().unwrap() * size_of::()) as u64, - MTLResourceOptions::StorageModeShared, - ); let out = self.device.new_buffer( (tensors[0].1.n_elements().to_usize().unwrap() * size_of::()) as u64, MTLResourceOptions::StorageModeShared, ); - self.metal_forward(&[(a, tensors[0].1)], command_buffer, &[&meaned], &[&out]); + self.metal_forward(&[(a, tensors[0].1)], command_buffer, &[], &[&out]); command_buffer.commit(); command_buffer.wait_until_completed(); @@ -199,28 +173,13 @@ impl Operator for MetalRMSNorm { }) } - fn custom(&mut self, key: &str, input: Box) -> Option> { + fn custom(&mut self, key: &str, _: Box) -> Option> { if key == "metal" { #[allow(clippy::arc_with_non_send_sync)] return Some(Box::new(MetalKernelWrapper(Arc::new(Box::new( self.clone(), ))))); } - // This op can accept non contiguous inputs - if key == "non_contiguous" { - return Some(Box::new(())); - } - if key == "recompile_shapes" { - if let Some(input_shapes) = input.downcast_ref::>() { - *self = Self::new( - self.epsilon, - self.device.clone(), - self.queue.clone(), - input_shapes[0], - self.dyn_map, - ) - } - } None } } @@ -286,24 +245,33 @@ impl Compiler for RMSNormCompiler { else { unreachable!() }; - let x = graph.get_sources(square)[0]; - if !graph.get_sources(square).iter().all(|(i, _, _)| *i == x.0) { + let (mut x, _, mut sh) = graph.get_sources(square)[0]; + if !graph.get_sources(square).iter().all(|(i, _, _)| *i == x) { continue; } - if !graph.get_sources(mul).iter().any(|(i, _, _)| *i == x.0) { + if !graph.get_sources(mul).iter().any(|(i, _, _)| *i == x) { continue; } + // Input must be contiguous + if !sh.is_contiguous() || sh.is_sliced() || sh.is_padded() { + x = graph + .add_op(MetalContiguous::::new( + sh, + dev.clone(), + queue.clone(), + &mut HashMap::new(), + &graph.dyn_map, + )) + .input(x, 0, sh) + .finish(); + sh = sh.contiguous(); + } + // Insert RMSNorm op let rms_norm = graph - .add_op(MetalRMSNorm::new( - epsilon_num, - dev.clone(), - queue.clone(), - x.2, - &graph.dyn_map, - )) - .input(x.0, 0, x.2) + .add_op(MetalRMSNorm::new(epsilon_num, dev.clone(), queue.clone())) + .input(x, 0, sh) .finish(); // Create edges to dests