diff --git a/src/compilers/metal/fp16/mod.rs b/src/compilers/metal/fp16/mod.rs index 070e7025..ee18c953 100644 --- a/src/compilers/metal/fp16/mod.rs +++ b/src/compilers/metal/fp16/mod.rs @@ -2,7 +2,7 @@ use half::f16; mod matmul; mod mean_reduce; -mod other; +pub mod other; mod rms_norm; pub type MetalFp16Compiler = ( diff --git a/src/compilers/metal/other.rs b/src/compilers/metal/other.rs index 6a4f1d8f..5649cdc7 100644 --- a/src/compilers/metal/other.rs +++ b/src/compilers/metal/other.rs @@ -10,7 +10,7 @@ use crate::{ prelude::*, }; -use super::binary::MetalSub; +use super::{binary::MetalSub, fp16::other::MetalExp}; /// Sometimes CopyTo -> CopyFrom and CopyFrom -> CopyTo patterns remain, so let's clean them up #[derive(LuminalPrint, Default)] @@ -373,3 +373,144 @@ impl Compiler for ContiguousElimination { } } } + +/// Special kernel for efficient mean reduction +#[derive(LuminalEq, LuminalPrint, Clone)] +pub struct MetalSoftmax { + pipeline: ComputePipelineState, + queue: CommandQueue, + device: Device, + _phantom: PhantomData, +} + +impl MetalSoftmax { + fn new(device: Device, queue: CommandQueue) -> Self { + Self { + pipeline: compile_function( + "metal_softmax", + " +#include +using namespace metal; +kernel void metal_softmax(device half *out [[buffer(0)]], device int& n_elements [[buffer(1)]], uint idx [[thread_position_in_grid]]) { +}", + &device, + ), + queue, + device, + _phantom: Default::default(), + } + } +} + +impl MetalKernel for MetalSoftmax { + fn output_buffer_sizes(&self, _: &[ShapeTracker]) -> Vec { + vec![BigExpression::from(0) * std::mem::size_of::()] + } + fn metal_forward( + &self, + _: &[(&Buffer, ShapeTracker)], + _: &CommandBufferRef, + _: &[&Buffer], + _: &[&Buffer], + ) { + // let encoder = + // command_buffer.compute_command_encoder_with_descriptor(ComputePassDescriptor::new()); + // encoder.set_compute_pipeline_state(&self.pipeline); + + // // Set inputs + // encoder.set_buffer(0, Some(output_buffers[0]), 0); + // encoder.set_int(1, size as u32); + + // // Execute + // encoder.dispatch_1d(size); + // encoder.end_encoding(); + } +} + +impl Operator for MetalSoftmax { + fn process(&mut self, _: Vec<(InputTensor, ShapeTracker)>) -> Vec { + // autoreleasepool(|| { + // // Setup command queue / command buffer / encoder + // let command_buffer = self.1.new_command_buffer(); + // let size = self.3.exec(unsafe { self.4.as_ref().unwrap() }).unwrap(); + // let out = self.2.new_buffer( + // (size * std::mem::size_of::()) as u64, + // MTLResourceOptions::StorageModeShared, + // ); + + // self.metal_forward(&[], command_buffer, &[], &[&out]); + + // command_buffer.commit(); + // command_buffer.wait_until_completed(); + + // vec![Tensor::new(out)] + // }) + vec![] + } + + fn custom(&mut self, key: &str, _: Box) -> Option> { + if key == "metal" { + #[allow(clippy::arc_with_non_send_sync)] + return Some(Box::new(MetalKernelWrapper(Arc::new(Box::new( + self.clone(), + ))))); + } + None + } +} + +/// Replace the mean reduce pattern with a special kernel. This is meant to be ran **after** the FakeSumReduceCompiler. +#[derive(Default, LuminalPrint)] +pub struct MetalSoftmaxCompiler(PhantomData); + +impl Compiler for MetalSoftmaxCompiler { + fn compile(&self, graph: &mut Graph, _: To) { + let dev = Device::system_default().unwrap(); + let queue = dev.new_command_queue(); + let (mut x1, mut max_reduce, mut sub, mut exp, mut sum_reduce, mut recip, mut mul) = ( + NodeIndex::default(), + NodeIndex::default(), + NodeIndex::default(), + NodeIndex::default(), + NodeIndex::default(), + NodeIndex::default(), + NodeIndex::default(), + ); + + let mut searcher = SelectOp::new() + .ptr(&mut x1) + .edge( + SelectOp::new() + .ty::>() + .ptr(&mut max_reduce), + ) + .edge(SelectOp::new().ty::>().ptr(&mut sub)) + .edge(SelectOp::new().ty::().ptr(&mut exp)) + .edge( + SelectOp::new() + .ty::>() + .ptr(&mut sum_reduce), + ) + .edge(SelectOp::new().ty::>().ptr(&mut recip)) + .edge(SelectOp::new().ty::>().ptr(&mut mul)) + .search(graph); + + while searcher.next_match() { + if graph.get_sources(mul).iter().any(|(i, _, _)| *i == exp) + && graph.get_sources(sub).iter().any(|(i, _, _)| *i == x1) + { + let softmax = graph + .add_op(MetalSoftmax::::new(dev.clone(), queue.clone())) + .finish(); + move_outgoing_edge(mul, softmax, &mut graph.graph); + + graph.graph.remove_node(mul); + graph.graph.remove_node(recip); + graph.graph.remove_node(max_reduce); + graph.graph.remove_node(sum_reduce); + graph.graph.remove_node(sub); + graph.graph.remove_node(exp); + } + } + } +}