Skip to content

Commit

Permalink
Added unused softmax op
Browse files Browse the repository at this point in the history
  • Loading branch information
Joe Fioti authored and Joe Fioti committed Jan 9, 2024
1 parent 5cdc559 commit 67366e1
Show file tree
Hide file tree
Showing 2 changed files with 143 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/compilers/metal/fp16/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use half::f16;

mod matmul;
mod mean_reduce;
mod other;
pub mod other;
mod rms_norm;

pub type MetalFp16Compiler = (
Expand Down
143 changes: 142 additions & 1 deletion src/compilers/metal/other.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::{
prelude::*,
};

use super::binary::MetalSub;
use super::{binary::MetalSub, fp16::other::MetalExp};

/// Sometimes CopyTo -> CopyFrom and CopyFrom -> CopyTo patterns remain, so let's clean them up
#[derive(LuminalPrint, Default)]
Expand Down Expand Up @@ -373,3 +373,144 @@ impl<T: MetalFloat> Compiler for ContiguousElimination<T> {
}
}
}

/// Special kernel for efficient mean reduction
#[derive(LuminalEq, LuminalPrint, Clone)]
pub struct MetalSoftmax<T: MetalFloat> {
pipeline: ComputePipelineState,
queue: CommandQueue,
device: Device,
_phantom: PhantomData<T>,
}

impl<T: MetalFloat> MetalSoftmax<T> {
fn new(device: Device, queue: CommandQueue) -> Self {
Self {
pipeline: compile_function(
"metal_softmax",
"
#include <metal_stdlib>
using namespace metal;
kernel void metal_softmax(device half *out [[buffer(0)]], device int& n_elements [[buffer(1)]], uint idx [[thread_position_in_grid]]) {
}",
&device,
),
queue,
device,
_phantom: Default::default(),
}
}
}

impl<T: MetalFloat> MetalKernel for MetalSoftmax<T> {
fn output_buffer_sizes(&self, _: &[ShapeTracker]) -> Vec<BigExpression> {
vec![BigExpression::from(0) * std::mem::size_of::<f16>()]
}
fn metal_forward(
&self,
_: &[(&Buffer, ShapeTracker)],
_: &CommandBufferRef,
_: &[&Buffer],
_: &[&Buffer],
) {
// let encoder =
// command_buffer.compute_command_encoder_with_descriptor(ComputePassDescriptor::new());
// encoder.set_compute_pipeline_state(&self.pipeline);

// // Set inputs
// encoder.set_buffer(0, Some(output_buffers[0]), 0);
// encoder.set_int(1, size as u32);

// // Execute
// encoder.dispatch_1d(size);
// encoder.end_encoding();
}
}

impl<T: MetalFloat> Operator for MetalSoftmax<T> {
fn process(&mut self, _: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
// autoreleasepool(|| {
// // Setup command queue / command buffer / encoder
// let command_buffer = self.1.new_command_buffer();
// let size = self.3.exec(unsafe { self.4.as_ref().unwrap() }).unwrap();
// let out = self.2.new_buffer(
// (size * std::mem::size_of::<f16>()) as u64,
// MTLResourceOptions::StorageModeShared,
// );

// self.metal_forward(&[], command_buffer, &[], &[&out]);

// command_buffer.commit();
// command_buffer.wait_until_completed();

// vec![Tensor::new(out)]
// })
vec![]
}

fn custom(&mut self, key: &str, _: Box<dyn Any>) -> Option<Box<dyn Any>> {
if key == "metal" {
#[allow(clippy::arc_with_non_send_sync)]
return Some(Box::new(MetalKernelWrapper(Arc::new(Box::new(
self.clone(),
)))));
}
None
}
}

/// Replace the mean reduce pattern with a special kernel. This is meant to be ran **after** the FakeSumReduceCompiler.
#[derive(Default, LuminalPrint)]
pub struct MetalSoftmaxCompiler<T: MetalFloat>(PhantomData<T>);

impl<T: MetalFloat> Compiler for MetalSoftmaxCompiler<T> {
fn compile<To: ToIdsMut>(&self, graph: &mut Graph, _: To) {
let dev = Device::system_default().unwrap();
let queue = dev.new_command_queue();
let (mut x1, mut max_reduce, mut sub, mut exp, mut sum_reduce, mut recip, mut mul) = (
NodeIndex::default(),
NodeIndex::default(),
NodeIndex::default(),
NodeIndex::default(),
NodeIndex::default(),
NodeIndex::default(),
NodeIndex::default(),
);

let mut searcher = SelectOp::new()
.ptr(&mut x1)
.edge(
SelectOp::new()
.ty::<MetalMaxReduce<T>>()
.ptr(&mut max_reduce),
)
.edge(SelectOp::new().ty::<MetalSub<T>>().ptr(&mut sub))
.edge(SelectOp::new().ty::<MetalExp>().ptr(&mut exp))
.edge(
SelectOp::new()
.ty::<MetalSumReduce<T>>()
.ptr(&mut sum_reduce),
)
.edge(SelectOp::new().ty::<MetalRecip<T>>().ptr(&mut recip))
.edge(SelectOp::new().ty::<MetalMul<T>>().ptr(&mut mul))
.search(graph);

while searcher.next_match() {
if graph.get_sources(mul).iter().any(|(i, _, _)| *i == exp)
&& graph.get_sources(sub).iter().any(|(i, _, _)| *i == x1)
{
let softmax = graph
.add_op(MetalSoftmax::<T>::new(dev.clone(), queue.clone()))
.finish();
move_outgoing_edge(mul, softmax, &mut graph.graph);

graph.graph.remove_node(mul);
graph.graph.remove_node(recip);
graph.graph.remove_node(max_reduce);
graph.graph.remove_node(sum_reduce);
graph.graph.remove_node(sub);
graph.graph.remove_node(exp);
}
}
}
}

0 comments on commit 67366e1

Please sign in to comment.