Skip to content

Commit

Permalink
ggml rms norm
Browse files Browse the repository at this point in the history
  • Loading branch information
Joe Fioti authored and Joe Fioti committed Jan 8, 2024
1 parent f9b52f0 commit 703f4d3
Showing 1 changed file with 113 additions and 145 deletions.
258 changes: 113 additions & 145 deletions src/compilers/metal/fp16/rms_norm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,91 +15,75 @@ use metal_rs::{objc::rc::autoreleasepool, *};
/// Special kernel for efficient rms norming
#[derive(LuminalEq, LuminalPrint, Clone)]
pub struct MetalRMSNorm {
square_mean_pipeline: ComputePipelineState, // Square-Mean kernel
rms_norm_pipeline: ComputePipelineState, // RMSNorm kernel
pipeline: ComputePipelineState,
device: Device,
queue: CommandQueue,
dyn_symbols: Vec<char>,
epsilon: f32, // Epsilon
dyn_map: *const HashMap<char, usize>,
}

impl MetalRMSNorm {
fn new(
epsilon: f32,
device: Device,
queue: CommandQueue,
inp_shape: ShapeTracker,
dyn_map: *const HashMap<char, usize>,
) -> Self {
let (idx_exp, valid_exp) = get_idx_valid_exps(inp_shape);
let (dyn_symbols, rendered) = render_dyn_dim_inputs(&[inp_shape], 6);
let mut square_mean_code = format!(
"
#include <metal_stdlib>
fn new(epsilon: f32, device: Device, queue: CommandQueue) -> Self {
let kernel_code = "#include <metal_stdlib>
using namespace metal;
kernel void mkernel(device half *inp [[buffer(0)]], device float *out [[buffer(1)]], device int& n_elements [[buffer(2)]], device int& front_size [[buffer(3)]], device int& back_size [[buffer(4)]], device int& dim_size [[buffer(5)]], uint i_ [[thread_position_in_grid]]{rendered}) {{
if (i_ < n_elements) {{
int a_ = i_ / back_size;
int b_ = i_ % back_size;
float reduce_value = 0.0;
int add_factor = a_ * dim_size * back_size + b_;
for (int c_ = 0; c_ < dim_size * back_size; c_ += back_size) {{
int idx = add_factor + c_;
if (({valid_exp}) != 0) {{
float val = (float)inp[{idx_exp}];
reduce_value += (val * val);
}}
}}
out[i_] = (reduce_value / (float)dim_size);
}}
}}
");
let square_mean_code_name = format!("kernel_{}", hash(&square_mean_code));
square_mean_code = square_mean_code.replace("mkernel", &square_mean_code_name);
kernel void kernel_rms_norm(
device const void * src0,
device half * dst,
constant int64_t & ne00,
constant uint64_t & nb01,
constant float & eps,
threadgroup float * buf [[threadgroup(0)]],
uint tgpig[[threadgroup_position_in_grid]],
uint tpitg[[thread_position_in_threadgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint ntg[[threads_per_threadgroup]]) {
device const half4 * x = (device const half4 *) ((device const char *) src0 + tgpig*nb01);
let mut meaned_shape = inp_shape;
let meaned_size = meaned_shape.remove_dim(meaned_shape.len() - 1);
meaned_shape.expand(meaned_shape.len(), meaned_size);
let (meaned_idx_exp, _) = get_idx_valid_exps(meaned_shape);
let (_, rendered) = render_dyn_dim_inputs(&[inp_shape], 4);
let mut rms_norm_code = format!("
#include <metal_stdlib>
using namespace metal;
float4 sumf = 0;
float all_sum = 0;
// parallel sum
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
sumf += (float4)x[i00] * (float4)x[i00];
}
all_sum = sumf[0] + sumf[1] + sumf[2] + sumf[3];
all_sum = simd_sum(all_sum);
if (ntg > 32) {
if (sgitg == 0) {
buf[tiisg] = 0.0f;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) {
buf[sgitg] = all_sum;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
all_sum = buf[tiisg];
all_sum = simd_sum(all_sum);
}
kernel void mkernel(device float *inp [[buffer(0)]], device half *x [[buffer(1)]], device half *out [[buffer(2)]], device int& n_elements [[buffer(3)]], uint idx [[thread_position_in_grid]]{rendered}) {{
if (idx < n_elements) {{
float added = inp[{meaned_idx_exp}] + {epsilon};
float sq = sqrt(added);
float recip = 1.0f / sq;
out[idx] = (half)(recip * (float)x[{idx_exp}]);
}}
}}");
let rms_norm_code_name = format!("kernel_{}", hash(&rms_norm_code));
rms_norm_code = rms_norm_code.replace("mkernel", &rms_norm_code_name);
const float mean = all_sum/ne00;
const float scale = 1.0f/sqrt(mean + eps);
device half4 * y = (device half4 *) (dst + tgpig*ne00);
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
y[i00] = (half4)(x[i00] * scale);
}
}";

Self {
square_mean_pipeline: compile_function(
&square_mean_code_name,
&square_mean_code,
&device,
),
rms_norm_pipeline: compile_function(&rms_norm_code_name, &rms_norm_code, &device),
pipeline: compile_function(&"kernel_rms_norm", kernel_code, &device),
device,
queue,
dyn_symbols,
epsilon,
dyn_map,
}
}
}

impl MetalKernel for MetalRMSNorm {
fn intermediate_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<BigExpression> {
let mut meaned_shape = input_shapes[0];
meaned_shape.remove_dim(meaned_shape.len() - 1);
vec![meaned_shape.n_elements() * size_of::<f32>()]
}
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<BigExpression> {
vec![input_shapes[0].n_elements() * size_of::<f16>()]
}
Expand All @@ -108,62 +92,58 @@ impl MetalKernel for MetalRMSNorm {
&self,
inputs: &[(&Buffer, ShapeTracker)],
command_buffer: &CommandBufferRef,
intermediate_buffers: &[&Buffer],
_: &[&Buffer],
output_buffers: &[&Buffer],
) {
let mut meaned_shape = inputs[0].1;
meaned_shape.remove_dim(meaned_shape.len() - 1);
// Setup buffers
let front_size: usize = inputs[0]
.1
.shape()
.iter()
.take(meaned_shape.len())
.map(|i| i.to_usize().unwrap())
.product();
let back_size = 1;
let dim_size = inputs[0].1.shape()[meaned_shape.len()].to_usize().unwrap();

let encoder =
command_buffer.compute_command_encoder_with_descriptor(ComputePassDescriptor::new());
encoder.set_compute_pipeline_state(&self.square_mean_pipeline);
let meaned_elements = meaned_shape.n_elements().to_usize().unwrap();
encoder.set_compute_pipeline_state(&self.pipeline);
let ne00 = inputs[0].1.shape()[2].to_usize().unwrap();
let nb01 = inputs[0].1.strides()[1].to_usize().unwrap() * size_of::<f16>();

// Set inputs
encoder.set_buffer(0, Some(inputs[0].0), 0);
encoder.set_buffer(1, Some(intermediate_buffers[0]), 0);
encoder.set_int(2, meaned_elements as u32);
encoder.set_int(3, front_size as u32);
encoder.set_int(4, back_size as u32);
encoder.set_int(5, dim_size as u32);
input_dyn_dims(
&self.dyn_symbols,
unsafe { self.dyn_map.as_ref().unwrap() },
encoder,
6,
encoder.set_buffer(1, Some(output_buffers[0]), 0);
encoder.set_bytes(
2,
size_of::<i64>() as u64,
&(ne00 as i64) as *const i64 as *const _,
);

encoder.dispatch_1d(meaned_elements);
encoder.end_encoding();

let encoder =
command_buffer.compute_command_encoder_with_descriptor(ComputePassDescriptor::new());
encoder.set_compute_pipeline_state(&self.rms_norm_pipeline);

// Set inputs
encoder.set_buffer(0, Some(intermediate_buffers[0]), 0);
encoder.set_buffer(1, Some(inputs[0].0), 0);
encoder.set_buffer(2, Some(output_buffers[0]), 0);
encoder.set_int(3, inputs[0].1.n_elements().to_usize().unwrap() as u32);
input_dyn_dims(
&self.dyn_symbols,
unsafe { self.dyn_map.as_ref().unwrap() },
encoder,
encoder.set_bytes(
3,
size_of::<u64>() as u64,
&(nb01 as u64) as *const u64 as *const _,
);
encoder.set_bytes(
4,
size_of::<f32>() as u64,
&self.epsilon as *const f32 as *const _,
);

// Execute
encoder.dispatch_1d(inputs[0].1.n_elements().to_usize().unwrap());
let mut nth = 32; // SIMD width
while nth < ne00 / 4 && nth < 1024 {
nth *= 2;
}
let n_rows = inputs[0]
.1
.shape()
.into_iter()
.take(2)
.map(|i| i.to_usize().unwrap())
.product::<usize>();
encoder.set_threadgroup_memory_length(0, 32 * size_of::<f32>() as u64);
encoder.dispatch_thread_groups(
MTLSize {
width: n_rows as u64,
height: 1,
depth: 1,
},
MTLSize {
width: nth as u64,
height: 1,
depth: 1,
},
);
encoder.end_encoding();
}
}
Expand All @@ -179,18 +159,12 @@ impl Operator for MetalRMSNorm {
.as_any()
.downcast_ref::<Buffer>()
.unwrap();
let mut meaned_shape = tensors[0].1;
meaned_shape.remove_dim(meaned_shape.len() - 1);
let meaned = self.device.new_buffer(
(meaned_shape.n_elements().to_usize().unwrap() * size_of::<f32>()) as u64,
MTLResourceOptions::StorageModeShared,
);
let out = self.device.new_buffer(
(tensors[0].1.n_elements().to_usize().unwrap() * size_of::<f16>()) as u64,
MTLResourceOptions::StorageModeShared,
);

self.metal_forward(&[(a, tensors[0].1)], command_buffer, &[&meaned], &[&out]);
self.metal_forward(&[(a, tensors[0].1)], command_buffer, &[], &[&out]);

command_buffer.commit();
command_buffer.wait_until_completed();
Expand All @@ -199,28 +173,13 @@ impl Operator for MetalRMSNorm {
})
}

fn custom(&mut self, key: &str, input: Box<dyn Any>) -> Option<Box<dyn Any>> {
fn custom(&mut self, key: &str, _: Box<dyn Any>) -> Option<Box<dyn Any>> {
if key == "metal" {
#[allow(clippy::arc_with_non_send_sync)]
return Some(Box::new(MetalKernelWrapper(Arc::new(Box::new(
self.clone(),
)))));
}
// This op can accept non contiguous inputs
if key == "non_contiguous" {
return Some(Box::new(()));
}
if key == "recompile_shapes" {
if let Some(input_shapes) = input.downcast_ref::<Vec<ShapeTracker>>() {
*self = Self::new(
self.epsilon,
self.device.clone(),
self.queue.clone(),
input_shapes[0],
self.dyn_map,
)
}
}
None
}
}
Expand Down Expand Up @@ -286,24 +245,33 @@ impl Compiler for RMSNormCompiler {
else {
unreachable!()
};
let x = graph.get_sources(square)[0];
if !graph.get_sources(square).iter().all(|(i, _, _)| *i == x.0) {
let (mut x, _, mut sh) = graph.get_sources(square)[0];
if !graph.get_sources(square).iter().all(|(i, _, _)| *i == x) {
continue;
}
if !graph.get_sources(mul).iter().any(|(i, _, _)| *i == x.0) {
if !graph.get_sources(mul).iter().any(|(i, _, _)| *i == x) {
continue;
}

// Input must be contiguous
if !sh.is_contiguous() || sh.is_sliced() || sh.is_padded() {
x = graph
.add_op(MetalContiguous::<f16>::new(
sh,
dev.clone(),
queue.clone(),
&mut HashMap::new(),
&graph.dyn_map,
))
.input(x, 0, sh)
.finish();
sh = sh.contiguous();
}

// Insert RMSNorm op
let rms_norm = graph
.add_op(MetalRMSNorm::new(
epsilon_num,
dev.clone(),
queue.clone(),
x.2,
&graph.dyn_map,
))
.input(x.0, 0, x.2)
.add_op(MetalRMSNorm::new(epsilon_num, dev.clone(), queue.clone()))
.input(x, 0, sh)
.finish();

// Create edges to dests
Expand Down

0 comments on commit 703f4d3

Please sign in to comment.