Skip to content

Commit

Permalink
Fixed layer norm
Browse files Browse the repository at this point in the history
  • Loading branch information
jafioti committed Jan 21, 2024
1 parent 8bd7598 commit 4219d8e
Show file tree
Hide file tree
Showing 10 changed files with 181 additions and 193 deletions.
2 changes: 1 addition & 1 deletion src/compilers/metal/fp16/mean_reduce.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ pub struct MetalMeanReduce(
ComputePipelineState,
CommandQueue,
Device,
usize,
pub usize,
Vec<char>,
*const HashMap<char, usize>,
);
Expand Down
4 changes: 2 additions & 2 deletions src/compilers/metal/fp16/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use half::f16;

mod matmul;
mod mean_reduce;
mod rms_norm;
mod std_norm;

pub type MetalFp16Compiler = (
super::prim::PrimitiveCompiler<f16>,
Expand All @@ -18,7 +18,7 @@ pub type MetalFp16Compiler = (
),
matmul::MetalMatMulCompiler,
mean_reduce::MeanReduceCompiler,
rms_norm::RMSNormCompiler,
std_norm::StdNormCompiler,
super::other::CopyCompiler<f16>,
super::other::ContiguousElimination<f16>,
// super::elementwise_fusion::ElementwiseFusionCompiler<f16>,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,80 +12,58 @@ use crate::{
use super::mean_reduce::MetalMeanReduce;
use metal_rs::{objc::rc::autoreleasepool, *};

/// Special kernel for efficient rms norming
/// Special kernel for efficient std norming
#[derive(LuminalEq, LuminalPrint, Clone)]
pub struct MetalRMSNorm {
pub struct MetalStdNorm {
pipeline: ComputePipelineState,
device: Device,
queue: CommandQueue,
epsilon: f32, // Epsilon
}

impl MetalRMSNorm {
impl MetalStdNorm {
fn new(epsilon: f32, device: Device, queue: CommandQueue) -> Self {
let kernel_code = "#include <metal_stdlib>
#define SIMD_WIDTH 32
using namespace metal;
kernel void kernel_rms_norm(
device const void * src0 [[buffer(0)]],
kernel void kernel_std_norm(
device const half * src0 [[buffer(0)]],
device half * dst [[buffer(1)]],
constant int64_t & ne00 [[buffer(2)]],
constant uint64_t & nb01 [[buffer(3)]],
constant float & eps [[buffer(4)]],
constant int64_t & row_size [[buffer(2)]],
constant float & eps [[buffer(3)]],
threadgroup float * buf [[threadgroup(0)]],
uint tgpig[[threadgroup_position_in_grid]],
uint tpitg[[thread_position_in_threadgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint ntg[[threads_per_threadgroup]]) {
device const half4 * x = (device const half4 *) ((device const char *) src0 + tgpig*nb01);
uint threadgroup_pos[[threadgroup_position_in_grid]],
uint simdgroup_pos[[thread_index_in_simdgroup]]) {
device const half4 * x = (device const half4 *) (src0 + threadgroup_pos * row_size);
float4 sumf = 0;
float all_sum = 0;
// parallel sum
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
sumf += (float4)x[i00] * (float4)x[i00];
for (int i = simdgroup_pos; i < row_size/4; i += SIMD_WIDTH) {
sumf += (float4)x[i] * (float4)x[i];
}
all_sum = sumf[0] + sumf[1] + sumf[2] + sumf[3];
float all_sum = sumf[0] + sumf[1] + sumf[2] + sumf[3];
all_sum = simd_sum(all_sum);
if (ntg > SIMD_WIDTH) {
if (sgitg == 0) {
buf[tiisg] = 0.0f;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) {
buf[sgitg] = all_sum;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
all_sum = buf[tiisg];
all_sum = simd_sum(all_sum);
}
const float mean = all_sum/ne00;
const float mean = all_sum/row_size;
const float scale = 1.0f/sqrt(mean + eps);
device half4 * y = (device half4 *) (dst + tgpig*ne00);
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
y[i00] = (half4)(x[i00] * scale);
device half4 * y = (device half4 *) (dst + threadgroup_pos * row_size);
for (int i = simdgroup_pos; i < row_size/4; i += SIMD_WIDTH) {
y[i] = (half4)(x[i] * scale);
}
}";

Self {
pipeline: compile_function(&"kernel_rms_norm", kernel_code, &device),
pipeline: compile_function(&"kernel_std_norm", kernel_code, &device),
device,
queue,
epsilon,
}
}
}

impl MetalKernel for MetalRMSNorm {
impl MetalKernel for MetalStdNorm {
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<BigExpression> {
vec![input_shapes[0].n_elements() * size_of::<f16>()]
}
Expand All @@ -100,37 +78,29 @@ impl MetalKernel for MetalRMSNorm {
let encoder =
command_buffer.compute_command_encoder_with_descriptor(ComputePassDescriptor::new());
encoder.set_compute_pipeline_state(&self.pipeline);
let n_dims = inputs[0].1.len();
let ne00 = inputs[0].1.shape()[n_dims - 1].to_usize().unwrap();
let nb01 = inputs[0].1.strides()[n_dims - 2].to_usize().unwrap() * size_of::<f16>();
let row_size = inputs[0].1.shape().last().unwrap().to_usize().unwrap();

// Set inputs
encoder.set_buffer(0, Some(inputs[0].0), 0);
encoder.set_buffer(1, Some(output_buffers[0]), 0);
encoder.set_i64(2, ne00 as i64);
encoder.set_u64(3, nb01 as u64);
encoder.set_f32(4, self.epsilon);

let mut nth = 32; // SIMD width
while nth < ne00 / 4 && nth < 1024 {
nth *= 2;
}
let n_rows = inputs[0]
encoder.set_i64(2, row_size as i64);
encoder.set_f32(3, self.epsilon);
let batch_size = inputs[0]
.1
.shape()
.into_iter()
.take(2)
.take(inputs[0].1.len() - 1)
.map(|i| i.to_usize().unwrap())
.product::<usize>();
encoder.set_threadgroup_memory_length(0, 32 * size_of::<f32>() as u64);
encoder.dispatch_thread_groups(
MTLSize {
width: n_rows as u64,
width: batch_size as u64,
height: 1,
depth: 1,
},
MTLSize {
width: nth as u64,
width: 32.min(row_size / 4) as u64,
height: 1,
depth: 1,
},
Expand All @@ -139,7 +109,7 @@ impl MetalKernel for MetalRMSNorm {
}
}

impl Operator for MetalRMSNorm {
impl Operator for MetalStdNorm {
fn process(&mut self, tensors: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
autoreleasepool(|| {
let command_buffer = self.queue.new_command_buffer();
Expand Down Expand Up @@ -177,9 +147,9 @@ impl Operator for MetalRMSNorm {

/// Replace the mean reduce pattern with a special kernel. This is meant to be ran **after** the FakeSumReduceCompiler.
#[derive(Default, Debug)]
pub struct RMSNormCompiler;
pub struct StdNormCompiler;

impl Compiler for RMSNormCompiler {
impl Compiler for StdNormCompiler {
fn compile<T: ToIdsMut>(&self, graph: &mut Graph, mut remap: T) {
let dev = Device::system_default().unwrap();
let queue = dev.new_command_queue();
Expand Down Expand Up @@ -234,9 +204,30 @@ impl Compiler for RMSNormCompiler {
.unwrap()
.0
else {
unreachable!()
continue;
};
let (mut x, _, mut sh) = graph.get_sources(square)[0];
if let Some(mean_reduce) = graph
.graph
.node_weight(mean)
.unwrap()
.as_any()
.downcast_ref::<MetalMeanReduce>()
{
if mean_reduce.3 != sh.len() - 1 {
continue;
}
}
if sh
.shape()
.last()
.unwrap()
.to_usize()
.map(|i| i % 32 != 0 || i < 32)
.unwrap_or(true)
{
continue;
}
if !graph.get_sources(square).iter().all(|(i, _, _)| *i == x) {
continue;
}
Expand All @@ -260,7 +251,7 @@ impl Compiler for RMSNormCompiler {

// Insert RMSNorm op
let rms_norm = graph
.add_op(MetalRMSNorm::new(epsilon_num, dev.clone(), queue.clone()))
.add_op(MetalStdNorm::new(epsilon_num, dev.clone(), queue.clone()))
.input(x, 0, sh)
.finish();

Expand Down
Loading

0 comments on commit 4219d8e

Please sign in to comment.