diff --git a/docs/blog/gpu.mdx b/docs/blog/gpu.mdx index 581e72b0..935089f7 100644 --- a/docs/blog/gpu.mdx +++ b/docs/blog/gpu.mdx @@ -1,6 +1,6 @@ --- title: 'Compiling fast GPU kernels' -description: 'Bringing support for Apple and Nvidia GPUs to Luminal through compilers' +description: 'Bringing support for Nvidia and Apple GPUs to Luminal through compilers' 'og:image': '/images/gpu_notext.png' 'twitter:image': '/images/gpu_notext.png' --- @@ -8,14 +8,14 @@ description: 'Bringing support for Apple and Nvidia GPUs to Luminal through comp Image credit: https://www.exxactcorp.com/

Image Credit: https://exxactcorp.com/

**Luminal compilers can now generate CUDA and Metal kernels on the fly, yielding specialized GPU compute for each model.** -In our day-to-day lives most computing is done on general purpose CPUs. The combination of ubuquity and flexibility makes them an attractive option for most software. However, certian types of software like graphics are very compute-intensive. CPUs execute a single stream of instructions, and therefore have very little (or no) parallelism, leading to slow performance and high power usage. +In our day-to-day lives most computing is done on general purpose CPUs. The combination of ubuquity and flexibility makes them an attractive option for most software. However, certian types of software like graphics are very compute-intensive, and since CPUs execute a single stream of instructions they have very little (or no) parallelism, leading to slow performance and high power usage. As graphics improved in the 80s and 90s, especially with the onset of 3D graphics, specialized hardware was required to render complex scenes at reasonable speed. Companies like Nvidia began releasing specialized chips able to do massively parallel compute, which served graphics applications well since individual pixels tend not to depend on other pixels. @@ -39,16 +39,62 @@ This kernel gets ran for each element of the input arrays, all in parallel. ## Compiler flow The typical approach in Luminal for supporting new backends would be: -1) Swap out each primop with a backend-specific primop. -2) Add in operations to copy to device and copy from device before and after Function ops. -3) Pattern-match to swap out chunks of ops with specialized variants. +1) Swap out each primitive operation with a backend-specific operation. +2) Add in operations to copy to device and copy from device before and after Function operations. +3) Pattern-match to swap out chunks of operations with specialized variants. 4) All other optimizations. -So let's go through how we do this for the Metal backend to support Apple GPUs. +Since we looked at a CUDA kernel above, let's go through how we do this for the Metal backend to support Apple GPUs. ### Step 1: Metal Primops We want to generically support all possible models in Luminal, so our first step is to replicate all primitive operations with a Metal version. Since there are 11 primitive operations, we need 11 Metal ops. You can see these in `crates/luminal_metal/src/prim.rs`. The compiler simply loops through all ops in the graph and swaps them out with the Metal variant. +These primitive operations are very simple. Here's the [MetalExp2](https://github.com/jafioti/luminal/blob/d3178b3443ee7fc887f8f0988a77736b73e618d0/crates/luminal_metal/src/prim.rs#L346) op, slightly simplified for clarity: +```rust +#[derive(Clone)] +pub struct MetalExp2 { + ... +} +impl MetalExp2 { + pub fn new() -> Self { + let type_name = T::type_name(); + let code = format!(" +#include +using namespace metal; +kernel void mkernel(device {type_name} *inp [[buffer(0)]], device {type_name} *out [[buffer(1)]], device int& n_elements [[buffer(2)]], uint idx [[thread_position_in_grid]]) { + if (idx < n_elements) { + out[idx] = exp2(inp[idx]); + } +}"); + Self { + pipeline: compile_function("mkernel", &code), + ... + } + } +} +impl MetalKernel for MetalExp2 { + fn metal_forward( + &self, + inputs: &[(&Buffer, ShapeTracker)], + command_buffer: &CommandBufferRef, + output_buffers: &[&Buffer], + ) { + let inp_size = inputs[0].1.n_elements().to_usize().unwrap(); + let encoder = command_buffer.compute_command_encoder_with_descriptor(ComputePassDescriptor::new()); + encoder.set_compute_pipeline_state(&self.pipeline); + + // Set function inputs + encoder.set_buffer(0, Some(inputs[0].0), 0); + encoder.set_buffer(1, Some(output_buffers[0]), 0); + encoder.set_u32(2, inp_size as u32); + + // Execute + encoder.dispatch_1d(inp_size); + encoder.end_encoding(); + } +} +``` + ### Step 2: Moving data around Since our data always starts on the host device (normal RAM), we need to move it to GPU memory. This means we need to look at the remaining ops that produce data from the host, and insert a CopyToDevice op, and look at where we need to get data back to host and insert a CopyFromDevice op. @@ -69,7 +115,11 @@ Elementwise fusion does away with that and generates a single kernel that does ` This actually is taken much furthur, fusing unary operations, binary operations, across reshapes, permutes, expands, etc. Turns out we can get very far with this idea! Here's an example of how many ops fusion can merge together. On the left is the unfused graph, on the right is the functionally identical fused graph: -![image](/images/fusion.png) +The Luminal graph with and without kernel fusion ### Step 5: Buffer Sharing @@ -98,7 +148,7 @@ or this: ```rust cx.compile(<(GenericCompiler, CudaCompiler)>::default(), ()); ``` -and their entire model now runs on the GPU! +and the model now runs on the GPU! ## Wrapping up This level of flexibility is only afforded to us because compilers can handle so much complexity internally, with correctness guarentees due to the simplicity of the graph. diff --git a/docs/images/fusion.png b/docs/images/fusion.png index 4c8a7c47..9dacefab 100644 Binary files a/docs/images/fusion.png and b/docs/images/fusion.png differ diff --git a/docs/mint.json b/docs/mint.json index a220af5c..0e0d8055 100644 --- a/docs/mint.json +++ b/docs/mint.json @@ -64,7 +64,8 @@ { "group": "Blog", "pages": [ - "blog/intro" + "blog/intro", + "blog/gpu" ] }, {