From 5e3e69d109836b71f9bb90991f28a477501eb6ea Mon Sep 17 00:00:00 2001 From: Joe Fioti Date: Fri, 26 Apr 2024 20:20:17 -0500 Subject: [PATCH 01/12] Fixed cuda graph prints' : --- Cargo.toml | 2 +- crates/luminal_cuda/Cargo.toml | 1 - crates/luminal_cuda/src/binary.rs | 10 +- crates/luminal_cuda/src/lib.rs | 13 +- crates/luminal_cuda/src/matmul.rs | 4 +- crates/luminal_cuda/src/other.rs | 7 +- crates/luminal_cuda/src/prim.rs | 43 ++-- crates/luminal_cuda/src/quantized.rs | 7 +- examples/llama/src/model.rs | 2 +- examples/phi/.gitignore | 17 ++ examples/phi/Cargo.toml | 28 +++ examples/phi/src/gguf.rs | 302 +++++++++++++++++++++++ examples/phi/src/loader.rs | 227 ++++++++++++++++++ examples/phi/src/main.rs | 185 ++++++++++++++ examples/phi/src/model.rs | 345 +++++++++++++++++++++++++++ src/compiler_utils.rs | 1 + 16 files changed, 1162 insertions(+), 32 deletions(-) create mode 100644 examples/phi/.gitignore create mode 100644 examples/phi/Cargo.toml create mode 100644 examples/phi/src/gguf.rs create mode 100644 examples/phi/src/loader.rs create mode 100644 examples/phi/src/main.rs create mode 100644 examples/phi/src/model.rs diff --git a/Cargo.toml b/Cargo.toml index 106f3327..109fc7b0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,7 +12,7 @@ num-traits = "0.2.16" petgraph = "0.6.4" rand = "0.8.5" urlencoding = "2.1.2" -webbrowser = "0.8.10" +webbrowser = "1.0.0" dyn-clone = "1.0.12" half = "*" tinyvec = "1.6.0" diff --git a/crates/luminal_cuda/Cargo.toml b/crates/luminal_cuda/Cargo.toml index fe53041e..25c58bd0 100644 --- a/crates/luminal_cuda/Cargo.toml +++ b/crates/luminal_cuda/Cargo.toml @@ -16,7 +16,6 @@ luminal_cudarc = { version="0.10.0", features = [ itertools = "0.12.1" rustc-hash = "1.1.0" num-traits = "0.2.18" -fmt-derive = "0.1.1" [dev-dependencies] dfdx = { version = "0.13", features = ["f16"] } diff --git a/crates/luminal_cuda/src/binary.rs b/crates/luminal_cuda/src/binary.rs index a47e2aad..79fdde3c 100644 --- a/crates/luminal_cuda/src/binary.rs +++ b/crates/luminal_cuda/src/binary.rs @@ -1,6 +1,5 @@ use std::{marker::PhantomData, sync::Arc}; -use fmt_derive::Debug; use luminal_cudarc::driver::{CudaDevice, CudaFunction, DeviceRepr, LaunchAsync, LaunchConfig}; use luminal::{ @@ -16,7 +15,7 @@ use crate::{ render_dyn_dim_inputs, CudaData, CudaFloat, }; -#[derive(Clone, Debug)] +#[derive(Clone)] pub struct CudaSub { function: CudaFunction, device: Arc, @@ -24,6 +23,7 @@ pub struct CudaSub { dyn_map: *const FxHashMap, _phantom: PhantomData, } +crate::debug_type!(CudaSub); impl CudaSub { pub fn new( @@ -140,7 +140,7 @@ impl Compiler for SubtractionCompiler { } } -#[derive(Clone, Debug)] +#[derive(Clone)] pub struct CudaEqual { function: CudaFunction, device: Arc, @@ -148,6 +148,7 @@ pub struct CudaEqual { dyn_map: *const FxHashMap, _phantom: PhantomData, } +crate::debug_type!(CudaEqual); impl CudaEqual { pub fn new( @@ -263,13 +264,14 @@ impl Compiler for EqualCompiler { } } -#[derive(Clone, Debug)] +#[derive(Clone)] pub struct CudaGather { function: CudaFunction, device: Arc, pub embed_dim: usize, _phantom: PhantomData, } +crate::debug_type!(CudaGather); impl CudaGather { pub fn new(device: Arc, embed_dim: usize) -> Self { diff --git a/crates/luminal_cuda/src/lib.rs b/crates/luminal_cuda/src/lib.rs index d1b83d8e..fa4485f9 100644 --- a/crates/luminal_cuda/src/lib.rs +++ b/crates/luminal_cuda/src/lib.rs @@ -159,7 +159,7 @@ fn render_dyn_dim_inputs(shapes: &[ShapeTracker]) -> (Vec, String) { .into_iter() .flat_map(|i| [i.0.into(), i.1.into()]), ) - .chain(st.slices.into_iter().flat_map(|i| [i.0.into(), i.1.into()])) + .chain(st.mask.into_iter().flat_map(|i| [i.0.into(), i.1.into()])) }) .flat_map(|d| d.to_symbols()) .unique() @@ -235,3 +235,14 @@ fn compile_and_load_kernel(mut code: String, device: &Arc) -> CudaFu } device.get_func(&name, &name).unwrap() } + +#[macro_export] +macro_rules! debug_type { + ($t: ty) => { + impl std::fmt::Debug for $t { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, stringify!($t)) + } + } + }; +} diff --git a/crates/luminal_cuda/src/matmul.rs b/crates/luminal_cuda/src/matmul.rs index c550e792..0125cad3 100644 --- a/crates/luminal_cuda/src/matmul.rs +++ b/crates/luminal_cuda/src/matmul.rs @@ -1,6 +1,5 @@ use std::{marker::PhantomData, sync::Arc}; -use fmt_derive::Debug; use luminal_cudarc::{ cublas::{sys::cublasOperation_t::*, CudaBlas}, driver::{CudaDevice, DevicePtr, DevicePtrMut}, @@ -16,8 +15,9 @@ use luminal::{ prelude::*, }; -#[derive(Clone, Debug)] +#[derive(Clone)] pub struct Matmul(Arc, Arc, PhantomData); +crate::debug_type!(Matmul); impl Operator for Matmul { fn process(&mut self, inp: Vec<(InputTensor, ShapeTracker)>) -> Vec { diff --git a/crates/luminal_cuda/src/other.rs b/crates/luminal_cuda/src/other.rs index bb71feb1..f8598c63 100644 --- a/crates/luminal_cuda/src/other.rs +++ b/crates/luminal_cuda/src/other.rs @@ -1,9 +1,7 @@ use std::{marker::PhantomData, sync::Arc}; -use luminal_cudarc::driver::{CudaDevice, CudaFunction, LaunchAsync, LaunchConfig}; - -use fmt_derive::Debug; use luminal::prelude::*; +use luminal_cudarc::driver::{CudaDevice, CudaFunction, LaunchAsync, LaunchConfig}; use rustc_hash::FxHashMap; use crate::{ @@ -13,7 +11,7 @@ use crate::{ CudaData, CudaFloat, }; -#[derive(Clone, Debug)] +#[derive(Clone)] pub struct CudaARange { function: CudaFunction, device: Arc, @@ -21,6 +19,7 @@ pub struct CudaARange { dyn_map: *const FxHashMap, _phantom: PhantomData, } +crate::debug_type!(CudaARange); impl CudaARange { pub fn new( diff --git a/crates/luminal_cuda/src/prim.rs b/crates/luminal_cuda/src/prim.rs index 7a0426d3..ed4497be 100644 --- a/crates/luminal_cuda/src/prim.rs +++ b/crates/luminal_cuda/src/prim.rs @@ -1,7 +1,6 @@ use crate::{compile_and_load_kernel, get_buffer_from_tensor, input_dyn_dims, CudaData, CudaFloat}; use super::{get_idx_valid_exps, render_dyn_dim_inputs}; -use fmt_derive::Debug; use itertools::Itertools; use rustc_hash::FxHashMap; @@ -19,8 +18,9 @@ use luminal::{ }; /// Copy a tensor to the GPU -#[derive(Clone, Debug)] +#[derive(Clone)] pub struct CudaCopyToDevice(Arc, PhantomData); +crate::debug_type!(CudaCopyToDevice); impl CudaCopyToDevice { pub fn new(dev: Arc) -> Self { @@ -45,8 +45,9 @@ impl Operator for CudaCopyToDevice { } /// Copy a tensor from the GPU -#[derive(Clone, Debug)] +#[derive(Clone)] pub struct CudaCopyFromDevice(Arc, PhantomData); +crate::debug_type!(CudaCopyFromDevice); impl CudaCopyFromDevice { pub fn new(dev: Arc) -> Self { @@ -113,7 +114,7 @@ impl Operator for CudaConstant { } } -#[derive(Clone, Debug)] +#[derive(Clone)] pub struct CudaContiguous { function: CudaFunction, device: Arc, @@ -121,6 +122,7 @@ pub struct CudaContiguous { dyn_symbols: Vec, dyn_map: *const FxHashMap, } +crate::debug_type!(CudaContiguous); impl CudaContiguous { pub fn new( @@ -172,12 +174,13 @@ impl Operator for CudaContiguous { } } -#[derive(Clone, Debug)] +#[derive(Clone)] pub struct CudaLog2 { function: CudaFunction, device: Arc, _phantom: PhantomData, } +crate::debug_type!(CudaLog2); impl CudaLog2 { pub fn new(device: Arc) -> Self { @@ -227,12 +230,13 @@ impl Operator for CudaLog2 { } } -#[derive(Clone, Debug)] +#[derive(Clone)] pub struct CudaExp2 { function: CudaFunction, device: Arc, _phantom: PhantomData, } +crate::debug_type!(CudaExp2); impl CudaExp2 { pub fn new(device: Arc) -> Self { @@ -281,12 +285,13 @@ impl Operator for CudaExp2 { } } -#[derive(Clone, Debug)] +#[derive(Clone)] pub struct CudaSqrt { function: CudaFunction, device: Arc, _phantom: PhantomData, } +crate::debug_type!(CudaSqrt); impl CudaSqrt { pub fn new(device: Arc) -> Self { @@ -339,12 +344,13 @@ impl Operator for CudaSqrt { } } -#[derive(Clone, Debug)] +#[derive(Clone)] pub struct CudaSin { function: CudaFunction, device: Arc, _phantom: PhantomData, } +crate::debug_type!(CudaSin); impl CudaSin { pub fn new(device: Arc) -> Self { @@ -394,12 +400,13 @@ impl Operator for CudaSin { } } -#[derive(Clone, Debug)] +#[derive(Clone)] pub struct CudaRecip { function: CudaFunction, device: Arc, _phantom: PhantomData, } +crate::debug_type!(CudaRecip); impl CudaRecip { pub fn new(device: Arc) -> Self { @@ -453,7 +460,7 @@ impl Operator for CudaRecip { } } -#[derive(Clone, Debug)] +#[derive(Clone)] pub struct CudaAdd { function: CudaFunction, device: Arc, @@ -461,6 +468,7 @@ pub struct CudaAdd { dyn_symbols: Vec, dyn_map: *const FxHashMap, } +crate::debug_type!(CudaAdd); impl CudaAdd { pub fn new( @@ -526,7 +534,7 @@ impl Operator for CudaAdd { } } -#[derive(Clone, Debug)] +#[derive(Clone)] pub struct CudaMul { function: CudaFunction, device: Arc, @@ -534,6 +542,7 @@ pub struct CudaMul { dyn_symbols: Vec, dyn_map: *const FxHashMap, } +crate::debug_type!(CudaMul); impl CudaMul { pub fn new( @@ -596,7 +605,7 @@ impl Operator for CudaMul { } } -#[derive(Clone, Debug)] +#[derive(Clone)] pub struct CudaMod { function: CudaFunction, device: Arc, @@ -604,6 +613,7 @@ pub struct CudaMod { dyn_symbols: Vec, dyn_map: *const FxHashMap, } +crate::debug_type!(CudaMod); impl CudaMod { pub fn new( @@ -666,7 +676,7 @@ impl Operator for CudaMod { } } -#[derive(Clone, Debug)] +#[derive(Clone)] pub struct CudaLessThan { function: CudaFunction, device: Arc, @@ -674,6 +684,7 @@ pub struct CudaLessThan { dyn_symbols: Vec, dyn_map: *const FxHashMap, } +crate::debug_type!(CudaLessThan); impl CudaLessThan { pub fn new( @@ -742,7 +753,7 @@ impl Operator for CudaLessThan { } } -#[derive(Clone, Debug)] +#[derive(Clone)] pub struct CudaSumReduce { function: CudaFunction, pub device: Arc, @@ -751,6 +762,7 @@ pub struct CudaSumReduce { dyn_symbols: Vec, dyn_map: *const FxHashMap, } +crate::debug_type!(CudaSumReduce); impl CudaSumReduce { pub fn new( @@ -835,7 +847,7 @@ where } } -#[derive(Clone, Debug)] +#[derive(Clone)] pub struct CudaMaxReduce { function: CudaFunction, pub device: Arc, @@ -844,6 +856,7 @@ pub struct CudaMaxReduce { dyn_symbols: Vec, dyn_map: *const FxHashMap, } +crate::debug_type!(CudaMaxReduce); impl CudaMaxReduce { pub fn new( diff --git a/crates/luminal_cuda/src/quantized.rs b/crates/luminal_cuda/src/quantized.rs index 08d30982..7deb3a15 100644 --- a/crates/luminal_cuda/src/quantized.rs +++ b/crates/luminal_cuda/src/quantized.rs @@ -1,6 +1,5 @@ use std::{marker::PhantomData, sync::Arc}; -use fmt_derive::Debug; use luminal_cudarc::driver::{CudaDevice, CudaFunction, DeviceRepr, LaunchAsync, LaunchConfig}; use petgraph::visit::EdgeRef; @@ -14,12 +13,13 @@ use crate::{ }; /// Multiplies a BxMxK matrix with a KxN matrix, resulting in a BxMxN matrix. This expects the first input to be a quantized 2D matrix -#[derive(Clone, Debug)] +#[derive(Clone)] pub struct QuantizedMatmul { matvec_function: CudaFunction, device: Arc, _phantom: PhantomData, } +crate::debug_type!(QuantizedMatmul); impl QuantizedMatmul { fn new(device: Arc) -> Self { @@ -170,13 +170,14 @@ impl Operator for QuantizedMatmul { } } -#[derive(Clone, Debug)] +#[derive(Clone)] pub struct QuantizedGather { pipeline: CudaFunction, device: Arc, embed_dim: usize, _phantom: PhantomData, } +crate::debug_type!(QuantizedGather); impl QuantizedGather { fn new(device: Arc, embed_dim: usize) -> Self { diff --git a/examples/llama/src/model.rs b/examples/llama/src/model.rs index 90760362..ecab6862 100644 --- a/examples/llama/src/model.rs +++ b/examples/llama/src/model.rs @@ -6,7 +6,7 @@ use luminal_nn::{Embedding, PermutedLinear, RMSNorm}; // Llama3 8B Config pub const VOCAB_SIZE: usize = 128256; pub const HIDDEN_DIM: usize = 4096; -pub const NUM_LAYERS: usize = 32; +pub const NUM_LAYERS: usize = 1; pub const N_HEADS: usize = 32; pub const N_KV_HEADS: usize = 8; pub const MLP_DIM: usize = 14336; diff --git a/examples/phi/.gitignore b/examples/phi/.gitignore new file mode 100644 index 00000000..199cec8e --- /dev/null +++ b/examples/phi/.gitignore @@ -0,0 +1,17 @@ +# Generated by Cargo +# will have compiled files and executables +debug/ +target/ + +# Remove Cargo.lock from gitignore if creating an executable, leave it for libraries +# More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html +Cargo.lock + +# These are backup files generated by rustfmt +**/*.rs.bk + +# MSVC Windows builds of rustc generate these, which store debugging information +*.pdb +setup/*.gguf +setup/*.json +.vscode \ No newline at end of file diff --git a/examples/phi/Cargo.toml b/examples/phi/Cargo.toml new file mode 100644 index 00000000..409e0962 --- /dev/null +++ b/examples/phi/Cargo.toml @@ -0,0 +1,28 @@ +[package] +name = "phi" +version = "0.1.0" +edition = "2021" + +[features] +metal = ["dep:luminal_metal", "dep:metal-rs"] +cuda = ["dep:luminal_cuda", "dep:luminal_cudarc"] + +[dependencies] +luminal = { path = "../.." } +luminal_nn = { path = "../../crates/luminal_nn" } +luminal_cpu = { path = "../../crates/luminal_cpu"} +luminal_metal = { path = "../../crates/luminal_metal", optional = true } +luminal_cuda = { path = "../../crates/luminal_cuda", optional = true } +clap = { version = "4.4.18", features = ["derive"] } +byteorder = "1.5.0" +memmap2 = "0.9.4" +metal-rs = { version = "0.27.0", package = "metal", features = [ + "mps", +], optional = true } +colored = "2.1.0" +itertools = "0.12.1" +luminal_cudarc = { version="0.10.0", features = [ + "cublas", + "f16", +], optional=true} +tokenizers = "0.15.2" diff --git a/examples/phi/src/gguf.rs b/examples/phi/src/gguf.rs new file mode 100644 index 00000000..ad4f94bd --- /dev/null +++ b/examples/phi/src/gguf.rs @@ -0,0 +1,302 @@ +//! Support for the GGUF file format. +//! +//! Spec: https://github.com/philpax/ggml/blob/gguf-spec/docs/gguf.md + +use byteorder::{LittleEndian, ReadBytesExt}; +use std::collections::HashMap; + +pub const DEFAULT_ALIGNMENT: u64 = 32; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum Magic { + Gguf, +} + +impl TryFrom for Magic { + type Error = (); + fn try_from(value: u32) -> Result { + let magic = match value { + 0x46554747 | 0x47475546 => Self::Gguf, + _ => panic!("unknown magic 0x{value:08x}"), + }; + Ok(magic) + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum VersionedMagic { + GgufV1, + GgufV2, + GgufV3, +} + +impl VersionedMagic { + pub fn read(reader: &mut R) -> Result { + let magic = reader.read_u32::().unwrap(); + let magic = Magic::try_from(magic).unwrap(); + let version = reader.read_u32::().unwrap(); + let versioned_magic = match (magic, version) { + (Magic::Gguf, 1) => Self::GgufV1, + (Magic::Gguf, 2) => Self::GgufV2, + (Magic::Gguf, 3) => Self::GgufV3, + _ => panic!("gguf: unsupported magic/version {magic:?}/{version}"), + }; + Ok(versioned_magic) + } +} + +#[derive(Debug)] +pub struct Content { + pub magic: VersionedMagic, + pub metadata: HashMap, + pub tensor_infos: HashMap, // buffer size and offset + pub tensor_data_offset: u64, +} + +pub fn read_string(reader: &mut R, magic: &VersionedMagic) -> Result { + let len = match magic { + VersionedMagic::GgufV1 => reader.read_u32::().unwrap() as usize, + VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => { + reader.read_u64::().unwrap() as usize + } + }; + let mut v = vec![0u8; len]; + reader.read_exact(&mut v).unwrap(); + // GGUF strings are supposed to be non-null terminated but in practice this happens. + while let Some(0) = v.last() { + v.pop(); + } + // GGUF strings are utf8 encoded but there are cases that don't seem to be valid. + Ok(String::from_utf8_lossy(&v).into_owned()) +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum ValueType { + // The value is a 8-bit unsigned integer. + U8, + // The value is a 8-bit signed integer. + I8, + // The value is a 16-bit unsigned little-endian integer. + U16, + // The value is a 16-bit signed little-endian integer. + I16, + // The value is a 32-bit unsigned little-endian integer. + U32, + // The value is a 32-bit signed little-endian integer. + I32, + // The value is a 64-bit unsigned little-endian integer. + U64, + // The value is a 64-bit signed little-endian integer. + I64, + // The value is a 32-bit IEEE754 floating point number. + F32, + // The value is a 64-bit IEEE754 floating point number. + F64, + // The value is a boolean. + // 1-byte value where 0 is false and 1 is true. + // Anything else is invalid, and should be treated as either the model being invalid or the reader being buggy. + Bool, + // The value is a UTF-8 non-null-terminated string, with length prepended. + String, + // The value is an array of other values, with the length and type prepended. + /// + // Arrays can be nested, and the length of the array is the number of elements in the array, not the number of bytes. + Array, +} + +#[derive(Debug, Clone)] +pub enum Value { + U8(u8), + I8(i8), + U16(u16), + I16(i16), + U32(u32), + I32(i32), + U64(u64), + I64(i64), + F32(f32), + F64(f64), + Bool(bool), + String(String), + Array(Vec), +} + +impl Value { + pub fn read( + reader: &mut R, + value_type: ValueType, + magic: &VersionedMagic, + ) -> Result { + let v = match value_type { + ValueType::U8 => Self::U8(reader.read_u8().unwrap()), + ValueType::I8 => Self::I8(reader.read_i8().unwrap()), + ValueType::U16 => Self::U16(reader.read_u16::().unwrap()), + ValueType::I16 => Self::I16(reader.read_i16::().unwrap()), + ValueType::U32 => Self::U32(reader.read_u32::().unwrap()), + ValueType::I32 => Self::I32(reader.read_i32::().unwrap()), + ValueType::U64 => Self::U64(reader.read_u64::().unwrap()), + ValueType::I64 => Self::I64(reader.read_i64::().unwrap()), + ValueType::F32 => Self::F32(reader.read_f32::().unwrap()), + ValueType::F64 => Self::F64(reader.read_f64::().unwrap()), + ValueType::Bool => match reader.read_u8().unwrap() { + 0 => Self::Bool(false), + 1 => Self::Bool(true), + b => panic!("unexpected bool value {b}"), + }, + ValueType::String => Self::String(read_string(reader, magic).unwrap()), + ValueType::Array => { + let value_type = reader.read_u32::().unwrap(); + let value_type = ValueType::from_u32(value_type).unwrap(); + let len = match magic { + VersionedMagic::GgufV1 => reader.read_u32::().unwrap() as usize, + VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => { + reader.read_u64::().unwrap() as usize + } + }; + let mut vs = Vec::with_capacity(len); + for _ in 0..len { + vs.push(Value::read(reader, value_type, magic).unwrap()) + } + Self::Array(vs) + } + }; + Ok(v) + } +} + +impl ValueType { + pub fn from_u32(v: u32) -> Result { + let v = match v { + 0 => Self::U8, + 1 => Self::I8, + 2 => Self::U16, + 3 => Self::I16, + 4 => Self::U32, + 5 => Self::I32, + 6 => Self::F32, + 7 => Self::Bool, + 8 => Self::String, + 9 => Self::Array, + 10 => Self::U64, + 11 => Self::I64, + 12 => Self::F64, + v => panic!("unrecognized value-type {v:#08x}"), + }; + Ok(v) + } +} + +impl Content { + pub fn read(reader: &mut R) -> Result { + let magic = VersionedMagic::read(reader).unwrap(); + + let tensor_count = match magic { + VersionedMagic::GgufV1 => reader.read_u32::().unwrap() as usize, + VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => { + reader.read_u64::().unwrap() as usize + } + }; + let metadata_kv_count = match magic { + VersionedMagic::GgufV1 => reader.read_u32::().unwrap() as usize, + VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => { + reader.read_u64::().unwrap() as usize + } + }; + + // Read metadata + let mut metadata = HashMap::new(); + for _idx in 0..metadata_kv_count { + let key = read_string(reader, &magic).unwrap(); + let value_type = reader.read_u32::().unwrap(); + let value_type = ValueType::from_u32(value_type).unwrap(); + let value = Value::read(reader, value_type, &magic).unwrap(); + metadata.insert(key, value); + } + // Read tensor infos + let mut tensor_infos = HashMap::new(); + for _idx in 0..tensor_count { + let tensor_name = read_string(reader, &magic).unwrap(); + let n_dimensions = reader.read_u32::().unwrap(); + let n_elements = match magic { + VersionedMagic::GgufV1 => { + let mut dimensions = vec![0; n_dimensions as usize]; + reader + .read_u32_into::(&mut dimensions) + .unwrap(); + dimensions.into_iter().map(|c| c as usize).product() + } + VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => { + let mut dimensions = vec![0; n_dimensions as usize]; + reader + .read_u64_into::(&mut dimensions) + .unwrap(); + dimensions.into_iter().map(|c| c as usize).product() + } + }; + + let ggml_dtype = reader.read_u32::().unwrap(); + let offset = reader.read_u64::().unwrap(); + tensor_infos.insert( + tensor_name, + (n_elements, offset as usize, GgmlDType::from_u32(ggml_dtype)), + ); + } + let position = reader.stream_position().unwrap(); + let alignment = match metadata.get("general.alignment") { + Some(Value::U8(v)) => *v as u64, + Some(Value::U16(v)) => *v as u64, + Some(Value::U32(v)) => *v as u64, + Some(Value::I8(v)) if *v >= 0 => *v as u64, + Some(Value::I16(v)) if *v >= 0 => *v as u64, + Some(Value::I32(v)) if *v >= 0 => *v as u64, + _ => DEFAULT_ALIGNMENT, + }; + let tensor_data_offset = (position + alignment - 1) / alignment * alignment; + Ok(Self { + magic, + metadata, + tensor_infos, + tensor_data_offset, + }) + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum GgmlDType { + F32, + F16, + Q4_0, + Q4_1, + Q5_0, + Q5_1, + Q8_0, + Q8_1, + Q2K, + Q3K, + Q4K, + Q5K, + Q6K, + Q8K, +} + +impl GgmlDType { + fn from_u32(u: u32) -> Self { + match u { + 0 => Self::F32, + 1 => Self::F16, + 2 => Self::Q4_0, + 3 => Self::Q4_1, + 6 => Self::Q5_0, + 7 => Self::Q5_1, + 8 => Self::Q8_0, + 9 => Self::Q8_1, + 10 => Self::Q2K, + 11 => Self::Q3K, + 12 => Self::Q4K, + 13 => Self::Q5K, + 14 => Self::Q6K, + 15 => Self::Q8K, + _ => panic!("unknown dtype for tensor {u}"), + } + } +} diff --git a/examples/phi/src/loader.rs b/examples/phi/src/loader.rs new file mode 100644 index 00000000..5e8ff96e --- /dev/null +++ b/examples/phi/src/loader.rs @@ -0,0 +1,227 @@ +use std::fs::File; +use std::path::Path; + +use luminal::{op::Function, prelude::*}; + +#[cfg(feature = "cuda")] +use {luminal_cuda::CudaData, luminal_cudarc::driver::CudaDevice}; + +use crate::gguf::*; + +#[cfg(not(feature = "metal"))] +use { + itertools::Itertools, + std::io::{Read, Seek}, +}; +#[cfg(feature = "metal")] +use { + luminal_metal::MetalBuffer, + memmap2::Mmap, + metal_rs::{Device, MTLResourceOptions}, +}; + +#[cfg(feature = "metal")] +pub fn q8_load, M: SerializeModule>( + path: P, + model: &M, + graph: &mut Graph, +) -> Vec { + // Read metadata from file + let mut reader = File::open(&path).unwrap(); + let Content { + mut tensor_infos, + tensor_data_offset, + .. + } = Content::read(&mut reader).unwrap(); + + // Create weight loading closures + let mut q8_weights = vec![]; + for (weight_name, node_index) in param_dict(model) { + if let Some(loading_node) = graph + .graph + .node_weight_mut(node_index) + .and_then(|op| op.as_any_mut().downcast_mut::()) + { + let file_path = path.as_ref().to_owned(); + let (n_elements, buffer_offset, data_type) = + tensor_infos.remove(&weight_name.replace('/', ".")).unwrap(); + let n_bytes = match data_type { + GgmlDType::F32 => n_elements * 4, + GgmlDType::Q8_0 => { + q8_weights.push(node_index); + n_elements + (n_elements / 16) + } + _ => panic!("Unsupported dtype: {data_type:?}"), + }; + loading_node.1 = Box::new(move |_| { + let mmap_buffer = unsafe { Mmap::map(&File::open(&file_path).unwrap()).unwrap() }; + let buffer = Device::system_default() + .unwrap() + .new_buffer_with_bytes_no_copy( + unsafe { + mmap_buffer + .as_ptr() + .add(buffer_offset + tensor_data_offset as usize) + as *const _ + }, + n_bytes as u64, + MTLResourceOptions::StorageModeShared, + None, + ); + vec![Tensor::new(MetalBuffer(buffer))] + }); + } + } + q8_weights +} + +#[cfg(feature = "cuda")] +pub fn q8_load, M: SerializeModule>( + path: P, + model: &M, + graph: &mut Graph, +) -> Vec { + // Read metadata from file + let mut reader = File::open(&path).unwrap(); + let Content { + mut tensor_infos, + tensor_data_offset, + .. + } = Content::read(&mut reader).unwrap(); + + // Create weight loading closures + let mut q8_weights = vec![]; + for (weight_name, node_index) in param_dict(model) { + if let Some(loading_node) = graph + .graph + .node_weight_mut(node_index) + .and_then(|op| op.as_any_mut().downcast_mut::()) + { + let file_path = path.as_ref().to_owned(); + let (n_elements, buffer_offset, data_type) = + tensor_infos.remove(&weight_name.replace('/', ".")).unwrap(); + let n_bytes = match data_type { + GgmlDType::F32 => n_elements * 4, + GgmlDType::Q8_0 => { + q8_weights.push(node_index); + n_elements + (n_elements / 16) + } + _ => panic!("Unsupported dtype: {data_type:?}"), + }; + loading_node.1 = Box::new(move |_| { + // Read bytes + let mut bytes = vec![0; n_bytes]; + let mut file = File::open(&file_path).unwrap(); + file.seek(std::io::SeekFrom::Start( + buffer_offset as u64 + tensor_data_offset, + )) + .unwrap(); + file.read_exact(&mut bytes).unwrap(); + // Copy buffer over to cuda slice + let device = CudaDevice::new(0).unwrap(); + match data_type { + GgmlDType::F32 => vec![Tensor::new( + bytes + .into_iter() + .chunks(4) + .into_iter() + .map(|c| { + let c = c.collect::>(); + f32::from_le_bytes([c[0], c[1], c[2], c[3]]) + }) + .collect::>(), + )], + GgmlDType::Q8_0 => vec![Tensor::new(CudaData( + device.htod_sync_copy::(&bytes).unwrap(), + ))], + _ => unimplemented!(), + } + }); + } + } + q8_weights +} + +#[cfg(all(not(feature = "metal"), not(feature = "cuda")))] +pub fn q8_load, M: SerializeModule>( + path: P, + model: &M, + graph: &mut Graph, +) -> Vec { + #[repr(C, packed)] + #[derive(Clone, Copy)] + struct Q8Block { + delta: f16, + weights: [i8; 32], + } + + // Read metadata from file + let mut reader = File::open(&path).unwrap(); + let Content { + mut tensor_infos, + tensor_data_offset, + .. + } = Content::read(&mut reader).unwrap(); + + // Create weight loading closures + let mut q8_weights = vec![]; + for (weight_name, node_index) in param_dict(model) { + if let Some(loading_node) = graph + .graph + .node_weight_mut(node_index) + .and_then(|op| op.as_any_mut().downcast_mut::()) + { + let file_path = path.as_ref().to_owned(); + let (n_elements, buffer_offset, data_type) = + tensor_infos.remove(&weight_name.replace('/', ".")).unwrap(); + let n_bytes = match data_type { + GgmlDType::F32 => n_elements * 4, + GgmlDType::Q8_0 => { + q8_weights.push(node_index); + n_elements + (n_elements / 16) + } + _ => panic!("Unsupported dtype: {data_type:?}"), + }; + loading_node.1 = Box::new(move |_| { + // Load all bytes + let mut bytes = vec![0; n_bytes]; + let mut file = File::open(&file_path).unwrap(); + file.seek(std::io::SeekFrom::Start( + buffer_offset as u64 + tensor_data_offset, + )) + .unwrap(); + file.read_exact(&mut bytes).unwrap(); + // Dequantize into f32 + let data: Vec = match data_type { + GgmlDType::F32 => bytes + .into_iter() + .chunks(4) + .into_iter() + .map(|c| { + let c = c.collect::>(); + f32::from_le_bytes([c[0], c[1], c[2], c[3]]) + }) + .collect(), + GgmlDType::Q8_0 => bytes + .into_iter() + .chunks(34) + .into_iter() + .map(|c| { + let chunk = c.collect::>(); + unsafe { chunk.align_to::().1[0] } + }) + .flat_map(|chunk| { + chunk + .weights + .into_iter() + .map(move |i| i as f32 * chunk.delta.to_f32()) + }) + .collect(), + _ => panic!("Unsupported dtype: {data_type:?}"), + }; + vec![Tensor::new(data)] + }); + } + } + q8_weights +} diff --git a/examples/phi/src/main.rs b/examples/phi/src/main.rs new file mode 100644 index 00000000..31ac7595 --- /dev/null +++ b/examples/phi/src/main.rs @@ -0,0 +1,185 @@ +use std::{ + io::{self, Write}, + marker::PhantomData, + time::Instant, +}; + +use clap::Parser; +use colored::Colorize; +use itertools::Itertools; +use tokenizers::Tokenizer; + +mod gguf; +mod loader; +mod model; + +use crate::model::KVCache; +use luminal::prelude::*; + +// Command args parser +#[derive(Debug, Parser)] +#[command(author, version, about, long_about = None)] +pub struct CLIArgs { + /// Number of tokens to generate + #[clap(short = 't', long = "gen_tokens", default_value = "128")] + gen_tokens: i32, + + /// Prompt for the model + #[clap(short = 'p', long = "prompt", default_value = include_str!("../../llama/prompts/merge_sort.txt"))] + prompt: String, +} + +fn main() { + let cli_args = CLIArgs::parse(); + let tokenizer = Tokenizer::from_file("setup/tokenizer.json").unwrap(); + + print!("Defining graph"); + io::stdout().flush().unwrap(); + let now = Instant::now(); + + // Set up graph + let mut cx = Graph::new(); + let mut input = cx.named_tensor::<(Const<1>, Dyn<'s'>)>("Input"); + let mut cache_src: Vec, Dyn<'p'>>> = (0..model::NUM_LAYERS) + .map(|_| (cx.named_tensor("Key Cache"), cx.named_tensor("Value Cache"))) + .collect(); + cache_src.set_dyn(vec![], &[1, model::N_KV_HEADS, 0, model::HEAD_DIM]); + let model = model::MistralLM::initialize(&mut cx); + let mut model_weights = downstream(params(&model), &cx); + cx.keep_tensors(&model_weights); + let (logits, mut cache_dest) = model.forward((input, &cache_src, PhantomData::>)); + let mut logits = logits + .slice((.., (Expression::from('s') - 1).., ..)) + .retrieve(); + cache_dest.keep(); + + // Set up model loading + #[cfg(any(feature = "metal", feature = "cuda"))] + let q_weights = loader::q8_load("setup/phi-3-mini-4k-instruct.Q8_0.gguf", &model, &mut cx); + #[cfg(all(not(feature = "metal"), not(feature = "cuda")))] + loader::q8_load("setup/phi-3-mini-4k-instruct.Q8_0.gguf", &model, &mut cx); + println!("\t\t - {}ms", now.elapsed().as_millis()); + + print!("Compiling graph"); + io::stdout().flush().unwrap(); + let now = Instant::now(); + cx.compile( + ( + GenericCompiler::default(), + #[cfg(feature = "metal")] + luminal_metal::quantized::MetalQuantizedCompiler::::new(q_weights), + #[cfg(feature = "cuda")] + luminal_cuda::CudaQuantizedCompiler::::new(q_weights), + #[cfg(all(not(feature = "metal"), not(feature = "cuda")))] + luminal_cpu::CPUCompiler::default(), + ), + ( + &mut input, + &mut logits, + &mut cache_src, + &mut cache_dest, + &mut model_weights, + ), + ); + cx.display(); + let cache_src_set = downstream(&cache_src, &cx); + let cache_dest_set = cache_dest.to_ids(); + println!("\t\t - {}ms", now.elapsed().as_millis()); + + // Initial forward pass to load weights + print!("Loading model"); + io::stdout().flush().unwrap(); + let now = Instant::now(); + input.set_dyn(vec![1.], &[1, 1]); + cx.set_dyn_dim('t', 1); + cx.execute(); + logits.drop(); + cache_dest.drop(); + println!("\t\t - {}ms", now.elapsed().as_millis()); + + // Now that weights are loaded, delete the loading nodes so they don't run again + delete_inputs(&model_weights, &mut cx); + // Run prompt processing pass + let mut input_ids = tokenizer + .encode(&cli_args.prompt as &str, false) + .unwrap() + .get_ids() + .to_vec(); + input_ids.insert(0, 1); + input.set_dyn( + input_ids.iter().map(|i| *i as f32).collect::>(), + &[1, input_ids.len()], + ); + cx.set_dyn_dim('t', input_ids.len()); + print!("Processing Prompt"); + io::stdout().flush().unwrap(); + let now = Instant::now(); + cx.execute(); + let elapsed_ms = now.elapsed().as_millis(); + println!( + "\t - {elapsed_ms}ms ({:.2} tok/s, {} prompt tokens)", + 1000.0 * (input_ids.len() as f64) / (elapsed_ms as f64), + input_ids.len() + ); + delete_inputs(&cache_src_set, &mut cx); + let mut output_ids = vec![sample_index(&logits.data())]; + logits.drop(); + + // Decode token + print!("{}", cli_args.prompt.white().bold()); + print!( + "{}", + tokenizer.decode(&output_ids, false).unwrap().bright_green() + ); + io::stdout().flush().unwrap(); + + // Swap caches + transfer_data_same_graph(&cache_dest_set, &cache_src_set, &mut cx); + + // Decode loop + let start_decode = std::time::Instant::now(); + let mut prev_output_len = 0; + for _ in 0..cli_args.gen_tokens { + input.set_dyn(vec![*output_ids.last().unwrap() as f32], &[1, 1]); + cx.set_dyn_dim('p', input_ids.len() + output_ids.len() - 1); + cx.set_dyn_dim('t', input_ids.len() + output_ids.len()); + cx.execute(); + + // Sample tokens + let output_id = sample_index(&logits.data()); + println!("{:?}", &logits.data()[..10]); + logits.drop(); + output_ids.push(output_id); + + // Get the current decoded output + let current_output = tokenizer.decode(&output_ids, false).unwrap(); + + // Print the new substring added to the decoded output + let new_substring = ¤t_output[prev_output_len..]; + print!("{}", new_substring.bright_green()); + io::stdout().flush().unwrap(); + + // Update the previous output + prev_output_len = current_output.len(); + + // Swap caches + transfer_data_same_graph(&cache_dest_set, &cache_src_set, &mut cx); + } + + println!(); + let avg_token_time = (std::time::Instant::now() - start_decode).as_micros() as f32 + / (output_ids.len() - 1) as f32 + / 1000.0; + println!( + "\nAverage token generated in {:.2}ms\t - ({:.2} tok/s)", + avg_token_time, + 1000.0 / avg_token_time + ); +} + +// Currently just an argmax, do actual sampling here +fn sample_index(dist: &[f32]) -> u32 { + dist.iter() + .position_max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)) + .unwrap() as u32 +} diff --git a/examples/phi/src/model.rs b/examples/phi/src/model.rs new file mode 100644 index 00000000..94cb3e2b --- /dev/null +++ b/examples/phi/src/model.rs @@ -0,0 +1,345 @@ +use std::{marker::PhantomData, ops::Div}; + +use luminal::prelude::{binary::F32Pow, *}; +use luminal_nn::{Embedding, PermutedLinear, RMSNorm}; + +// Llama3 8B Config +pub const VOCAB_SIZE: usize = 32064; +pub const HIDDEN_DIM: usize = 3072; +pub const NUM_LAYERS: usize = 1; +pub const N_HEADS: usize = 32; +pub const N_KV_HEADS: usize = 8; +pub const MLP_DIM: usize = 8192; + +pub const N_ATTENTION_GROUPS: usize = N_HEADS / N_KV_HEADS; +pub const HEAD_DIM: usize = HIDDEN_DIM / N_HEADS; +pub const HEAD_DIM_OVER_2: usize = HEAD_DIM / 2; +pub const ATTN_PROJ_DIM: usize = HEAD_DIM * N_KV_HEADS; + +pub type KVCache = ( + GraphTensor<(Batch, Const, Seq, Const)>, + GraphTensor<(Batch, Const, Seq, Const)>, +); + +pub struct Mlp { + pub gate_proj: PermutedLinear, + pub down_proj: PermutedLinear, + pub up_proj: PermutedLinear, +} + +impl Module> for Mlp +where + GraphTensor: Matmul, Output = GraphTensor>, + GraphTensor: Matmul, Output = GraphTensor>, +{ + type Output = GraphTensor; + + fn forward(&self, input: GraphTensor) -> Self::Output { + let gate = self.gate_proj.forward(input).swish(); + let up = self.up_proj.forward(input) * gate; + self.down_proj.forward(up) + } +} + +impl InitModule for Mlp { + fn initialize(cx: &mut Graph) -> Self { + Self { + gate_proj: PermutedLinear { + weight: cx.named_tensor("Gate"), + }, + up_proj: PermutedLinear { + weight: cx.named_tensor("Up"), + }, + down_proj: PermutedLinear { + weight: cx.named_tensor("Down"), + }, + } + } +} + +impl SerializeModule for Mlp { + fn serialize(&self, s: &mut Serializer) { + s.module("ffn_gate", &self.gate_proj); + s.module("ffn_up", &self.up_proj); + s.module("ffn_down", &self.down_proj); + } +} + +fn apply_rotary_embeddings_ggml( + input: GraphTensor<(Batch, Const, Seq, Const)>, + prev_seq: BigExpression, +) -> GraphTensor<(Batch, Const, Seq, Const)> { + // Get freqs + let freqs = (input.graph().arange::>() * 2.0) / (HEAD_DIM as f32); + let freqs = 500000_f32.pow(freqs); + let pos = input.graph().arange::() + prev_seq; + let emb = pos.expand::<(_, Const<1>), _>().matmul(freqs.expand()); + + // Split input into evens and odds + let split = input.reshape::<(Batch, Const, Seq, Const, Const<2>)>(); + let x0: GraphTensor<(Batch, Const, Seq, Const, Const<1>)> = split + .slice((.., .., .., .., ..Expression::from(1))) + .realize(); + let x1: GraphTensor<(Batch, Const, Seq, Const, Const<1>)> = split + .slice((.., .., .., .., Expression::from(1)..)) + .realize(); + + // Apply sin and cos embeddings + let x0_out = x0 * emb.cos().expand() - x1 * emb.sin().expand(); + let x1_out = x0 * emb.sin().expand() + x1 * emb.cos().expand(); + + // Combine back into output + x0_out + .concat_along::<(Batch, Const, Seq, Const, Const<2>), Axis<4>, _>( + x1_out, + ) + .reshape() +} + +pub struct SelfAttention { + pub q_proj: GraphTensor>, + pub k_proj: GraphTensor>, + pub v_proj: GraphTensor>, + pub o_proj: GraphTensor>, +} + +impl + Module<( + GraphTensor<(Batch, CurSeq, Const)>, + KVCache, + PhantomData, + )> for SelfAttention +{ + type Output = ( + GraphTensor<(Batch, CurSeq, Const)>, + KVCache, + ); + fn forward( + &self, + (x, (k_cache, v_cache), _): ( + GraphTensor<(Batch, CurSeq, Const)>, + KVCache, + PhantomData, + ), + ) -> Self::Output { + // Apply the Projections + let queries = x + .matmul(self.q_proj.permute()) + .reshape::<(Batch, CurSeq, Const, Const)>() + .permute::<_, Axes4<0, 2, 1, 3>>(); + + let keys = x + .matmul(self.k_proj.permute()) + .reshape::<(Batch, CurSeq, Const, Const)>() + .permute::<_, Axes4<0, 2, 1, 3>>(); + + let values = x + .matmul(self.v_proj.permute()) + .reshape::<(Batch, CurSeq, Const, Const)>() + .permute::<_, Axes4<0, 2, 1, 3>>(); + + // Rotary embed queries and keys + let queries = apply_rotary_embeddings_ggml(queries, PrevSeq::const_size().into()); + let keys = apply_rotary_embeddings_ggml(keys, PrevSeq::const_size().into()); + + // Add KV cache + let keys = k_cache.concat_along::<_, Axis<2>, _>(keys); + let values = v_cache.concat_along::<_, Axis<2>, _>(values); + + // Repeat the KV States for Grouped-Query Attention + let repeated_keys = keys.expand::<(_, _, Const, _, _), _>(); + let repeated_values = values.expand::<(_, _, Const, _, _), _>(); + + // Calculate attention weights + let mut attention_weights = queries + .reshape::<(_, Const, Const, _, _)>() // Split query heads into groups + .matmul(repeated_keys.permute()) + .div((HEAD_DIM as f32).sqrt()); + + let attention_mask = self.k_proj.graph().triu::(1) * f16::MIN.to_f32(); + attention_weights += attention_mask + .pad::<(CurSeq, TotSeq), _, _>(&[ + (0.into(), Expression::from(0)), + (TotSeq::const_size() - CurSeq::const_size(), 0.into()), + ]) + .expand(); + + // Calculate final outputs + let output = attention_weights + .softmax::>() + // Apply distribution to values + .matmul(repeated_values) + // Merge heads + .permute::<_, Axes5<0, 3, 1, 2, 4>>() + .reshape::<(Batch, CurSeq, Const)>(); + let output = output + // Apply output projection + .matmul(self.o_proj.permute()); + (output, (keys.contiguous(), values.contiguous())) // Cache needs to be contiguous for transferring to another graph + } +} + +impl InitModule for SelfAttention { + fn initialize(cx: &mut Graph) -> Self { + Self { + q_proj: cx.named_tensor("Q Proj"), + k_proj: cx.named_tensor("K Proj"), + v_proj: cx.named_tensor("V Proj"), + o_proj: cx.named_tensor("O Proj"), + } + } +} + +impl SerializeModule for SelfAttention { + fn serialize(&self, s: &mut Serializer) { + s.tensor("attn_q/weight", self.q_proj); + s.tensor("attn_v/weight", self.v_proj); + s.tensor("attn_k/weight", self.k_proj); + s.tensor("attn_output/weight", self.o_proj); + } +} + +pub struct TransformerBlock { + pub attention: SelfAttention, + pub attention_norm: RMSNorm, + pub feed_forward: Mlp, + pub feed_forward_norm: RMSNorm, +} + +impl + Module<( + GraphTensor<(Batch, CurSeq, Const)>, + KVCache, + PhantomData, + )> for TransformerBlock +{ + type Output = ( + GraphTensor<(Batch, CurSeq, Const)>, + KVCache, + ); + fn forward( + &self, + (mut x, cache, _): ( + GraphTensor<(Batch, CurSeq, Const)>, + KVCache, + PhantomData, + ), + ) -> Self::Output { + // Attention + let normed = self.attention_norm.forward(x); + let (y, cache) = self + .attention + .forward((normed, cache, PhantomData::)); + + // Residual Addition + x += y; + + // Feed Forward + let y = self.feed_forward.forward(self.feed_forward_norm.forward(x)); + + // Residual Addition + (x + y, cache) + } +} + +impl InitModule for TransformerBlock { + fn initialize(cx: &mut Graph) -> Self { + Self { + attention: InitModule::initialize(cx), + attention_norm: RMSNorm { + weight: cx.named_tensor("RMS Norm Weight"), + epsilon: 1e-5, + }, + feed_forward: InitModule::initialize(cx), + feed_forward_norm: RMSNorm { + weight: cx.named_tensor("RMS Norm Weight"), + epsilon: 1e-5, + }, + } + } +} + +impl SerializeModule for TransformerBlock { + fn serialize(&self, s: &mut Serializer) { + s.module("", &self.attention); + s.module("attn_norm", &self.attention_norm); + s.module("ffn_norm", &self.feed_forward_norm); + s.module("", &self.feed_forward); + } +} + +pub struct MistralLM { + // Token embeddings + pub embedding: Embedding, + // Transformer layers + pub layers: Vec, + // Final Norm layer + pub norm: RMSNorm, + // LM Head Layer + pub lm_head: GraphTensor>, +} + +impl + Module<( + GraphTensor<(Batch, CurSeq)>, + &[KVCache], + PhantomData, + )> for MistralLM +{ + type Output = ( + GraphTensor<(Batch, CurSeq, Const)>, + Vec>, + ); + fn forward( + &self, + (input, cache, _): ( + GraphTensor<(Batch, CurSeq)>, + &[KVCache], + PhantomData, + ), + ) -> Self::Output { + // Embed tokens + let mut x = self.embedding.forward(input); + + // Run through layers and collect new caches + let mut new_caches = vec![]; + let mut new_cache; + for (i, layer) in self.layers.iter().enumerate() { + (x, new_cache) = layer.forward((x, cache[i], PhantomData::)); + new_caches.push(new_cache); + } + // Run through last norm and output projection + let output = self.norm.forward(x).matmul(self.lm_head.permute()); + + (output, new_caches) + } +} + +impl InitModule for MistralLM { + fn initialize(cx: &mut Graph) -> Self { + Self { + embedding: Embedding { + weight: cx.named_tensor("Embedding Weight"), + }, + norm: RMSNorm { + weight: cx.named_tensor("RMS Norm Weight"), + epsilon: 1e-5, + }, + lm_head: cx.named_tensor("LM Head"), + layers: (0..NUM_LAYERS) + .map(|_| InitModule::initialize(cx)) + .collect(), + } + } +} + +impl SerializeModule for MistralLM { + fn serialize(&self, s: &mut Serializer) { + s.module("token_embd", &self.embedding); + s.module("output_norm", &self.norm); + s.tensor("output/weight", self.lm_head); + for (i, layer) in self.layers.iter().enumerate() { + s.module(&format!("blk/{i}"), layer); + } + } +} diff --git a/src/compiler_utils.rs b/src/compiler_utils.rs index c286bbf8..cb6ad9f8 100644 --- a/src/compiler_utils.rs +++ b/src/compiler_utils.rs @@ -523,6 +523,7 @@ pub fn display_graph( ); } + println!("GRAPGH STRING: {:?}", graph_string.len()); let url = format!( "https://dreampuf.github.io/GraphvizOnline/#{}", urlencoding::encode(&graph_string) From ad8908ab797a2814225e55ce0dd5a32121aa0b02 Mon Sep 17 00:00:00 2001 From: Joe Fioti Date: Fri, 26 Apr 2024 21:49:50 -0500 Subject: [PATCH 02/12] Fixed cpu gather compiler --- crates/luminal_cpu/src/binary.rs | 39 +++++++++++++++++------ crates/luminal_symbolic/src/expression.rs | 18 +++++++---- crates/luminal_symbolic/src/simplify.rs | 8 ++--- crates/luminal_symbolic/src/term.rs | 10 +++--- examples/phi/src/main.rs | 3 +- examples/phi/src/model.rs | 2 +- src/op.rs | 3 +- src/shape/tracker.rs | 2 +- 8 files changed, 54 insertions(+), 31 deletions(-) diff --git a/crates/luminal_cpu/src/binary.rs b/crates/luminal_cpu/src/binary.rs index 03960a56..a92cbcd8 100644 --- a/crates/luminal_cpu/src/binary.rs +++ b/crates/luminal_cpu/src/binary.rs @@ -196,19 +196,35 @@ pub struct GatherCompiler; impl Compiler for GatherCompiler { type Output = (); fn compile(&self, graph: &mut Graph, _: To) { - let arange = op::(); - let eq = unary::(arange); - let inp = node(); - let mul = binary::(inp.clone(), eq.clone()); + let indexes = node(); + let eq = binary::(indexes.clone(), op::()); + let embedding = node(); + let mul = binary::(embedding.clone(), eq.clone()); let sum_reduce = unary::(mul.clone()); let mut s = sum_reduce.clone().search(graph); while s.next_match() { - if s.check_no_delete(&[sum_reduce.id]) { + if s.check_no_delete(&[embedding.id]) { continue; } + let emb_shape = graph + .edges_connecting(s.get(&embedding), s.get(&mul)) + .next() + .unwrap() + .weight() + .as_data() + .unwrap() + .2; + let index_shape = graph + .edges_connecting(s.get(&indexes), s.get(&eq)) + .next() + .unwrap() + .weight() + .as_data() + .unwrap() + .2; let embed_dim = graph .graph - .edges_connecting(s.get(&inp), s.get(&mul)) + .edges_connecting(s.get(&embedding), s.get(&mul)) .next() .unwrap() .weight() @@ -218,11 +234,14 @@ impl Compiler for GatherCompiler { .shape()[2] .to_usize() .unwrap(); - let gather = graph.add_op(Gather { embed_dim }).finish(); - move_incoming_edge(s.get(&eq), gather, &mut graph.graph); - graph.safe_remove_node(s.get(&eq), 1); - move_incoming_edge(s.get(&mul), gather, &mut graph.graph); + + let gather = graph + .add_op(Gather { embed_dim }) + .input(s.get(&indexes), 0, index_shape) + .input(s.get(&embedding), 0, emb_shape) + .finish(); move_outgoing_edge(s.get(&sum_reduce), gather, &mut graph.graph); + graph.remove_node(s.get(&sum_reduce)); s.try_delete(); } } diff --git a/crates/luminal_symbolic/src/expression.rs b/crates/luminal_symbolic/src/expression.rs index fab56cc9..b69c9bd2 100644 --- a/crates/luminal_symbolic/src/expression.rs +++ b/crates/luminal_symbolic/src/expression.rs @@ -126,6 +126,12 @@ impl Debug for GenericExpression { } } +impl std::fmt::Display for GenericExpression { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{:?}", self) + } +} + impl GenericExpression { /// Simplify the expression to its minimal terms pub fn simplify(self) -> Self { @@ -217,11 +223,11 @@ where self.exec_single_var_stack(value, &mut stack) } /// Evaluate the expression with one value for all variables. Uses a provided stack - pub fn exec_single_var_stack(&self, value: usize, stack: &mut Vec) -> usize { + pub fn exec_single_var_stack(&self, value: usize, stack: &mut Vec) -> usize { for term in &self.terms { match term { - Term::Num(n) => stack.push(*n), - Term::Var(_) => stack.push(value as i32), + Term::Num(n) => stack.push(*n as i64), + Term::Var(_) => stack.push(value as i64), _ => { let a = stack.pop().unwrap(); let b = stack.pop().unwrap(); @@ -239,16 +245,16 @@ where pub fn exec_stack( &self, variables: &FxHashMap, - stack: &mut Vec, + stack: &mut Vec, ) -> Option { for term in &self.terms { match term { - Term::Num(n) => stack.push(*n), + Term::Num(n) => stack.push(*n as i64), Term::Var(c) => { #[allow(clippy::needless_borrow)] if let Some(n) = variables.get(&c) { - stack.push(*n as i32) + stack.push(*n as i64) } else { return None; } diff --git a/crates/luminal_symbolic/src/simplify.rs b/crates/luminal_symbolic/src/simplify.rs index df3f426b..27f0341c 100644 --- a/crates/luminal_symbolic/src/simplify.rs +++ b/crates/luminal_symbolic/src/simplify.rs @@ -25,8 +25,8 @@ pub fn reduce_triples( let (b_ind, b_term) = stack.pop().unwrap(); triples.push((a_ind, index, b_ind)); if let (Term::Num(a), Term::Num(b)) = (a_term, b_term) { - if let Some(c) = term.as_op().unwrap()(a, b) { - stack.push((None, Term::Num(c))); + if let Some(c) = term.as_op().unwrap()(a as i64, b as i64) { + stack.push((None, Term::Num(c as i32))); } else { break; } @@ -68,8 +68,8 @@ pub fn reduce_triples( b_ind.map(|b| expr.terms[b]), ) { (Some(Term::Num(a)), term, Some(Term::Num(b))) if term.as_op().is_some() => { - if let Some(c) = term.as_op().unwrap()(a, b) { - expr.terms[unwrap_cont!(a_ind)] = Term::Num(c); + if let Some(c) = term.as_op().unwrap()(a as i64, b as i64) { + expr.terms[unwrap_cont!(a_ind)] = Term::Num(c as i32); remove_terms(&mut expr.terms, &[op_ind, unwrap_cont!(b_ind)]); } else { inner_changed = false; diff --git a/crates/luminal_symbolic/src/term.rs b/crates/luminal_symbolic/src/term.rs index af219b2c..87a88c10 100644 --- a/crates/luminal_symbolic/src/term.rs +++ b/crates/luminal_symbolic/src/term.rs @@ -43,7 +43,7 @@ impl Default for Term { } impl Term { - pub fn as_op(self) -> Option Option> { + pub fn as_op(self) -> Option Option> { match self { Term::Add => Some(|a, b| a.checked_add(b)), Term::Sub => Some(|a, b| a.checked_sub(b)), @@ -52,10 +52,10 @@ impl Term { Term::Mod => Some(|a, b| a.checked_rem(b)), Term::Max => Some(|a, b| Some(a.max(b))), Term::Min => Some(|a, b| Some(a.min(b))), - Term::And => Some(|a, b| Some((a != 0 && b != 0) as i32)), - Term::Or => Some(|a, b| Some((a != 0 || b != 0) as i32)), - Term::Gte => Some(|a, b| Some((a >= b) as i32)), - Term::Lt => Some(|a, b| Some((a < b) as i32)), + Term::And => Some(|a, b| Some((a != 0 && b != 0) as i64)), + Term::Or => Some(|a, b| Some((a != 0 || b != 0) as i64)), + Term::Gte => Some(|a, b| Some((a >= b) as i64)), + Term::Lt => Some(|a, b| Some((a < b) as i64)), _ => None, } } diff --git a/examples/phi/src/main.rs b/examples/phi/src/main.rs index 31ac7595..c2b7499b 100644 --- a/examples/phi/src/main.rs +++ b/examples/phi/src/main.rs @@ -81,7 +81,6 @@ fn main() { &mut model_weights, ), ); - cx.display(); let cache_src_set = downstream(&cache_src, &cx); let cache_dest_set = cache_dest.to_ids(); println!("\t\t - {}ms", now.elapsed().as_millis()); @@ -147,7 +146,7 @@ fn main() { // Sample tokens let output_id = sample_index(&logits.data()); - println!("{:?}", &logits.data()[..10]); + // println!("{:?}", &logits.data()[..10]); logits.drop(); output_ids.push(output_id); diff --git a/examples/phi/src/model.rs b/examples/phi/src/model.rs index 94cb3e2b..c219be9d 100644 --- a/examples/phi/src/model.rs +++ b/examples/phi/src/model.rs @@ -6,7 +6,7 @@ use luminal_nn::{Embedding, PermutedLinear, RMSNorm}; // Llama3 8B Config pub const VOCAB_SIZE: usize = 32064; pub const HIDDEN_DIM: usize = 3072; -pub const NUM_LAYERS: usize = 1; +pub const NUM_LAYERS: usize = 32; pub const N_HEADS: usize = 32; pub const N_KV_HEADS: usize = 8; pub const MLP_DIM: usize = 8192; diff --git a/src/op.rs b/src/op.rs index 01cf2fc8..af055145 100644 --- a/src/op.rs +++ b/src/op.rs @@ -316,7 +316,6 @@ impl Operator for SumReduce { let input = get_vec(&inp[0].0); let expr = (inp[0].1.index_expression(), inp[0].1.valid_expression()); let mut stack = vec![]; - for i in 0..front_size { for j in 0..back_size { for k in 0..dim_size { @@ -363,7 +362,7 @@ fn get_vec<'a>(tensor: &'a InputTensor<'a>) -> &'a Vec { fn get_index( data: &[f32], (ind, val): &(BigExpression, BigExpression), - stack: &mut Vec, + stack: &mut Vec, index: usize, ) -> f32 { if val.exec_single_var_stack(index, stack) != 0 { diff --git a/src/shape/tracker.rs b/src/shape/tracker.rs index 18e8bef3..9b1829dd 100644 --- a/src/shape/tracker.rs +++ b/src/shape/tracker.rs @@ -274,7 +274,7 @@ impl ShapeTracker { pub fn resolve_global_dyn_dims_stack( &mut self, dyn_dim_map: &FxHashMap, - stack: &mut Vec, + stack: &mut Vec, ) { for d in self.dims.iter_mut() { *d = d.exec_stack(dyn_dim_map, stack).unwrap().into(); From 8bf379b3eb4c9d97c3a7ac5dc8f9dc99d707ee4e Mon Sep 17 00:00:00 2001 From: Joe Fioti Date: Fri, 26 Apr 2024 22:03:22 -0500 Subject: [PATCH 03/12] Changed tests --- crates/luminal_cuda/src/tests/fp16.rs | 3 +-- crates/luminal_cuda/src/tests/fp32.rs | 5 ++--- examples/phi/src/main.rs | 1 + examples/phi/src/model.rs | 2 +- src/compiler_utils.rs | 1 - 5 files changed, 5 insertions(+), 7 deletions(-) diff --git a/crates/luminal_cuda/src/tests/fp16.rs b/crates/luminal_cuda/src/tests/fp16.rs index 50b24c8c..3af91c15 100644 --- a/crates/luminal_cuda/src/tests/fp16.rs +++ b/crates/luminal_cuda/src/tests/fp16.rs @@ -33,8 +33,7 @@ binary_test!(|a, b| a + b, |a, b| a + b, test_add, f16); binary_test!(|a, b| a - b, |a, b| a - b, test_sub, f16); binary_test!(|a, b| a * b, |a, b| a * b, test_mul, f16); binary_test!(|a, b| a / b, |a, b| a * b.recip(), test_div, f16); -// binary_test!(|a, b| a.max(b), |a, b| a.maximum(b), test_max, f16); -single_binary_test!(|a, b| a.max(b), |a, b| a.maximum(b), test_max, f16, 3); // Why don't larger max tests work? +binary_test!(|a, b| a.max(b), |a, b| a.maximum(b), test_max, f16); binary_test!(|a, b| a.min(b), |a, b| a.minimum(b), test_min, f16); #[test] diff --git a/crates/luminal_cuda/src/tests/fp32.rs b/crates/luminal_cuda/src/tests/fp32.rs index 78aab5b8..80f23247 100644 --- a/crates/luminal_cuda/src/tests/fp32.rs +++ b/crates/luminal_cuda/src/tests/fp32.rs @@ -20,7 +20,7 @@ use luminal::{ }, }; -use crate::{binary_test, single_binary_test, single_unary_test, unary_test, CudaCompiler}; +use crate::{binary_test, single_unary_test, unary_test, CudaCompiler}; unary_test!(|a| a.sin(), |a| a.sin(), test_sin, f32); unary_test!(|a| a.sqrt(), |a| a.sqrt(), test_sqrt, f32); @@ -32,8 +32,7 @@ binary_test!(|a, b| a + b, |a, b| a + b, test_add, f32); binary_test!(|a, b| a - b, |a, b| a - b, test_sub, f32); binary_test!(|a, b| a * b, |a, b| a * b, test_mul, f32); binary_test!(|a, b| a / b, |a, b| a * b.recip(), test_div, f32); -// binary_test!(|a, b| a.max(b), |a, b| a.maximum(b), test_max, f32); -single_binary_test!(|a, b| a.max(b), |a, b| a.maximum(b), test_max, f32, 3); // Why don't larger max tests work? +binary_test!(|a, b| a.max(b), |a, b| a.maximum(b), test_max, f32); binary_test!(|a, b| a.min(b), |a, b| a.minimum(b), test_min, f32); #[test] diff --git a/examples/phi/src/main.rs b/examples/phi/src/main.rs index c2b7499b..05585351 100644 --- a/examples/phi/src/main.rs +++ b/examples/phi/src/main.rs @@ -81,6 +81,7 @@ fn main() { &mut model_weights, ), ); + cx.display(); let cache_src_set = downstream(&cache_src, &cx); let cache_dest_set = cache_dest.to_ids(); println!("\t\t - {}ms", now.elapsed().as_millis()); diff --git a/examples/phi/src/model.rs b/examples/phi/src/model.rs index c219be9d..94cb3e2b 100644 --- a/examples/phi/src/model.rs +++ b/examples/phi/src/model.rs @@ -6,7 +6,7 @@ use luminal_nn::{Embedding, PermutedLinear, RMSNorm}; // Llama3 8B Config pub const VOCAB_SIZE: usize = 32064; pub const HIDDEN_DIM: usize = 3072; -pub const NUM_LAYERS: usize = 32; +pub const NUM_LAYERS: usize = 1; pub const N_HEADS: usize = 32; pub const N_KV_HEADS: usize = 8; pub const MLP_DIM: usize = 8192; diff --git a/src/compiler_utils.rs b/src/compiler_utils.rs index cb6ad9f8..c286bbf8 100644 --- a/src/compiler_utils.rs +++ b/src/compiler_utils.rs @@ -523,7 +523,6 @@ pub fn display_graph( ); } - println!("GRAPGH STRING: {:?}", graph_string.len()); let url = format!( "https://dreampuf.github.io/GraphvizOnline/#{}", urlencoding::encode(&graph_string) From fb84e93815fcff97485edf5941d29af1bfe8f810 Mon Sep 17 00:00:00 2001 From: Joe Fioti Date: Sat, 27 Apr 2024 09:35:55 -0500 Subject: [PATCH 04/12] Metal fixes --- crates/luminal_metal/src/prim.rs | 3 +++ examples/llama/src/main.rs | 5 +++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/crates/luminal_metal/src/prim.rs b/crates/luminal_metal/src/prim.rs index c2187f58..354e9127 100644 --- a/crates/luminal_metal/src/prim.rs +++ b/crates/luminal_metal/src/prim.rs @@ -1600,6 +1600,9 @@ impl Compiler for PrimitiveCompiler { graph.remove_edge(edge_id); } + if graph.no_delete.contains(&function_node) { + graph.no_delete.insert(copy_node); + } if let Some(w) = graph.to_retrieve.get(&function_node) { graph.to_retrieve.insert(copy_node, *w); } diff --git a/examples/llama/src/main.rs b/examples/llama/src/main.rs index a4abcdef..93488b46 100644 --- a/examples/llama/src/main.rs +++ b/examples/llama/src/main.rs @@ -45,7 +45,7 @@ fn main() { .collect(); cache_src.set_dyn(vec![], &[1, model::N_KV_HEADS, 0, model::HEAD_DIM]); let model = model::MistralLM::initialize(&mut cx); - let mut model_weights = downstream(params(&model), &cx); + let mut model_weights = params(&model); cx.keep_tensors(&model_weights); let (logits, mut cache_dest) = model.forward((input, &cache_src, PhantomData::>)); let mut logits = logits @@ -97,7 +97,8 @@ fn main() { println!("\t\t - {}ms", now.elapsed().as_millis()); // Now that weights are loaded, delete the loading nodes so they don't run again - delete_inputs(&model_weights, &mut cx); + delete_inputs(&downstream(model_weights, &cx), &mut cx); + // Run prompt processing pass let mut input_ids = tokenizer .encode(&cli_args.prompt as &str, false) From 9b734d6cbd878ce26fadb3debb6d3aa8e3c6afaf Mon Sep 17 00:00:00 2001 From: Joe Fioti Date: Sat, 27 Apr 2024 09:43:51 -0500 Subject: [PATCH 05/12] Fixed llama layers --- examples/llama/src/model.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/llama/src/model.rs b/examples/llama/src/model.rs index ecab6862..90760362 100644 --- a/examples/llama/src/model.rs +++ b/examples/llama/src/model.rs @@ -6,7 +6,7 @@ use luminal_nn::{Embedding, PermutedLinear, RMSNorm}; // Llama3 8B Config pub const VOCAB_SIZE: usize = 128256; pub const HIDDEN_DIM: usize = 4096; -pub const NUM_LAYERS: usize = 1; +pub const NUM_LAYERS: usize = 32; pub const N_HEADS: usize = 32; pub const N_KV_HEADS: usize = 8; pub const MLP_DIM: usize = 14336; From 868e1c667ef29510510360e984da83cffae8d4d3 Mon Sep 17 00:00:00 2001 From: Joe Fioti Date: Sat, 27 Apr 2024 10:00:31 -0500 Subject: [PATCH 06/12] Support fp16 on llama --- .../luminal_metal/src/elementwise_fusion.rs | 4 +- examples/llama/src/loader.rs | 66 ++++++++++++------- examples/llama/src/main.rs | 4 +- 3 files changed, 48 insertions(+), 26 deletions(-) diff --git a/crates/luminal_metal/src/elementwise_fusion.rs b/crates/luminal_metal/src/elementwise_fusion.rs index 8ed2cec1..c1de9840 100644 --- a/crates/luminal_metal/src/elementwise_fusion.rs +++ b/crates/luminal_metal/src/elementwise_fusion.rs @@ -1090,9 +1090,9 @@ mod tests { let unopt_out = out.data(); out.drop(); - cx.compile(<(GenericCompiler, MetalCompiler)>::default(), &mut out); + cx.compile(<(GenericCompiler, MetalCompiler)>::default(), &mut out); cx.execute(); - assert_close_precision(&out.data(), &unopt_out, 1e-3); + assert_close_precision(&out.data(), &unopt_out, 1e-2); } } diff --git a/examples/llama/src/loader.rs b/examples/llama/src/loader.rs index 5e8ff96e..5c557977 100644 --- a/examples/llama/src/loader.rs +++ b/examples/llama/src/loader.rs @@ -1,4 +1,6 @@ +use itertools::Itertools; use std::fs::File; +use std::io::{Read, Seek}; use std::path::Path; use luminal::{op::Function, prelude::*}; @@ -8,11 +10,6 @@ use {luminal_cuda::CudaData, luminal_cudarc::driver::CudaDevice}; use crate::gguf::*; -#[cfg(not(feature = "metal"))] -use { - itertools::Itertools, - std::io::{Read, Seek}, -}; #[cfg(feature = "metal")] use { luminal_metal::MetalBuffer, @@ -53,23 +50,48 @@ pub fn q8_load, M: SerializeModule>( } _ => panic!("Unsupported dtype: {data_type:?}"), }; - loading_node.1 = Box::new(move |_| { - let mmap_buffer = unsafe { Mmap::map(&File::open(&file_path).unwrap()).unwrap() }; - let buffer = Device::system_default() - .unwrap() - .new_buffer_with_bytes_no_copy( - unsafe { - mmap_buffer - .as_ptr() - .add(buffer_offset + tensor_data_offset as usize) - as *const _ - }, - n_bytes as u64, - MTLResourceOptions::StorageModeShared, - None, - ); - vec![Tensor::new(MetalBuffer(buffer))] - }); + if let GgmlDType::F32 = data_type { + loading_node.1 = Box::new(move |_| { + // Read bytes + let mut bytes = vec![0; n_bytes]; + let mut file = File::open(&file_path).unwrap(); + file.seek(std::io::SeekFrom::Start( + buffer_offset as u64 + tensor_data_offset, + )) + .unwrap(); + file.read_exact(&mut bytes).unwrap(); + vec![Tensor::new( + bytes + .into_iter() + .chunks(4) + .into_iter() + .map(|c| { + let c = c.collect::>(); + f32::from_le_bytes([c[0], c[1], c[2], c[3]]) + }) + .collect::>(), + )] + }); + } else { + loading_node.1 = Box::new(move |_| { + let mmap_buffer = + unsafe { Mmap::map(&File::open(&file_path).unwrap()).unwrap() }; + let buffer = Device::system_default() + .unwrap() + .new_buffer_with_bytes_no_copy( + unsafe { + mmap_buffer + .as_ptr() + .add(buffer_offset + tensor_data_offset as usize) + as *const _ + }, + n_bytes as u64, + MTLResourceOptions::StorageModeShared, + None, + ); + vec![Tensor::new(MetalBuffer(buffer))] + }); + } } } q8_weights diff --git a/examples/llama/src/main.rs b/examples/llama/src/main.rs index 93488b46..83cd8c8a 100644 --- a/examples/llama/src/main.rs +++ b/examples/llama/src/main.rs @@ -67,9 +67,9 @@ fn main() { ( GenericCompiler::default(), #[cfg(feature = "metal")] - luminal_metal::quantized::MetalQuantizedCompiler::::new(q_weights), + luminal_metal::quantized::MetalQuantizedCompiler::::new(q_weights), #[cfg(feature = "cuda")] - luminal_cuda::CudaQuantizedCompiler::::new(q_weights), + luminal_cuda::CudaQuantizedCompiler::::new(q_weights), #[cfg(all(not(feature = "metal"), not(feature = "cuda")))] luminal_cpu::CPUCompiler::default(), ), From fa2b7ac22e8c0b18fea501c101379ffc305bec17 Mon Sep 17 00:00:00 2001 From: Joe Fioti Date: Sun, 28 Apr 2024 12:35:14 -0500 Subject: [PATCH 07/12] Fixed many cuda bugs --- crates/luminal_cuda/src/lib.rs | 15 +- crates/luminal_cuda/src/matmul.rs | 13 +- crates/luminal_cuda/src/other.rs | 81 ++- crates/luminal_cuda/src/prim.rs | 88 +-- crates/luminal_cuda/src/tests/fp16.rs | 35 +- crates/luminal_cuda/src/tests/fp32.rs | 41 +- crates/luminal_cuda/src/tests/mod.rs | 77 ++- crates/luminal_cuda/src/unary.rs | 781 ++++++++++++++++++++++++++ examples/phi/src/main.rs | 1 - 9 files changed, 968 insertions(+), 164 deletions(-) create mode 100644 crates/luminal_cuda/src/unary.rs diff --git a/crates/luminal_cuda/src/lib.rs b/crates/luminal_cuda/src/lib.rs index fa4485f9..0b2c20cc 100644 --- a/crates/luminal_cuda/src/lib.rs +++ b/crates/luminal_cuda/src/lib.rs @@ -3,9 +3,11 @@ mod matmul; mod other; mod prim; mod quantized; +mod unary; pub use quantized::*; #[cfg(test)] +#[macro_use] mod tests; use itertools::Itertools; @@ -20,14 +22,25 @@ use std::{collections::hash_map::DefaultHasher, ffi::c_void, fmt::Write, hash::H use luminal::{op::InputTensor, prelude::*}; +/// Compile graphs to run on Metal-supported macOS devices in supported data formats pub type CudaCompiler = ( prim::PrimitiveCompiler, + SpecialOpsCompiler, + other::CopyCompiler, +); + +/// Compiler to replace metal ops with specialized variants +pub type SpecialOpsCompiler = ( binary::SubtractionCompiler, binary::EqualCompiler, other::ARangeCompiler, binary::GatherCompiler, + unary::CudaExpCompiler, + unary::CudaCosCompiler, + unary::MeanReduceCompiler, + unary::StdNormCompiler, + unary::SoftmaxCompiler, matmul::MatMulCompiler, - prim::CopyCompiler, ); pub trait CudaFloat: diff --git a/crates/luminal_cuda/src/matmul.rs b/crates/luminal_cuda/src/matmul.rs index 0125cad3..d558aa52 100644 --- a/crates/luminal_cuda/src/matmul.rs +++ b/crates/luminal_cuda/src/matmul.rs @@ -22,7 +22,6 @@ crate::debug_type!(Matmul); impl Operator for Matmul { fn process(&mut self, inp: Vec<(InputTensor, ShapeTracker)>) -> Vec { let (a_shape, b_shape) = (inp[0].1.shape(), inp[1].1.shape()); - let a_strides = inp[0].1.strides(); let (batch_size, m, k, n) = ( a_shape .iter() @@ -33,6 +32,7 @@ impl Operator for Matmul { a_shape[a_shape.len() - 1].to_usize().unwrap() as i32, b_shape[b_shape.len() - 1].to_usize().unwrap() as i32, ); + println!("{:?}", (batch_size, m, k, n)); let a = get_buffer_from_tensor::(&inp[0].0); let b = get_buffer_from_tensor::(&inp[1].0); let mut out = self @@ -49,6 +49,9 @@ impl Operator for Matmul { (false, true) => (CUBLAS_OP_N, CUBLAS_OP_T), (true, false) => (CUBLAS_OP_T, CUBLAS_OP_N), }; + + let a_dims = inp[0].1.fake.iter().filter(|f| !**f).count(); + let b_dims = inp[1].1.fake.iter().filter(|f| !**f).count(); if T::is_f32() { unsafe { luminal_cudarc::cublas::result::sgemm_strided_batched( @@ -61,10 +64,10 @@ impl Operator for Matmul { &1.0_f32 as *const f32, *b.device_ptr() as *const f32, if b_row_major { n } else { k }, - 0, + if b_dims == 2 { 0 } else { (n * k) as i64 }, *a.device_ptr() as *const f32, if a_row_major { k } else { m }, - a_strides[0].to_usize().unwrap() as i64, + if a_dims == 2 { 0 } else { (m * k) as i64 }, &0.0_f32 as *const f32, *out.device_ptr_mut() as *mut f32, n, @@ -85,10 +88,10 @@ impl Operator for Matmul { &f16::from_f32(1.0) as *const f16, *b.device_ptr() as *const f16, if b_row_major { n } else { k }, - 0, + if b_dims == 2 { 0 } else { (n * k) as i64 }, *a.device_ptr() as *const f16, if a_row_major { k } else { m }, - a_strides[0].to_usize().unwrap() as i64, + if a_dims == 2 { 0 } else { (m * k) as i64 }, &f16::from_f32(0.0) as *const f16, *out.device_ptr_mut() as *mut f16, n, diff --git a/crates/luminal_cuda/src/other.rs b/crates/luminal_cuda/src/other.rs index f8598c63..90f242c0 100644 --- a/crates/luminal_cuda/src/other.rs +++ b/crates/luminal_cuda/src/other.rs @@ -1,13 +1,14 @@ use std::{marker::PhantomData, sync::Arc}; -use luminal::prelude::*; +use itertools::Itertools; +use luminal::prelude::{petgraph::visit::EdgeRef, *}; use luminal_cudarc::driver::{CudaDevice, CudaFunction, LaunchAsync, LaunchConfig}; use rustc_hash::FxHashMap; use crate::{ binary::CudaSub, compile_and_load_kernel, constant, - prim::{CudaAdd, CudaContiguous, CudaSumReduce}, + prim::{CudaAdd, CudaContiguous, CudaCopyFromDevice, CudaCopyToDevice, CudaSumReduce}, CudaData, CudaFloat, }; @@ -120,3 +121,79 @@ impl Compiler for ARangeCompiler { } } } + +// Sometimes CopyTo -> CopyFrom and CopyFrom -> CopyTo patterns remain, so let's clean them up +#[derive(Debug, Default)] +pub struct CopyCompiler(PhantomData); + +impl Compiler for CopyCompiler { + type Output = (); + fn compile(&self, graph: &mut Graph, mut ids: To) { + for (first, second) in graph + .edge_indices() + .filter_map(|e| graph.edge_endpoints(e)) + .filter(|(a, b)| { + (graph + .node_weight(*a) + .unwrap() + .as_any() + .is::>() + && graph + .node_weight(*b) + .unwrap() + .as_any() + .is::>()) + || (graph + .node_weight(*a) + .unwrap() + .as_any() + .is::>() + && graph + .node_weight(*b) + .unwrap() + .as_any() + .is::>()) + }) + .unique_by(|n| n.0) + .unique_by(|n| n.1) + .collect::>() + { + if graph + .edges_directed(first, petgraph::Direction::Outgoing) + .filter(|e| graph.contains_node(e.target())) + .filter(|e| { + !graph + .node_weight(e.target()) + .unwrap() + .as_any() + .is::>() + && !graph + .node_weight(e.target()) + .unwrap() + .as_any() + .is::>() + }) + .count() + > 0 + || graph.no_delete.contains(&first) + { + continue; + } + let source = graph.get_sources(first)[0]; + move_outgoing_edge(second, source.0, graph); + remap(second, source.0, &mut ids, graph); + graph.remove_node(second); + for dest in graph + .get_dests(first) + .iter() + .map(|(i, _)| *i) + .collect::>() + { + move_outgoing_edge(dest, source.0, graph); + remap(dest, source.0, &mut ids, graph); + graph.remove_node(dest); + } + graph.remove_node(first); + } + } +} diff --git a/crates/luminal_cuda/src/prim.rs b/crates/luminal_cuda/src/prim.rs index ed4497be..a4f47277 100644 --- a/crates/luminal_cuda/src/prim.rs +++ b/crates/luminal_cuda/src/prim.rs @@ -355,8 +355,10 @@ crate::debug_type!(CudaSin); impl CudaSin { pub fn new(device: Arc) -> Self { let type_name = T::type_name(); - let code = format!( - " + Self { + function: compile_and_load_kernel( + format!( + " #include \"cuda_fp16.h\" extern \"C\" __global__ void kernel({type_name} *out, const {type_name} *inp, int numel) {{ int i = blockIdx.x * blockDim.x + threadIdx.x; @@ -364,9 +366,9 @@ extern \"C\" __global__ void kernel({type_name} *out, const {type_name} *inp, in out[i] = sin(inp[i]); }} }}" - ); - Self { - function: compile_and_load_kernel(code, &device), + ), + &device, + ), device, _phantom: Default::default(), } @@ -1108,79 +1110,3 @@ impl Compiler for PrimitiveCompiler { } } } - -// Sometimes CopyTo -> CopyFrom and CopyFrom -> CopyTo patterns remain, so let's clean them up -#[derive(Debug, Default)] -pub struct CopyCompiler(PhantomData); - -impl Compiler for CopyCompiler { - type Output = (); - fn compile(&self, graph: &mut Graph, mut ids: To) { - for (first, second) in graph - .edge_indices() - .filter_map(|e| graph.edge_endpoints(e)) - .filter(|(a, b)| { - (graph - .node_weight(*a) - .unwrap() - .as_any() - .is::>() - && graph - .node_weight(*b) - .unwrap() - .as_any() - .is::>()) - || (graph - .node_weight(*a) - .unwrap() - .as_any() - .is::>() - && graph - .node_weight(*b) - .unwrap() - .as_any() - .is::>()) - }) - .unique_by(|n| n.0) - .unique_by(|n| n.1) - .collect::>() - { - if graph - .edges_directed(first, petgraph::Direction::Outgoing) - .filter(|e| graph.contains_node(e.target())) - .filter(|e| { - !graph - .node_weight(e.target()) - .unwrap() - .as_any() - .is::>() - && !graph - .node_weight(e.target()) - .unwrap() - .as_any() - .is::>() - }) - .count() - > 0 - || graph.no_delete.contains(&first) - { - continue; - } - let source = graph.get_sources(first)[0]; - move_outgoing_edge(second, source.0, graph); - remap(second, source.0, &mut ids, graph); - graph.remove_node(second); - for dest in graph - .get_dests(first) - .iter() - .map(|(i, _)| *i) - .collect::>() - { - move_outgoing_edge(dest, source.0, graph); - remap(dest, source.0, &mut ids, graph); - graph.remove_node(dest); - } - graph.remove_node(first); - } - } -} diff --git a/crates/luminal_cuda/src/tests/fp16.rs b/crates/luminal_cuda/src/tests/fp16.rs index 3af91c15..be07cfbf 100644 --- a/crates/luminal_cuda/src/tests/fp16.rs +++ b/crates/luminal_cuda/src/tests/fp16.rs @@ -21,20 +21,7 @@ use luminal::{ }, }; -use crate::{binary_test, single_binary_test, single_unary_test, unary_test, CudaCompiler}; - -unary_test!(|a| a.sin(), |a| a.sin(), test_sin, f16); -unary_test!(|a| a.sqrt(), |a| a.sqrt(), test_sqrt, f16); -unary_test!(|a| a.recip(), |a| a.recip(), test_recip, f16); -unary_test!(|a| a * a, |a| a.clone() * a, test_square, f16); -single_unary_test!(|a| a.ln(), |a| a.ln(), test_ln, f16, 3); // For some reason ln fails on larger tensors - -binary_test!(|a, b| a + b, |a, b| a + b, test_add, f16); -binary_test!(|a, b| a - b, |a, b| a - b, test_sub, f16); -binary_test!(|a, b| a * b, |a, b| a * b, test_mul, f16); -binary_test!(|a, b| a / b, |a, b| a * b.recip(), test_div, f16); -binary_test!(|a, b| a.max(b), |a, b| a.maximum(b), test_max, f16); -binary_test!(|a, b| a.min(b), |a, b| a.minimum(b), test_min, f16); +use crate::CudaCompiler; #[test] fn test_contiguous() { @@ -54,24 +41,6 @@ fn test_contiguous() { assert_close(&b.data(), &d_b.to_dtype::().as_vec()); } -#[test] -fn test_softmax() { - let mut cx = Graph::new(); - let data = random_vec(12); - let a = cx.tensor::>().set(data.clone()); - let mut b = a.softmax::>().retrieve(); - cx.compile(CudaCompiler::::default(), &mut b); - cx.execute(); - - let d_dev = Cpu::default(); - let d_a = d_dev - .tensor_from_vec(data, (DConst::<1>, DConst::<12>)) - .to_dtype::(); - let d_b = d_a.softmax::>(); - - assert_close(&b.data(), &d_b.to_dtype::().as_vec()); -} - #[test] fn test_rotate() { let mut cx = Graph::new(); @@ -814,7 +783,7 @@ fn test_pad_contig() { .set_dyn(a_data, &[m, k]) .retrieve(); let mut b: GraphTensor<(Dyn<'M'>, Dyn<'K'>)> = a - .pad(&[(0, 0.into()), (0, Expression::from(16) - 'K')]) + .pad(&[(0, 0.into()), (0, Expression::from(24) - 'K')]) .contiguous() .retrieve(); let mut c: GraphTensor<(Dyn<'M'>, Dyn<'K'>)> = diff --git a/crates/luminal_cuda/src/tests/fp32.rs b/crates/luminal_cuda/src/tests/fp32.rs index 80f23247..61d5512a 100644 --- a/crates/luminal_cuda/src/tests/fp32.rs +++ b/crates/luminal_cuda/src/tests/fp32.rs @@ -20,21 +20,9 @@ use luminal::{ }, }; -use crate::{binary_test, single_unary_test, unary_test, CudaCompiler}; - -unary_test!(|a| a.sin(), |a| a.sin(), test_sin, f32); -unary_test!(|a| a.sqrt(), |a| a.sqrt(), test_sqrt, f32); -unary_test!(|a| a.recip(), |a| a.recip(), test_recip, f32); -unary_test!(|a| a * a, |a| a.clone() * a, test_square, f32); +use crate::{single_unary_test, CudaCompiler}; single_unary_test!(|a| a.ln(), |a| a.ln(), test_ln, f32, 3); // For some reason ln fails on larger tensors -binary_test!(|a, b| a + b, |a, b| a + b, test_add, f32); -binary_test!(|a, b| a - b, |a, b| a - b, test_sub, f32); -binary_test!(|a, b| a * b, |a, b| a * b, test_mul, f32); -binary_test!(|a, b| a / b, |a, b| a * b.recip(), test_div, f32); -binary_test!(|a, b| a.max(b), |a, b| a.maximum(b), test_max, f32); -binary_test!(|a, b| a.min(b), |a, b| a.minimum(b), test_min, f32); - #[test] fn test_contiguous() { let mut cx = Graph::new(); @@ -51,22 +39,6 @@ fn test_contiguous() { assert_close(&b.data(), &d_b.as_vec()); } -#[test] -fn test_softmax() { - let mut cx = Graph::new(); - let data = random_vec(12); - let a = cx.tensor::>().set(data.clone()); - let mut b = a.softmax::>().retrieve(); - cx.compile(CudaCompiler::::default(), &mut b); - cx.execute(); - - let d_dev = Cpu::default(); - let d_a = d_dev.tensor_from_vec(data, (DConst::<1>, DConst::<12>)); - let d_b = d_a.softmax::>(); - - assert_close(&b.data(), &d_b.as_vec()); -} - #[test] fn test_rotate() { let mut cx = Graph::new(); @@ -575,8 +547,8 @@ fn test_layer_norm() { let d_b = d_a.clone().normalize::>(1e-5); let d_c = d_a.normalize::>(1e-5); - assert_close_precision(&b.data(), &d_b.as_vec(), 0.01); - assert_close_precision(&c.data(), &d_c.as_vec(), 0.01); + assert_close_precision(&b.data(), &d_b.as_vec(), 1e-2); + assert_close_precision(&c.data(), &d_c.as_vec(), 1e-2); } #[test] @@ -746,13 +718,16 @@ fn test_pad_contig() { .set_dyn(a_data, &[m, k]) .retrieve(); let mut b: GraphTensor<(Dyn<'M'>, Dyn<'K'>)> = a - .pad(&[(0, 0.into()), (0, Expression::from(16) - 'K')]) + .pad(&[(0, 0.into()), (0, Expression::from(24) - 'K')]) .contiguous() .retrieve(); let mut c: GraphTensor<(Dyn<'M'>, Dyn<'K'>)> = (a.slice((.., ..Expression::from(k))).realize() / 1.0).retrieve(); - cx.compile(CudaCompiler::::default(), (&mut a, &mut b, &mut c)); + cx.compile( + <(GenericCompiler, CudaCompiler)>::default(), + (&mut a, &mut b, &mut c), + ); cx.execute(); // Close because b and c are going through 16 bits, while a is not diff --git a/crates/luminal_cuda/src/tests/mod.rs b/crates/luminal_cuda/src/tests/mod.rs index 2568eea1..6c36075d 100644 --- a/crates/luminal_cuda/src/tests/mod.rs +++ b/crates/luminal_cuda/src/tests/mod.rs @@ -1,3 +1,8 @@ +use dfdx::prelude::*; +use luminal::prelude::*; +use luminal::tests::random_vec_rng; +use rand::{rngs::StdRng, SeedableRng}; + mod fp16; mod fp32; @@ -6,14 +11,14 @@ macro_rules! single_unary_test { ($luminal_func: expr , $dfdx_func: expr , $name: ident, $type: ty, $size: expr) => { paste::paste! { #[test] - fn [<$name _ $size>]() { + fn [<$name _ $type _ $size>]() { let mut rng = StdRng::seed_from_u64(1); let data = random_vec_rng($size, &mut rng); let mut cx = Graph::new(); let a = cx.tensor::>().set(data.clone()); let f: fn(GraphTensor>) -> GraphTensor> = $luminal_func; let mut b = f(a).retrieve(); - cx.compile(CudaCompiler::<$type>::default(), &mut b); + cx.compile($crate::CudaCompiler::<$type>::default(), &mut b); cx.execute(); let d_dev = Cpu::default(); @@ -25,14 +30,14 @@ macro_rules! single_unary_test { ) -> dfdx::prelude::Tensor, $type, Cpu, NoneTape> = $dfdx_func; let d_b = f(d_a); - assert_close(&b.data(), &d_b.to_dtype::().as_vec()); + luminal::tests::assert_close(&b.data(), &d_b.to_dtype::().as_vec()); } } }; } #[macro_export] -macro_rules! unary_test { +macro_rules! unary_test_type { ($luminal_func: expr , $dfdx_func: expr , $name: ident, $type: ty) => { $crate::single_unary_test!($luminal_func, $dfdx_func, $name, $type, 3); $crate::single_unary_test!($luminal_func, $dfdx_func, $name, $type, 50); @@ -41,12 +46,20 @@ macro_rules! unary_test { }; } +#[macro_export] +macro_rules! unary_test { + ($luminal_func: expr , $dfdx_func: expr , $name: ident) => { + $crate::unary_test_type!($luminal_func, $dfdx_func, $name, f32); + $crate::unary_test_type!($luminal_func, $dfdx_func, $name, f16); + }; +} + #[macro_export] macro_rules! single_binary_test { ($luminal_func: expr , $dfdx_func: expr , $name: ident, $type: ty, $size: expr) => { paste::paste! { #[test] - fn [<$name _ $size>]() { + fn [<$name _ $type _ $size>]() { let mut rng = StdRng::seed_from_u64(2); let a_data = random_vec_rng($size, &mut rng); let b_data = random_vec_rng($size, &mut rng); @@ -56,7 +69,7 @@ macro_rules! single_binary_test { let f: fn(GraphTensor>, GraphTensor>) -> GraphTensor> = $luminal_func; let mut c = f(a, b).retrieve(); - cx.compile(CudaCompiler::<$type>::default(), &mut c); + cx.compile($crate::CudaCompiler::<$type>::default(), &mut c); cx.execute(); let d_dev = Cpu::default(); @@ -72,14 +85,14 @@ macro_rules! single_binary_test { ) -> dfdx::prelude::Tensor, $type, Cpu, NoneTape> = $dfdx_func; let d_c = f(d_a, d_b); - assert_close(&c.data(), &d_c.to_dtype::().as_vec()); + luminal::tests::assert_close(&c.data(), &d_c.to_dtype::().as_vec()); } } }; } #[macro_export] -macro_rules! binary_test { +macro_rules! binary_test_type { ($luminal_func: expr , $dfdx_func: expr , $name: ident, $type: ty) => { $crate::single_binary_test!($luminal_func, $dfdx_func, $name, $type, 3); $crate::single_binary_test!($luminal_func, $dfdx_func, $name, $type, 50); @@ -87,3 +100,51 @@ macro_rules! binary_test { $crate::single_binary_test!($luminal_func, $dfdx_func, $name, $type, 4096); }; } + +#[macro_export] +macro_rules! binary_test { + ($luminal_func: expr , $dfdx_func: expr , $name: ident) => { + $crate::binary_test_type!($luminal_func, $dfdx_func, $name, f32); + $crate::binary_test_type!($luminal_func, $dfdx_func, $name, f16); + }; +} + +pub fn assert_op_in_graph(graph: &Graph) { + assert!( + graph.node_indices().any(|i| graph.check_node_type::(i)), + "Node not found in the graph!" + ); +} + +unary_test!(|a| a.sin(), |a| a.sin(), test_sin); +unary_test!(|a| a.sqrt(), |a| a.sqrt(), test_sqrt); +unary_test!(|a| a.recip(), |a| a.recip(), test_recip); +unary_test!(|a| a * a, |a| a.clone() * a, test_square); +unary_test!(|a| a.exp(), |a| a.exp(), test_exp); +unary_test!(|a| a.cos(), |a| a.cos(), test_cos); +unary_test!(|a| a.softmax(), |a| a.softmax(), test_softmax); +unary_test!( + |a| a.mean_norm::>(), + |a| a.clone() - a.mean::<_, dfdx::prelude::Axis<0>>().broadcast(), + test_mean_norm +); +unary_test!( + |a| a.std_norm::, _>(1e-5), + |a| a.clone() / a.stddev::<_, dfdx::prelude::Axis<0>>(1e-5).broadcast(), + test_std_norm +); +unary_test!( + |a| a.layer_norm::, _>(1e-5), + |a| a.normalize::>(1e-5), + test_norm +); + +binary_test!(|a, b| a + b, |a, b| a + b, test_add); +binary_test!(|a, b| a - b, |a, b| a - b, test_sub); +binary_test!(|a, b| a * b, |a, b| a * b, test_mul); +binary_test!(|a, b| a / b, |a, b| a * b.recip(), test_div); +binary_test!(|a, b| a.max(b), |a, b| a.maximum(b), test_max); +binary_test!(|a, b| a.min(b), |a, b| a.minimum(b), test_min); + +single_unary_test!(|a| a.ln(), |a| a.ln(), test_ln, f16, 3); // For some reason ln fails on larger tensors +single_unary_test!(|a| a.ln(), |a| a.ln(), test_ln, f32, 3); // For some reason ln fails on larger tensors diff --git a/crates/luminal_cuda/src/unary.rs b/crates/luminal_cuda/src/unary.rs new file mode 100644 index 00000000..55753f22 --- /dev/null +++ b/crates/luminal_cuda/src/unary.rs @@ -0,0 +1,781 @@ +use luminal_cudarc::driver::{CudaDevice, CudaFunction, DeviceRepr, LaunchAsync, LaunchConfig}; +use num_traits::float::FloatConst; +use rustc_hash::FxHashMap; +use std::{any::Any, marker::PhantomData, mem::size_of, sync::Arc}; + +use petgraph::visit::EdgeRef; + +use luminal::{ + op::{ConstantValue, InputTensor, Operator}, + prelude::*, +}; + +use crate::{ + binary::CudaSub, + compile_and_load_kernel, constant, get_buffer_from_tensor, get_idx_valid_exps, input_dyn_dims, + prim::{ + CudaAdd, CudaConstant, CudaContiguous, CudaExp2, CudaMaxReduce, CudaMul, CudaRecip, + CudaSin, CudaSqrt, CudaSumReduce, + }, + render_dyn_dim_inputs, CudaData, CudaFloat, +}; + +/// Special kernel for efficient mean reduction +#[derive(Clone)] +pub struct CudaMeanReduce { + function: CudaFunction, + device: Arc, + pub dim: usize, + pub dyn_symbols: Vec, + pub dyn_map: *const FxHashMap, + _phantom: PhantomData, +} +crate::debug_type!(CudaMeanReduce); + +impl PartialEq for CudaMeanReduce { + fn eq(&self, other: &Self) -> bool { + self.dim == other.dim + } +} + +impl CudaMeanReduce { + fn new( + dev: Arc, + dim: usize, + shape: ShapeTracker, + dyn_map: *const FxHashMap, + ) -> Self { + let (idx_exp, valid_exp) = get_idx_valid_exps(shape); + let (dyn_symbols, rendered) = render_dyn_dim_inputs(&[shape]); + let type_name = T::type_name(); + let mut code = format!(" +#include \"cuda_fp16.h\" +extern \"C\" __global__ void kernel(const {type_name} *inp, {type_name} *out, int n_elements, int front_size, int back_size, int dim_size{rendered}) {{ + int i_ = blockIdx.x * blockDim.x + threadIdx.x; + if (i_ < n_elements) {{ + int a_ = i_ / back_size; + int b_ = i_ % back_size; + float reduce_value = 0.0; + for (int c_ = 0; c_ < dim_size; c_++) {{ + int idx = a_ * dim_size * back_size + c_ * back_size + b_; + if (({valid_exp}) != 0) {{ + reduce_value += (float)inp[{idx_exp}]; + }} + }} + out[i_] = ({type_name})(reduce_value / (float)dim_size); + }} +}}"); + code = code.replace("mkernel", "kernel_mean_reduce"); + + Self { + function: compile_and_load_kernel(code, &dev), + device: dev, + dim, + dyn_symbols, + dyn_map, + _phantom: Default::default(), + } + } +} + +impl Operator for CudaMeanReduce { + fn process(&mut self, tensors: Vec<(InputTensor, ShapeTracker)>) -> Vec { + // Setup buffers + let mut sh = tensors[0].1; + sh.remove_dim(self.dim); + let inp_size = sh.n_elements().to_usize().unwrap(); + let inp_size_int = inp_size as i32; + let out = self.device.alloc_zeros::(inp_size).unwrap(); + let front_size = tensors[0] + .1 + .shape() + .iter() + .take(self.dim) + .map(|i| i.to_usize().unwrap()) + .product::() as i32; + let back_size = tensors[0] + .1 + .shape() + .iter() + .skip(self.dim + 1) + .map(|i| i.to_usize().unwrap()) + .product::() as i32; + let dim_size = tensors[0].1.shape()[self.dim].to_usize().unwrap() as i32; + let mut params = vec![ + get_buffer_from_tensor::(&tensors[0].0).as_kernel_param(), + (&out).as_kernel_param(), + inp_size_int.as_kernel_param(), + front_size.as_kernel_param(), + back_size.as_kernel_param(), + dim_size.as_kernel_param(), + ]; + input_dyn_dims(&mut params, &self.dyn_symbols, self.dyn_map); + unsafe { + self.function + .clone() + .launch(LaunchConfig::for_num_elems(inp_size as u32), &mut params) + .unwrap(); + } + vec![Tensor::new(CudaData(out))] + } +} + +/// Replace the mean reduce pattern with a special kernel. This is meant to be ran **after** the FakeSumReduceCompiler. +#[derive(Default, Debug)] +pub struct MeanReduceCompiler(PhantomData); + +impl Compiler for MeanReduceCompiler { + type Output = (); + fn compile(&self, graph: &mut Graph, mut ids: To) { + let dev = CudaDevice::new(0).unwrap(); + // Look for the mean-reduce pattern + // mul(recip(fake_sum_reduce(const_ones)), sum_reduce(x)) + let fake_sum_reduce = op::>(); + let sum_reduce = op::>(); + let mul = binary::>( + sum_reduce.clone(), + unary::>(fake_sum_reduce.clone()), + ); + let mut s = mul.clone().search(graph); + while s.next_match() { + if s.check_no_delete(&[mul.id]) { + // An intermediate node can't be deleted + continue; + } + let (sum_reduce, mul) = (s.get(&sum_reduce), s.get(&mul)); + let dim = graph.get_op::>(sum_reduce).dim; + // Insert MeanReduce op + let src = graph.get_sources(sum_reduce)[0]; + let mean_reduce = graph + .add_op(CudaMeanReduce::::new( + dev.clone(), + dim, + src.2, + &graph.dyn_map, + )) + .input(src.0, 0, src.2) + .finish(); + + // Create edges to dests + move_outgoing_edge(mul, mean_reduce, graph); + remap(mul, mean_reduce, &mut ids, graph); + + // Remove the old ops + graph.remove_node(mul); + s.try_delete(); + } + } +} + +/// Special kernel for efficient std norming +#[derive(Clone)] +pub struct CudaStdNorm { + function: CudaFunction, + device: Arc, + epsilon: f32, // Epsilon + _phantom: PhantomData, +} +crate::debug_type!(CudaStdNorm); + +impl PartialEq for CudaStdNorm { + fn eq(&self, other: &Self) -> bool { + self.epsilon == other.epsilon + } +} + +impl CudaStdNorm { + fn new(epsilon: f32, device: Arc) -> Self { + let type_name = T::type_name(); + let kernel_code = format!(" +#include \"cuda_fp16.h\" +typedef struct __align__(8) {{ + __half x; + __half y; + __half z; + __half w; + }} __half4; +__device__ float warp_sum(float val) {{ + const unsigned int mask = 0xffffffff; + + for (int offset = warpSize / 2; offset > 0; offset /= 2) {{ + val += __shfl_down_sync(mask, val, offset); + }} + + return __shfl_sync(mask, val, 0); +}} +extern \"C\" __global__ void kernel(const {type_name} * src0, {type_name} * dst, const int row_size, const float eps) {{ + int threadgroup_position_in_grid = blockIdx.x; + int thread_position_in_threadgroup = threadIdx.x; + int simdgroup_index_in_threadgroup = thread_position_in_threadgroup / 32; // 32 threads in warp + int thread_index_in_simdgroup = thread_position_in_threadgroup % 32; + int threads_per_threadgroup = blockDim.x; + + extern __shared__ float buf[]; + const {type_name}4 * x = (const {type_name}4 *) (src0 + threadgroup_position_in_grid * row_size); + + float sumf = 0.; + + // parallel sum + for (int i = thread_position_in_threadgroup; i < row_size/4; i += threads_per_threadgroup) {{ + sumf += (float)x[i].x * (float)x[i].x; + sumf += (float)x[i].y * (float)x[i].y; + sumf += (float)x[i].z * (float)x[i].z; + sumf += (float)x[i].w * (float)x[i].w; + }} + float all_sum = sumf; + all_sum = warp_sum(all_sum); + + if (threads_per_threadgroup > 32) {{ + if (simdgroup_index_in_threadgroup == 0) {{ + buf[thread_index_in_simdgroup] = 0.0f; + }} + + __syncthreads(); + + if (thread_index_in_simdgroup == 0) {{ + buf[simdgroup_index_in_threadgroup] = all_sum; + }} + + __syncthreads(); + + all_sum = buf[thread_index_in_simdgroup]; + all_sum = warp_sum(all_sum); + }} + + const float mean = all_sum / row_size; + const float scale = rsqrt(mean + eps); + + {type_name}4 * y = ({type_name}4 *) (dst + threadgroup_position_in_grid * row_size); + for (int i = thread_position_in_threadgroup; i < row_size/4; i += threads_per_threadgroup) {{ + y[i].x = ({type_name})((float)x[i].x * scale); + y[i].y = ({type_name})((float)x[i].y * scale); + y[i].z = ({type_name})((float)x[i].z * scale); + y[i].w = ({type_name})((float)x[i].w * scale); + }} +}}"); + + Self { + function: compile_and_load_kernel(kernel_code, &device), + device, + epsilon, + _phantom: Default::default(), + } + } +} + +impl Operator for CudaStdNorm { + fn process(&mut self, tensors: Vec<(InputTensor, ShapeTracker)>) -> Vec { + let row_size = tensors[0].1.shape().last().unwrap().to_usize().unwrap(); + let row_size_int = row_size as i32; + let out = self + .device + .alloc_zeros::(tensors[0].1.n_elements().to_usize().unwrap()) + .unwrap(); + let mut params = vec![ + get_buffer_from_tensor::(&tensors[0].0).as_kernel_param(), + (&out).as_kernel_param(), + row_size_int.as_kernel_param(), + self.epsilon.as_kernel_param(), + ]; + let batch_size = tensors[0] + .1 + .shape() + .into_iter() + .take(tensors[0].1.len() - 1) + .map(|i| i.to_usize().unwrap()) + .product::(); + let mut nth = 32; // SIMD width + while nth < row_size / 4 && nth < 1024 { + nth *= 2; + } + unsafe { + self.function + .clone() + .launch( + LaunchConfig { + grid_dim: (batch_size as u32, 1, 1), + block_dim: (nth as u32, 1, 1), + shared_mem_bytes: 32 * size_of::() as u32, + }, + &mut params, + ) + .unwrap(); + } + + vec![Tensor::new(CudaData(out))] + } +} + +/// Replace the mean reduce pattern with a special kernel. This is meant to be ran **after** the FakeSumReduceCompiler. +#[derive(Default, Debug)] +pub struct StdNormCompiler(PhantomData); + +impl Compiler for StdNormCompiler { + type Output = (); + fn compile(&self, graph: &mut Graph, mut ids: To) { + let dev = CudaDevice::new(0).unwrap(); + // Look for the RMSNorm pattern + // mul(recip(sqrt(add(mean_reduce(mul(x, x)), 1e-6))), x) + + let mut eps = op::>(); + eps.check(|op, _| { + if let Some(c) = op.as_any().downcast_ref::>() { + if let ConstantValue::Float(v) = c.value { + v <= 1e-2 && v > 0.0 + } else { + false + } + } else { + false + } + }); + let inp = node(); + let square = unary::>(inp.clone()); // This should check both inputs! For some reason doesn't work + let mean = unary::>(square.clone()); + let add = binary::>(mean.clone(), eps.clone()); + let mul = unary::>(unary::>(unary::>(add.clone()))); + + let mut s = mul.clone().search(graph); + while s.next_match() { + if s.check_no_delete(&[mul.id, inp.id]) { + // An intermediate node can't be deleted + continue; + } + let ConstantValue::Float(epsilon_num) = + graph.get_op::>(s.get(&eps)).value + else { + continue; + }; + let (mut x, _, mut sh) = graph.get_sources(s.get(&square))[0]; + if let Some(mean_reduce) = graph.try_get_op::>(s.get(&mean)) { + if mean_reduce.dim != sh.len() - 1 { + continue; + } + } + if sh + .shape() + .last() + .unwrap() + .to_usize() + .map(|i| i % 32 != 0 || i < 32) + .unwrap_or(true) + { + continue; + } + if !graph + .get_sources(s.get(&mul)) + .iter() + .any(|(i, _, _)| *i == x) + { + continue; + } + + // Input must be contiguous + if sh.is_reshaped() { + x = graph + .add_op(CudaContiguous::::new(sh, dev.clone(), &graph.dyn_map)) + .input(x, 0, sh) + .finish(); + sh = sh.contiguous(); + } + + // Insert RMSNorm op + let rms_norm = graph + .add_op(CudaStdNorm::::new(epsilon_num, dev.clone())) + .input(x, 0, sh) + .finish(); + + // Create edges to dests + let mul = s.get(&mul); + move_outgoing_edge(mul, rms_norm, graph); + remap(mul, rms_norm, &mut ids, graph); + + // Remove the old ops + graph.remove_node(mul); + s.try_delete(); + } + } +} + +#[derive(Clone)] +pub struct CudaExp { + function: CudaFunction, + device: Arc, + _phantom: PhantomData, +} +crate::debug_type!(CudaExp); + +impl CudaExp { + fn new(device: Arc) -> Self { + let type_name = T::type_name(); + Self { + function: compile_and_load_kernel( + format!( + "#include \"cuda_fp16.h\" +extern \"C\" __global__ void kernel({type_name} *out, const {type_name} *inp, int numel) {{ + int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < numel) {{ + out[i] = exp(inp[i]); + }} +}}" + ), + &device, + ), + device, + _phantom: Default::default(), + } + } +} + +impl Operator for CudaExp { + fn process(&mut self, tensors: Vec<(InputTensor, ShapeTracker)>) -> Vec { + let inp = get_buffer_from_tensor::(&tensors[0].0); + let inp_size = tensors[0].1.n_physical_elements().to_usize().unwrap(); + let mut out = self.device.alloc_zeros::(inp_size).unwrap(); + unsafe { + self.function + .clone() + .launch( + LaunchConfig::for_num_elems(inp_size as u32), + (&mut out, inp, inp_size), + ) + .unwrap(); + } + + vec![Tensor::new(CudaData(out))] + } + + fn custom(&mut self, key: &str, _: Box) -> Option> { + if key == "elementwise" { + return Some(Box::new("exp(input0)".to_string())); + } + + None + } +} + +#[derive(Default, Debug)] +pub struct CudaExpCompiler(PhantomData); + +impl Compiler for CudaExpCompiler { + type Output = (); + fn compile(&self, graph: &mut Graph, mut ids: To) { + let dev = CudaDevice::new(0).unwrap(); + // Look for the exp pattern + // exp2(mul(x, const)) + + let inp = node(); + let mul = binary::>(inp.clone(), constant::(1.0 / f32::ln(2.))); + let exp2 = unary::>(mul.clone()); + let mut s = exp2.clone().search(graph); + while s.next_match() { + if s.check_no_delete(&[exp2.id]) { + // An intermediate node can't be deleted + continue; + } + + // Insert exp op + let (_, _, src_shape) = graph + .edges_connecting(s.get(&inp), s.get(&mul)) + .next() + .unwrap() + .weight() + .as_data() + .unwrap(); + let exp = graph + .add_op(CudaExp::::new(dev.clone())) + .input(s.get(&inp), 0, src_shape) + .finish(); + + // Create edges to dests + let exp2 = s.get(&exp2); + move_outgoing_edge(exp2, exp, graph); + remap(exp2, exp, &mut ids, graph); + + // Remove the old ops + graph.remove_node(exp2); + s.try_delete(); + } + } +} + +/// Special kernel for cos +#[derive(Clone)] +pub struct CudaCos { + function: CudaFunction, + device: Arc, + _phantom: PhantomData, +} +crate::debug_type!(CudaCos); + +impl CudaCos { + fn new(device: Arc) -> Self { + let type_name = T::type_name(); + Self { + function: compile_and_load_kernel( + format!( + "#include \"cuda_fp16.h\" +extern \"C\" __global__ void kernel({type_name} *out, const {type_name} *inp, int numel) {{ + int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < numel) {{ + out[i] = cos(inp[i]); + }} +}}" + ), + &device, + ), + device, + _phantom: Default::default(), + } + } +} +impl Operator for CudaCos { + fn process(&mut self, tensors: Vec<(InputTensor, ShapeTracker)>) -> Vec { + let inp = get_buffer_from_tensor::(&tensors[0].0); + let inp_size = tensors[0].1.n_physical_elements().to_usize().unwrap(); + let mut out = self.device.alloc_zeros::(inp_size).unwrap(); + unsafe { + self.function + .clone() + .launch( + LaunchConfig::for_num_elems(inp_size as u32), + (&mut out, inp, inp_size), + ) + .unwrap(); + } + + vec![Tensor::new(CudaData(out))] + } + + fn custom(&mut self, key: &str, _: Box) -> Option> { + if key == "elementwise" { + return Some(Box::new("cos(input0)".to_string())); + } + + None + } +} + +#[derive(Default, Debug)] +pub struct CudaCosCompiler(PhantomData); + +impl Compiler for CudaCosCompiler { + type Output = (); + fn compile(&self, graph: &mut Graph, mut ids: To) { + let dev = CudaDevice::new(0).unwrap(); + // Look for the cos pattern + // sin(add(mul(const_neg_one, x), const_pi_over_2)) + + let const_pi = constant::(f32::PI() / 2.); + let inp = node(); + let sub = binary::>(inp.clone(), const_pi.clone()); + let sin = unary::>(sub.clone()); + let mut s = sin.clone().search(graph); + while s.next_match() { + if s.check_no_delete(&[sin.id]) { + // An intermediate node can't be deleted + continue; + } + + // Insert cos op + let shape = graph + .edges_directed(s.get(&sub), petgraph::Direction::Incoming) + .filter(|e| !e.weight().is_schedule()) + .find(|e| e.source() != s.get(&const_pi)) + .unwrap() + .weight() + .as_data() + .unwrap() + .2; + let cos = graph + .add_op(CudaCos::::new(dev.clone())) + .input(s.get(&inp), 0, shape) + .finish(); + + // Create edges to dests + let sin = s.get(&sin); + move_outgoing_edge(sin, cos, graph); + remap(sin, cos, &mut ids, graph); + + // Remove the old ops + graph.remove_node(sin); + s.try_delete(); + } + } +} + +/// Special kernel for efficient softmax. Currently only works on the last dim +#[derive(Clone)] +pub struct CudaSoftmax { + function: CudaFunction, + device: Arc, + _phantom: PhantomData, +} +crate::debug_type!(CudaSoftmax); + +impl CudaSoftmax { + fn new(device: Arc) -> Self { + let type_name = T::type_name(); + Self { + function: compile_and_load_kernel( + format!( + " +#include \"cuda_fp16.h\" +extern \"C\" __global__ void kernel(const {type_name} * x, {type_name} * dst, const int ncols) {{ + const int row = blockDim.x*blockIdx.x + threadIdx.x; + const int block_size = blockDim.y; + const int tid = threadIdx.y; + + {type_name} max_val = -__int_as_float(0x7f800000); + + for (int col = tid; col < ncols; col += block_size) {{ + const int i = row*ncols + col; + max_val = fmaxf(max_val, x[i]); + }} + + // find the max value in the block + #pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) {{ + max_val = fmaxf(max_val, __shfl_xor_sync(0xffffffff, max_val, mask, 32)); + }} + + {type_name} tmp = 0.; + + for (int col = tid; col < ncols; col += block_size) {{ + const int i = row*ncols + col; + const {type_name} val = exp(x[i] - max_val); + tmp += static_cast<{type_name}>(val); + dst[i] = val; + }} + + // sum up partial sums + #pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) {{ + tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); + }} + + const {type_name} inv_tmp = ({type_name})1. / tmp; + + for (int col = tid; col < ncols; col += block_size) {{ + const int i = row*ncols + col; + dst[i] *= inv_tmp; + }} +}} +", + ), + &device, + ), + device, + _phantom: Default::default(), + } + } +} + +impl Operator for CudaSoftmax { + fn process(&mut self, tensors: Vec<(InputTensor, ShapeTracker)>) -> Vec { + // Setup buffers + let inp_size = tensors[0].1.n_elements().to_usize().unwrap(); + let batch_size = tensors[0] + .1 + .shape() + .iter() + .take(tensors[0].1.len() - 1) + .map(|i| i.to_usize().unwrap()) + .product::() + .max(1); + let axis_size = tensors[0].1.shape().last().unwrap().to_usize().unwrap(); + let axis_size_int = axis_size as i32; + let out = self.device.alloc_zeros::(inp_size).unwrap(); + + let mut params = vec![ + get_buffer_from_tensor::(&tensors[0].0).as_kernel_param(), + (&out).as_kernel_param(), + axis_size_int.as_kernel_param(), + ]; + unsafe { + self.function + .clone() + .launch( + LaunchConfig { + grid_dim: (batch_size as u32, 1, 1), + block_dim: (1, 32, 1), + shared_mem_bytes: 0, + }, + &mut params, + ) + .unwrap(); + } + + vec![Tensor::new(CudaData(out))] + } +} + +/// Replace the softmax pattern with a special kernel. +#[derive(Default, Debug)] +pub struct SoftmaxCompiler(PhantomData); + +impl Compiler for SoftmaxCompiler { + type Output = (); + fn compile(&self, graph: &mut Graph, mut ids: To) { + let dev = CudaDevice::new(0).unwrap(); + // Look for the mean-reduce pattern + // mul(recip(fake_sum_reduce(const_ones)), sum_reduce(x)) + + let max_reduce = op::>(); + let mul = unary::>(unary::>(unary::>(unary::< + CudaExp, + >( + unary::>(max_reduce.clone()), + )))); + + let mut s = mul.clone().search(graph); + while s.next_match() { + if s.check_no_delete(&[mul.id]) { + // An intermediate node can't be deleted + continue; + } + // Insert Softmax op + let src = graph.get_sources(s.get(&max_reduce))[0]; + let mean_reduce = graph + .add_op(CudaSoftmax::::new(dev.clone())) + .input(src.0, 0, src.2) + .finish(); + + // Create edges to dests + let mul = s.get(&mul); + move_outgoing_edge(mul, mean_reduce, graph); + remap(mul, mean_reduce, &mut ids, graph); + + // Remove the old ops + graph.remove_node(mul); + s.try_delete(); + } + } +} + +#[cfg(test)] +mod tests { + use luminal::prelude::*; + + use crate::tests::assert_op_in_graph; + + use super::{CudaMeanReduce, CudaStdNorm}; + #[test] + fn test_norms() { + let mut cx = Graph::new(); + let a = cx.tensor().set([0.; 32]); + let mut b = a.layer_norm::, _>(1e-5).retrieve(); + + cx.compile( + <( + GenericCompiler, + crate::prim::PrimitiveCompiler, + crate::SpecialOpsCompiler, + )>::default(), + &mut b, + ); + + assert_op_in_graph::>(&cx); + assert_op_in_graph::>(&cx); + } +} diff --git a/examples/phi/src/main.rs b/examples/phi/src/main.rs index 05585351..c2b7499b 100644 --- a/examples/phi/src/main.rs +++ b/examples/phi/src/main.rs @@ -81,7 +81,6 @@ fn main() { &mut model_weights, ), ); - cx.display(); let cache_src_set = downstream(&cache_src, &cx); let cache_dest_set = cache_dest.to_ids(); println!("\t\t - {}ms", now.elapsed().as_millis()); From 003c824a02ff6666533d44cbd0c16a666f46307f Mon Sep 17 00:00:00 2001 From: Joe Fioti Date: Sun, 28 Apr 2024 15:59:02 -0500 Subject: [PATCH 08/12] Added elementwise fusion to cuda --- crates/luminal_cuda/Cargo.toml | 1 + crates/luminal_cuda/src/binary.rs | 22 +- crates/luminal_cuda/src/elementwise_fusion.rs | 1056 +++++++++++++++++ crates/luminal_cuda/src/lib.rs | 18 +- crates/luminal_cuda/src/matmul.rs | 3 +- crates/luminal_cuda/src/other.rs | 2 +- crates/luminal_cuda/src/prim.rs | 44 +- crates/luminal_cuda/src/quantized.rs | 4 +- crates/luminal_cuda/src/unary.rs | 10 +- 9 files changed, 1124 insertions(+), 36 deletions(-) create mode 100644 crates/luminal_cuda/src/elementwise_fusion.rs diff --git a/crates/luminal_cuda/Cargo.toml b/crates/luminal_cuda/Cargo.toml index 25c58bd0..b34a83d1 100644 --- a/crates/luminal_cuda/Cargo.toml +++ b/crates/luminal_cuda/Cargo.toml @@ -16,6 +16,7 @@ luminal_cudarc = { version="0.10.0", features = [ itertools = "0.12.1" rustc-hash = "1.1.0" num-traits = "0.2.18" +regex = "1.10.4" [dev-dependencies] dfdx = { version = "0.13", features = ["f16"] } diff --git a/crates/luminal_cuda/src/binary.rs b/crates/luminal_cuda/src/binary.rs index 79fdde3c..685c2e7e 100644 --- a/crates/luminal_cuda/src/binary.rs +++ b/crates/luminal_cuda/src/binary.rs @@ -1,4 +1,4 @@ -use std::{marker::PhantomData, sync::Arc}; +use std::{any::Any, marker::PhantomData, sync::Arc}; use luminal_cudarc::driver::{CudaDevice, CudaFunction, DeviceRepr, LaunchAsync, LaunchConfig}; @@ -23,7 +23,7 @@ pub struct CudaSub { dyn_map: *const FxHashMap, _phantom: PhantomData, } -crate::debug_type!(CudaSub); +crate::debug_type!(CudaSub); impl CudaSub { pub fn new( @@ -80,6 +80,13 @@ impl Operator for CudaSub { vec![Tensor::new(CudaData(out))] } + + fn custom(&mut self, key: &str, _: Box) -> Option> { + if key == "elementwise" { + return Some(Box::new("input0 - input1".to_string())); + } + None + } } #[derive(Debug, Default)] @@ -148,7 +155,7 @@ pub struct CudaEqual { dyn_map: *const FxHashMap, _phantom: PhantomData, } -crate::debug_type!(CudaEqual); +crate::debug_type!(CudaEqual); impl CudaEqual { pub fn new( @@ -205,6 +212,13 @@ impl Operator for CudaEqual { vec![Tensor::new(CudaData(out))] } + + fn custom(&mut self, key: &str, _: Box) -> Option> { + if key == "elementwise" { + return Some(Box::new("(float)(input0 == input1)".to_string())); + } + None + } } #[derive(Debug, Default)] @@ -271,7 +285,7 @@ pub struct CudaGather { pub embed_dim: usize, _phantom: PhantomData, } -crate::debug_type!(CudaGather); +crate::debug_type!(CudaGather); impl CudaGather { pub fn new(device: Arc, embed_dim: usize) -> Self { diff --git a/crates/luminal_cuda/src/elementwise_fusion.rs b/crates/luminal_cuda/src/elementwise_fusion.rs new file mode 100644 index 00000000..8682db4f --- /dev/null +++ b/crates/luminal_cuda/src/elementwise_fusion.rs @@ -0,0 +1,1056 @@ +use luminal_cudarc::driver::{CudaDevice, CudaFunction, DeviceRepr, LaunchAsync, LaunchConfig}; +use regex::Regex; +use rustc_hash::{FxHashMap, FxHashSet}; +use std::{any::Any, fmt::Debug, iter::once, marker::PhantomData, mem::size_of, sync::Arc}; + +use itertools::Itertools; +use luminal::prelude::{ + petgraph::{visit::EdgeRef, Direction}, + *, +}; + +use crate::{ + compile_and_load_kernel, expr_to_cuda_string, get_buffer_from_tensor, prim::CudaConstant, + CudaData, CudaFloat, +}; + +use super::{input_dyn_dims, render_dyn_dim_inputs}; + +#[derive(Default, Debug)] +pub struct ElementwiseFusionCompiler(PhantomData); + +fn get_inputs(node: NodeIndex, graph: &Graph) -> Vec<(NodeIndex, u8, ShapeTracker)> { + graph + .edges_directed(node, Direction::Incoming) + .filter_map(|e| e.weight().as_data().map(|i| (e.source(), i))) + .sorted_by_key(|(_, i)| i.0) + .map(|(a, (_, b, c))| (a, b, c)) + .collect() +} + +// Check if we stack the views, does more than one view exist for one of a set of given inputs +fn is_more_than_one_view( + subexpressions: &[(String, ShapeTracker)], + subexp_indexes: &[usize], +) -> bool { + let intermediate_match = Regex::new(r"intermediate(\d+)").unwrap(); + let mut subexp_views = subexpressions + .iter() + .map(|(_, sh)| vec![*sh]) + .collect::>(); + for i in (0..subexp_views.len()).rev() { + for capture in intermediate_match.captures_iter(&subexpressions[i].0) { + let index = capture.get(1).unwrap().as_str().parse::().unwrap(); + if subexp_views[index].len() == 1 { + let v = subexp_views[i].clone(); + subexp_views[index].extend(v); + } + } + } + if !subexpressions + .iter() + .positions(|(s, _)| { + subexp_indexes + .iter() + .any(|i| s.contains(&format!("input{i}"))) + }) + .map(|subexp_index| &subexp_views[subexp_index]) + .all_equal() + { + return true; + } + false +} + +impl Compiler for ElementwiseFusionCompiler { + type Output = (); + fn compile(&self, graph: &mut Graph, mut ids: To) { + let device = CudaDevice::new(0).unwrap(); + // Track fused ops to compile later + let mut fused_ops = FxHashSet::default(); + + let mut matched = true; + let mut elementwise_ops = FxHashMap::default(); + for op in graph.node_indices().collect::>() { + if let Some(exp) = graph.node_custom::(op, "elementwise", ()) { + elementwise_ops.insert(op, exp); + } + } + let mut intermediate_regexes = FxHashMap::default(); + let mut input_regexes = FxHashMap::default(); + while matched { + matched = false; + for edge in graph.edge_indices().collect::>() { + let Some((a, b)) = graph.edge_endpoints(edge) else { + continue; + }; + if graph.no_delete.contains(&a) + || graph.no_delete.contains(&b) + || (!graph.check_node_type::>(a) + && graph + .edges_directed(a, Direction::Outgoing) + .filter(|e| e.target() != b) + .count() + > 0) + { + continue; // A is not a constant and is feeding into some other node + } + let (Some(expression_a), Some(expression_b)) = + (elementwise_ops.get(&a), elementwise_ops.get(&b)) + else { + continue; + }; + // a and b are elementwise ops + // Make sure all edges from a to b share the same shape + if !graph + .edges_connecting(a, b) + .map(|e| e.weight().as_data().unwrap().2) + .all_equal() + { + continue; + } + // Check if there are more than one view of this input. If so, we can't merge + let mut subexpressions_b = graph + .try_get_op::>(b) + .map(|o| o.subexpressions.clone()) + .unwrap_or_else(|| vec![(expression_b.clone(), ShapeTracker::new(&[]))]); + let a_to_b_indexes = graph + .edges_connecting(a, b) + .map(|e| e.weight().as_data().unwrap().0 as usize) + .sorted() + .collect::>(); + if is_more_than_one_view(&subexpressions_b, &a_to_b_indexes) { + continue; + } + matched = true; + let a_inputs = get_inputs(a, graph); + let mut b_inputs = get_inputs(b, graph); + let (_, _, connecting_shape) = b_inputs.remove(*a_to_b_indexes.last().unwrap()); + for i in a_to_b_indexes.iter().take(a_to_b_indexes.len() - 1).rev() { + b_inputs.remove(*i); + } + // Get subexpressions + let mut subexpressions_a = graph + .try_get_op::>(a) + .map(|o| o.subexpressions.clone()) + .unwrap_or_else(|| vec![(expression_a.clone(), ShapeTracker::new(&[]))]); + subexpressions_a.last_mut().unwrap().1 = connecting_shape; + // Re-reference b intermediates + for i in (0..subexpressions_b.len()).rev() { + let re = if let Some(r) = intermediate_regexes.get(&i) { + r + } else { + intermediate_regexes.insert( + i, + Regex::new(&format!(r"intermediate{i}([^0-9]|$)")).unwrap(), + ); + intermediate_regexes.get(&i).unwrap() + }; + for (exp, _) in subexpressions_b.iter_mut() { + *exp = re + .replace_all( + exp, + format!("intermediate{}$1", i + subexpressions_a.len()), + ) + .to_string(); + } + } + // Re-reference b inputs to a + for index in &a_to_b_indexes { + let re = if let Some(r) = input_regexes.get(index) { + r + } else { + input_regexes.insert( + *index, + Regex::new(&format!(r"input{index}([^0-9]|$)")).unwrap(), + ); + input_regexes.get(index).unwrap() + }; + for (exp, _) in subexpressions_b.iter_mut() { + *exp = re + .replace_all( + exp, + format!("intermediate{}$1", subexpressions_a.len() - 1), + ) + .to_string(); + } + } + // Re-reference b inputs + for (sub_factor, index) in a_to_b_indexes.iter().enumerate() { + for i in (*index - sub_factor + 1)..(b_inputs.len() + a_to_b_indexes.len()) { + let re = if let Some(r) = input_regexes.get(&i) { + r + } else { + input_regexes + .insert(i, Regex::new(&format!(r"input{i}([^0-9]|$)")).unwrap()); + input_regexes.get(&i).unwrap() + }; + for (exp, _) in subexpressions_b.iter_mut() { + *exp = re.replace_all(exp, format!("input{}$1", i - 1)).to_string(); + } + } + } + // Combine inputs for a and b + for i in (0..a_inputs.len()).rev() { + // Re-reference the a inputs + let re = if let Some(r) = input_regexes.get(&i) { + r + } else { + input_regexes + .insert(i, Regex::new(&format!(r"input{i}([^0-9]|$)")).unwrap()); + input_regexes.get(&i).unwrap() + }; + for (exp, _) in subexpressions_a.iter_mut() { + *exp = re + .replace_all(exp, format!("input{}$1", i + b_inputs.len())) + .to_string(); + } + } + b_inputs.extend(a_inputs); + // a intermediates should remain valid + // Combine subexpressions + for subexp in subexpressions_a.into_iter().rev() { + subexpressions_b.insert(0, subexp); + } + // Create new fused op + let output_buffer_sizes = + if let Some(o) = graph.try_get_op::>(b) { + o.output_buffer_sizes.clone() + } else { + vec![ + graph + .edges_directed(b, Direction::Incoming) + .filter_map(|e| e.weight().as_data().map(|i| i.2.n_elements())) + .reduce(|acc, e| acc.max(e)) + .unwrap() + * size_of::(), + ] + }; + let new_op = graph + .add_op(FusedElementwiseOp:: { + kernel: None, + dyn_map: &graph.dyn_map, + dyn_chars: vec![], + subexpressions: subexpressions_b.clone(), + device: device.clone(), + output_buffer_sizes, + _phantom: Default::default(), + }) + .finish(); + // Add edges to new op + move_outgoing_edge(b, new_op, graph); + for (i, (node, output, shape)) in b_inputs.into_iter().enumerate() { + graph.add_edge( + node, + new_op, + Dependency::Data { + input_order: i as u8, + output_order: output, + shape, + }, + ); + } + graph.remove_node(b); + graph.safe_remove_node(a, 0); + // Keep track of the fused op so we can compile it later + fused_ops.remove(&a); + fused_ops.remove(&b); + fused_ops.insert(new_op); + elementwise_ops.remove(&a); + elementwise_ops.remove(&b); + elementwise_ops.insert(new_op, String::new()); + if !graph.contains_node(a) { + remap(a, new_op, &mut ids, graph); + } + remap(b, new_op, &mut ids, graph); + } + } + // Compile all the kernels we placed + let type_name = T::type_name(); + let intermediate_match = Regex::new(r"intermediate(\d+)([^0-9]|$)").unwrap(); + for fused_op in fused_ops { + let inputs = graph + .edges_directed(fused_op, Direction::Incoming) + .flat_map(|e| e.weight().as_data()) + .sorted_by_key(|(i, _, _)| *i) + .map(|(_, _, sh)| sh) + .collect::>(); + let op = graph.get_op_mut::>(fused_op); + // Stack index expressions and replace them in the subexpressions + // Track all shapes used, will pull dyn dims from these + let shapes_used = op + .subexpressions + .iter() + .map(|(_, s)| *s) + .chain(inputs.clone()) + .collect::>(); + // Track the views of each subexpression by going in reverse order and appending the current subexpression's views to the referenced subexpression + let mut subexp_views = op + .subexpressions + .iter() + .map(|(_, sh)| vec![*sh]) // Start with the current view for this subexpression + .collect::>(); + for i in (0..subexp_views.len() - 1).rev() { + for capture in intermediate_match.captures_iter(&op.subexpressions[i].0) { + let index = capture.get(1).unwrap().as_str().parse::().unwrap(); + if subexp_views[index].len() == 1 { + let v = subexp_views[i].clone(); + subexp_views[index].extend(v); + } else { + assert_eq!(subexp_views[index][1..], subexp_views[i][..]); + } + } + } + // Stack views for each input by going to the first subexpression that uses it and combining it's stacked shape with the input's shape + let stacked_shapes: Vec> = inputs + .iter() + .enumerate() + .map(|(i, s)| { + // Find the first subexpression that uses this input + let re = if let Some(r) = input_regexes.get(&i) { + r + } else { + input_regexes + .insert(i, Regex::new(&format!(r"input{i}([^0-9]|$)")).unwrap()); + input_regexes.get(&i).unwrap() + }; + let using_subexp = op + .subexpressions + .iter() + .position(|(s, _)| re.is_match(s)) + .unwrap(); + + once(*s) + .chain( + subexp_views[using_subexp] + .iter() + .copied() + .filter(|s| !s.is_empty()), + ) + .collect() + }) + .collect(); + // Stack index expressions + let stacked_index_expressions_partial = stacked_shapes + .iter() + .map(|s| { + s.iter() + .rev() + .take(s.len() - 1) + .fold(BigExpression::from('z'), |acc, inp| { + inp.index_expression().substitute('z', acc) + }) + }) + .collect::>(); + let stacked_index_expressions = stacked_index_expressions_partial + .iter() + .cloned() + .zip(&stacked_shapes) + .map(|(partial, sh)| sh[0].index_expression().substitute('z', partial)) + .collect::>(); + let stacked_valid_expressions = stacked_index_expressions_partial + .iter() + .cloned() + .zip(&stacked_shapes) + .map(|(partial, sh)| sh[0].valid_expression().substitute('z', partial)) + .collect::>(); + + // Replace in subexpressions + let n_subexpressions = op.subexpressions.len(); + for (i, ((subexp, _), stacked_shapes)) in + op.subexpressions.iter_mut().zip(subexp_views).enumerate() + { + // Index + for (i, (ind_exp, val_exp)) in stacked_index_expressions + .iter() + .zip(&stacked_valid_expressions) + .enumerate() + { + let re = if let Some(r) = input_regexes.get(&i) { + r + } else { + input_regexes + .insert(i, Regex::new(&format!(r"input{i}([^0-9]|$)")).unwrap()); + input_regexes.get(&i).unwrap() + }; + *subexp = re + .replace_all( + subexp, + &if *val_exp != true { + format!( + "({} != 0 ? (float)input{i}[{}] : 0.0)$1", + expr_to_cuda_string(val_exp), + expr_to_cuda_string(ind_exp) + ) + } else { + format!("(float)input{i}[{}]$1", expr_to_cuda_string(ind_exp)) + }, + ) + .to_string(); + } + // Valid (not on last subexpression) + if i != n_subexpressions - 1 { + let val_exp = stacked_shapes + .iter() + .rev() + .fold( + (BigExpression::from(true), BigExpression::from('z')), + |(_, ind_acc), inp| { + ( + inp.valid_expression().substitute('z', ind_acc.clone()), + inp.index_expression().substitute('z', ind_acc), + ) + }, + ) + .0; + if val_exp != true { + *subexp = format!( + "(({} != 0) ? {subexp} : 0.0)", + expr_to_cuda_string(&val_exp) + ); + } + } + } + + let (dyn_chars, rendered) = render_dyn_dim_inputs(&shapes_used); + let kernel = format!( + " +#include \"cuda_fp16.h\" +extern \"C\" __global__ void kernel({} {type_name}* out, const int n_elements{rendered}) {{ + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n_elements) {{ + {} + out[idx] = ({type_name})({}); + }} +}}", + (0..inputs.len()) + .map(|inp_ind| format!("const {type_name}* input{inp_ind},")) + .collect::>() + .join(" "), + op.subexpressions + .iter() + .take(op.subexpressions.len() - 1) + .enumerate() + .map(|(i, (subexp, _))| format!("float intermediate{i} = {subexp};")) + .join("\n "), + op.subexpressions.last().unwrap().0 + ); + op.kernel = Some(compile_and_load_kernel(kernel, &device)); + op.dyn_chars = dyn_chars; + } + } +} + +#[derive(Clone)] +pub struct FusedElementwiseOp { + kernel: Option, + dyn_map: *const FxHashMap, + dyn_chars: Vec, + subexpressions: Vec<(String, ShapeTracker)>, + device: Arc, + output_buffer_sizes: Vec, + _phantom: PhantomData, +} +impl Debug for FusedElementwiseOp { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "FusedElementwiseOp") + } +} + +impl Operator for FusedElementwiseOp { + fn process(&mut self, tensors: Vec<(InputTensor, ShapeTracker)>) -> Vec { + let dyn_map = unsafe { self.dyn_map.as_ref().unwrap() }; + let out_size = + self.output_buffer_sizes[0].exec(dyn_map).unwrap() / std::mem::size_of::(); + let out_size_int = out_size as i32; + let out = self.device.alloc_zeros::(out_size).unwrap(); + + let mut params = vec![]; + for (buf, _) in &tensors { + params.push(get_buffer_from_tensor::(buf).as_kernel_param()); + } + params.push((&out).as_kernel_param()); + params.push(out_size_int.as_kernel_param()); + + input_dyn_dims(&mut params, &self.dyn_chars, self.dyn_map); + + unsafe { + self.kernel + .clone() + .unwrap() + .launch(LaunchConfig::for_num_elems(out_size as u32), &mut params) + .unwrap(); + } + + vec![Tensor::new(CudaData(out))] + } + + fn custom(&mut self, key: &str, _: Box) -> Option> { + if key == "elementwise" { + return Some(Box::::default()); + } + None + } +} + +#[cfg(test)] +mod tests { + use luminal::{ + prelude::{binary::F32Pow, *}, + tests::{assert_close, assert_close_precision, random_vec, random_vec_rng}, + }; + use luminal_nn::*; + use rand::{rngs::StdRng, SeedableRng}; + use std::{marker::PhantomData, ops::Div}; + + use crate::CudaCompiler; + + #[test] + fn test_fusion_simple() { + let mut cx = Graph::new(); + let mut rng = StdRng::seed_from_u64(0); + let inp = cx.tensor::>().set(random_vec_rng(10, &mut rng)); + let mut out = inp.exp2().cos().sqrt().retrieve(); + + cx.execute(); + let unopt_out = out.data(); + out.drop(); + + cx.compile(<(GenericCompiler, CudaCompiler)>::default(), &mut out); + cx.execute(); + + assert_close(&out.data(), &unopt_out); + } + #[test] + fn test_fusion_binary() { + let mut cx = Graph::new(); + let mut rng = StdRng::seed_from_u64(0); + let a = cx.tensor::>().set(random_vec_rng(10, &mut rng)); + let b = cx.tensor::>().set(random_vec_rng(10, &mut rng)); + let mut out = (a.exp2() + b.cos()).retrieve(); + + cx.execute(); + let unopt_out = out.data(); + out.drop(); + + cx.compile(<(GenericCompiler, CudaCompiler)>::default(), &mut out); + cx.execute(); + + assert_close(&out.data(), &unopt_out); + } + + #[test] + fn test_fusion_subexpression_complex() { + let mut cx = Graph::new(); + let a = cx.named_tensor::>("a").set(random_vec(10)).keep(); + let b = cx.named_tensor::>("b").set(random_vec(10)).keep(); + let d = cx.named_tensor::>("d").set(random_vec(10)).keep(); + let mut out = ((a.exp2() - b.sin()).sin() * 3.4).less_than(d).retrieve(); + + cx.execute(); + let unopt_out = out.data(); + out.drop(); + + cx.compile(<(GenericCompiler, CudaCompiler)>::default(), &mut out); + cx.execute(); + + assert_close(&out.data(), &unopt_out); + } + + #[test] + fn test_fusion_slicing_padding() { + let mut cx = Graph::new(); + let mut rng = StdRng::seed_from_u64(0); + let inp = random_vec_rng(10, &mut rng); + let a = cx.named_tensor::>("a").set(inp); + let mut padded = a + .slice((..Expression::from(1), ..)) + .realize::>() + .cos() + .pad::, _, _>(&[(0, 1), (0, 0)]) + .exp2() + .retrieve(); + cx.execute(); + let unopt_out = padded.data(); + padded.drop(); + + cx.compile( + <(GenericCompiler, CudaCompiler)>::default(), + &mut padded, + ); + cx.execute(); + + assert_close(&padded.data(), &unopt_out); + } + + #[test] + fn test_fusion_subexpression() { + let mut cx = Graph::new(); + let mut rng = StdRng::seed_from_u64(0); + let data = random_vec_rng(10, &mut rng); + let a = cx.tensor::>().set(data); + let mut out = (a.sqrt().exp() + a.sqrt().sin()).retrieve(); + cx.execute(); + let unopt_out = out.data(); + out.drop(); + + cx.compile(<(GenericCompiler, CudaCompiler)>::default(), &mut out); + cx.execute(); + + assert_close(&out.data(), &unopt_out); + } + + #[test] + fn test_fusion_rope_emb() { + let mut cx = Graph::new(); + const SEQ: usize = 2; + const HEAD_DIM: usize = 4; + const HEAD_DIM_OVER_2: usize = HEAD_DIM / 2; + let freqs = (cx.arange::>() * 2.0) / (HEAD_DIM as f32); + let freqs = 1000000_f32.pow(freqs); + let pos = cx.arange::>() + BigExpression::from(0); + let mut emb = pos + .expand::<(_, Const<1>), _>() + .matmul(freqs.expand()) + .retrieve(); + + cx.execute(); + let unopt_out = emb.data(); + emb.drop(); + + cx.compile(<(GenericCompiler, CudaCompiler)>::default(), &mut emb); + cx.execute(); + assert_close(&emb.data(), &unopt_out); + } + + #[test] + fn test_fusion_rotate() { + let mut cx = Graph::new(); + let mut rng = StdRng::seed_from_u64(0); + const SEQ: usize = 2; + const HEAD_DIM: usize = 4; + const HEAD_DIM_OVER_2: usize = HEAD_DIM / 2; + let a = cx + .tensor::>() + .set(random_vec_rng(SEQ * HEAD_DIM, &mut rng)) + .keep(); + let b = cx + .tensor::>() + .set(random_vec_rng(SEQ * HEAD_DIM_OVER_2, &mut rng)) + .keep(); + // Split input into evens and odds + let split = a.reshape::>(); + let x0: GraphTensor> = + split.slice((.., .., ..Expression::from(1))).realize(); + let x1: GraphTensor> = + split.slice((.., .., Expression::from(1)..)).realize(); + + let x0_out = x0 * b - x1 * b.cos(); + let x1_out = x0 + x1; + + // Combine back into output + let mut out: GraphTensor> = x0_out + .concat_along::, Axis<2>, _>(x1_out) + .reshape() + .retrieve(); + cx.execute(); + + let unopt_out = out.data(); + out.drop(); + + cx.compile(<(GenericCompiler, CudaCompiler)>::default(), &mut out); + cx.execute(); + assert_close(&out.data(), &unopt_out); + } + + #[test] + fn test_fusion_rope_full() { + let mut cx = Graph::new(); + let mut rng = StdRng::seed_from_u64(0); + const BATCH: usize = 1; + const N_HEADS: usize = 8; + const SEQ: usize = 2; + const HEAD_DIM: usize = 4; + const HEAD_DIM_OVER_2: usize = HEAD_DIM / 2; + let a = cx + .named_tensor::>("a") + .set(random_vec_rng(BATCH * N_HEADS * SEQ * HEAD_DIM, &mut rng)) + .keep(); + let freqs = (cx.arange::>() * 2.0) / (HEAD_DIM as f32); + let freqs = 1000000_f32.pow(freqs); + let pos = cx.arange::>() + BigExpression::from(0); + let emb = pos.expand::<(_, Const<1>), _>().matmul(freqs.expand()); + // Split input into evens and odds + let split = a.reshape::>(); + let x0: GraphTensor> = split + .slice((.., .., .., .., ..Expression::from(1))) + .contiguous() + .realize(); + let x1: GraphTensor> = split + .slice((.., .., .., .., Expression::from(1)..)) + .contiguous() + .realize(); + + // Apply sin and cos embeddings + let x0_out = x0 * emb.cos().expand() - x1 * emb.sin().expand(); + let x1_out = x0 * emb.sin().expand() + x1 * emb.cos().expand(); + + // Combine back into output + let mut out: GraphTensor> = x0_out + .concat_along::, Axis<4>, _>(x1_out) + .reshape() + .retrieve(); + cx.execute(); + let unopt_out = out.data(); + out.drop(); + + cx.compile(<(GenericCompiler, CudaCompiler)>::default(), &mut out); + cx.execute(); + + assert_close(&out.data(), &unopt_out); + } + + #[test] + fn test_fusion_transformer() { + pub const HIDDEN_DIM: usize = 128; + pub const N_HEADS: usize = 2; + pub const N_KV_HEADS: usize = 2; + pub const MLP_DIM: usize = 256; + pub const NUM_LAYERS: usize = 2; + pub const SEQ_LEN: usize = 65; + pub const N_ATTENTION_GROUPS: usize = N_HEADS / N_KV_HEADS; + pub const HEAD_DIM: usize = HIDDEN_DIM / N_HEADS; + pub const HEAD_DIM_OVER_2: usize = HEAD_DIM / 2; + pub const ATTN_PROJ_DIM: usize = HEAD_DIM * N_KV_HEADS; + pub struct Mlp { + pub gate_proj: PermutedLinear, + pub down_proj: PermutedLinear, + pub up_proj: PermutedLinear, + } + + pub type KVCache = ( + GraphTensor<(Batch, Const, Seq, Const)>, + GraphTensor<(Batch, Const, Seq, Const)>, + ); + + impl Module> for Mlp + where + GraphTensor: Matmul, Output = GraphTensor>, + GraphTensor: Matmul, Output = GraphTensor>, + { + type Output = GraphTensor; + + fn forward(&self, input: GraphTensor) -> Self::Output { + let gate = self.gate_proj.forward(input).swish(); + let up = self.up_proj.forward(input) * gate; + self.down_proj.forward(up) + } + } + impl InitModule for Mlp { + fn initialize(cx: &mut Graph) -> Self { + Self { + gate_proj: InitModule::initialize(cx), + up_proj: InitModule::initialize(cx), + down_proj: InitModule::initialize(cx), + } + } + } + fn apply_rotary_embeddings_ggml( + input: GraphTensor<(Batch, Const, Seq, Const)>, + prev_seq: BigExpression, + ) -> GraphTensor<(Batch, Const, Seq, Const)> { + // Get freqs + let freqs = + (input.graph().arange::>() * 2.0) / (HEAD_DIM as f32); + let freqs = 1000000_f32.pow(freqs); + let pos = input.graph().arange::() + prev_seq; + let emb = pos.expand::<(_, Const<1>), _>().matmul(freqs.expand()); + + // Split input into evens and odds + let split = + input.reshape::<(Batch, Const, Seq, Const, Const<2>)>(); + let x0: GraphTensor<(Batch, Const, Seq, Const, Const<1>)> = + split + .slice((.., .., .., .., ..Expression::from(1))) + .contiguous() + .realize(); + let x1: GraphTensor<(Batch, Const, Seq, Const, Const<1>)> = + split + .slice((.., .., .., .., Expression::from(1)..)) + .contiguous() + .realize(); + + // Apply sin and cos embeddings + let x0_out = x0 * emb.cos().expand() - x1 * emb.sin().expand(); + let x1_out = x0 * emb.sin().expand() + x1 * emb.cos().expand(); + + // Combine back into output + x0_out + .concat_along::<(Batch, Const, Seq, Const, Const<2>), Axis<4>, _>( + x1_out, + ) + .reshape() + } + pub struct SelfAttention { + pub q_proj: GraphTensor>, + pub k_proj: GraphTensor>, + pub v_proj: GraphTensor>, + pub o_proj: GraphTensor>, + } + + impl + Module<( + GraphTensor<(Batch, CurSeq, Const)>, + KVCache, + PhantomData, + )> for SelfAttention + { + type Output = ( + GraphTensor<(Batch, CurSeq, Const)>, + KVCache, + ); + fn forward( + &self, + (x, (k_cache, v_cache), _): ( + GraphTensor<(Batch, CurSeq, Const)>, + KVCache, + PhantomData, + ), + ) -> Self::Output { + // Apply the Projections + let queries = x + .matmul(self.q_proj.permute()) + .reshape::<(Batch, CurSeq, Const, Const)>() + .permute::<_, Axes4<0, 2, 1, 3>>(); + + let keys = x + .matmul(self.k_proj.permute()) + .reshape::<(Batch, CurSeq, Const, Const)>() + .permute::<_, Axes4<0, 2, 1, 3>>(); + + let values = x + .matmul(self.v_proj.permute()) + .reshape::<(Batch, CurSeq, Const, Const)>() + .permute::<_, Axes4<0, 2, 1, 3>>(); + + // Rotary embed queries and keys + let queries = apply_rotary_embeddings_ggml(queries, PrevSeq::const_size().big()); + let keys = apply_rotary_embeddings_ggml(keys, PrevSeq::const_size().big()); + + // Add KV cache + let (keys, values) = ( + k_cache.concat_along::<_, Axis<2>, _>(keys), + v_cache.concat_along::<_, Axis<2>, _>(values), + ); + + // Repeat the KV States for Grouped-Query Attention + let repeated_keys = keys.expand::<(_, _, Const, _, _), _>(); + let repeated_values = values.expand::<(_, _, Const, _, _), _>(); + + // Calculate attention weights + let mut attention_weights = queries + .reshape::<(_, Const, Const, _, _)>() // Split query heads into groups + .matmul(repeated_keys.permute()) + .div((HEAD_DIM as f32).sqrt()); + + let attention_mask = self.k_proj.graph().triu::(1) * f16::MIN.to_f32(); + attention_weights += attention_mask + .pad::<(CurSeq, TotSeq), _, _>(&[ + (0.into(), Expression::from(0)), + (TotSeq::const_size() - CurSeq::const_size(), 0.into()), + ]) + .expand(); + + // Calculate final outputs + let output = attention_weights + .softmax::>() + // Apply distribution to values + .matmul(repeated_values) + // Merge heads + .permute::<_, Axes5<0, 3, 1, 2, 4>>() + .reshape::<(Batch, CurSeq, Const)>(); + let output = output + // Apply output projection + .matmul(self.o_proj.permute()); + (output, (keys.contiguous(), values.contiguous())) // Cache needs to be contiguous for transferring to another graph + } + } + + impl InitModule for SelfAttention { + fn initialize(cx: &mut Graph) -> Self { + Self { + q_proj: cx + .named_tensor("Q Proj") + .set(random_vec(HIDDEN_DIM * HIDDEN_DIM)), + k_proj: cx + .named_tensor("K Proj") + .set(random_vec(ATTN_PROJ_DIM * HIDDEN_DIM)), + v_proj: cx + .named_tensor("V Proj") + .set(random_vec(ATTN_PROJ_DIM * HIDDEN_DIM)), + o_proj: cx + .named_tensor("O Proj") + .set(random_vec(HIDDEN_DIM * HIDDEN_DIM)), + } + } + } + + impl SerializeModule for SelfAttention { + fn serialize(&self, s: &mut Serializer) { + s.tensor("attn_q/weight", self.q_proj); + s.tensor("attn_v/weight", self.v_proj); + s.tensor("attn_k/weight", self.k_proj); + s.tensor("attn_output/weight", self.o_proj); + } + } + + pub struct TransformerBlock { + pub attention: SelfAttention, + pub attention_norm: RMSNorm, + pub feed_forward: Mlp, + pub feed_forward_norm: RMSNorm, + } + + impl + Module<( + GraphTensor<(Batch, CurSeq, Const)>, + KVCache, + PhantomData, + )> for TransformerBlock + { + type Output = ( + GraphTensor<(Batch, CurSeq, Const)>, + KVCache, + ); + fn forward( + &self, + (mut x, cache, _): ( + GraphTensor<(Batch, CurSeq, Const)>, + KVCache, + PhantomData, + ), + ) -> Self::Output { + // Attention + let normed = self.attention_norm.forward(x); + let (y, cache) = self + .attention + .forward((normed, cache, PhantomData::)); + + // Residual Addition + x += y; + + // Feed Forward + let y = self.feed_forward.forward(self.feed_forward_norm.forward(x)); + + // Residual Addition + (x + y, cache) + } + } + + impl InitModule for TransformerBlock { + fn initialize(cx: &mut Graph) -> Self { + Self { + attention: InitModule::initialize(cx), + attention_norm: { + let mut norm = RMSNorm::initialize(cx); + norm.epsilon = 1e-5; + norm + }, + feed_forward: InitModule::initialize(cx), + feed_forward_norm: { + let mut norm = RMSNorm::initialize(cx); + norm.epsilon = 1e-5; + norm + }, + } + } + } + + pub struct MistralLM { + // Transformer layers + pub layers: Vec, + // Final Norm layer + pub norm: RMSNorm, + } + + impl + Module<( + GraphTensor<(Batch, CurSeq, Const)>, + Vec>, + PhantomData, + )> for MistralLM + { + type Output = ( + GraphTensor<(Batch, CurSeq, Const)>, + Vec>, + ); + fn forward( + &self, + (input, cache, _): ( + GraphTensor<(Batch, CurSeq, Const)>, + Vec>, + PhantomData, + ), + ) -> Self::Output { + let mut x = input; + + // Run through layers and collect new caches + let mut new_caches = vec![]; + let mut new_cache; + for (i, layer) in self.layers.iter().enumerate() { + (x, new_cache) = layer.forward((x, cache[i], PhantomData::)); + new_caches.push(new_cache); + } + // Run through last norm and output projection + let normed = self.norm.forward(x); + (normed, new_caches) + } + } + + impl InitModule for MistralLM { + fn initialize(cx: &mut Graph) -> Self { + Self { + norm: RMSNorm::initialize(cx), + layers: (0..NUM_LAYERS) + .map(|_| InitModule::initialize(cx)) + .collect(), + } + } + } + + let mut cx = Graph::new(); + let model = MistralLM::initialize(&mut cx); + let caches = (0..NUM_LAYERS) + .map(|_| { + ( + cx.tensor::<(Const<1>, Const, Dyn<'p'>, Const)>() + .set_dyn( + random_vec(SEQ_LEN * N_KV_HEADS * HEAD_DIM), + &[1, N_KV_HEADS, SEQ_LEN, HEAD_DIM], + ), + cx.tensor::<(Const<1>, Const, Dyn<'p'>, Const)>() + .set_dyn( + random_vec(SEQ_LEN * N_KV_HEADS * HEAD_DIM), + &[1, N_KV_HEADS, SEQ_LEN, HEAD_DIM], + ), + ) + }) + .collect(); + let input = cx + .tensor::<(Const<1>, Dyn<'s'>, luminal::shape::Const)>() + .set_dyn(random_vec(2 * HIDDEN_DIM), &[1, 2, HIDDEN_DIM]); + let (mut out, _) = model.forward((input, caches, PhantomData::>)); + out.retrieve(); + + cx.set_dyn_dim('t', SEQ_LEN + 2); + cx.execute(); + + let unopt_out = out.data(); + out.drop(); + + cx.compile(<(GenericCompiler, CudaCompiler)>::default(), &mut out); + cx.execute(); + + assert_close_precision(&out.data(), &unopt_out, 1e-1); + } +} diff --git a/crates/luminal_cuda/src/lib.rs b/crates/luminal_cuda/src/lib.rs index 0b2c20cc..46db7550 100644 --- a/crates/luminal_cuda/src/lib.rs +++ b/crates/luminal_cuda/src/lib.rs @@ -1,4 +1,5 @@ mod binary; +mod elementwise_fusion; mod matmul; mod other; mod prim; @@ -22,14 +23,15 @@ use std::{collections::hash_map::DefaultHasher, ffi::c_void, fmt::Write, hash::H use luminal::{op::InputTensor, prelude::*}; -/// Compile graphs to run on Metal-supported macOS devices in supported data formats +/// Compile graphs to run on CUDA GPUs in supported data formats pub type CudaCompiler = ( prim::PrimitiveCompiler, SpecialOpsCompiler, other::CopyCompiler, + elementwise_fusion::ElementwiseFusionCompiler, ); -/// Compiler to replace metal ops with specialized variants +/// Compiler to replace cuda primops with specialized variants pub type SpecialOpsCompiler = ( binary::SubtractionCompiler, binary::EqualCompiler, @@ -121,9 +123,9 @@ impl CudaFloat for u8 { } } -fn expr_to_cuda_string(expr: BigExpression) -> String { +fn expr_to_cuda_string(expr: &BigExpression) -> String { let mut symbols = vec![]; - for term in expr.terms { + for term in expr.terms.clone() { let new_symbol = match term { Term::Num(n) => n.to_string(), Term::Var(c) => { @@ -156,8 +158,8 @@ fn expr_to_cuda_string(expr: BigExpression) -> String { fn get_idx_valid_exps(shape: ShapeTracker) -> (String, String) { ( - expr_to_cuda_string(shape.index_expression()), - expr_to_cuda_string(shape.valid_expression()), + expr_to_cuda_string(&shape.index_expression()), + expr_to_cuda_string(&shape.valid_expression()), ) } @@ -251,8 +253,8 @@ fn compile_and_load_kernel(mut code: String, device: &Arc) -> CudaFu #[macro_export] macro_rules! debug_type { - ($t: ty) => { - impl std::fmt::Debug for $t { + ($t: ident) => { + impl std::fmt::Debug for $t { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, stringify!($t)) } diff --git a/crates/luminal_cuda/src/matmul.rs b/crates/luminal_cuda/src/matmul.rs index d558aa52..7e42baaf 100644 --- a/crates/luminal_cuda/src/matmul.rs +++ b/crates/luminal_cuda/src/matmul.rs @@ -17,7 +17,7 @@ use luminal::{ #[derive(Clone)] pub struct Matmul(Arc, Arc, PhantomData); -crate::debug_type!(Matmul); +crate::debug_type!(Matmul); impl Operator for Matmul { fn process(&mut self, inp: Vec<(InputTensor, ShapeTracker)>) -> Vec { @@ -32,7 +32,6 @@ impl Operator for Matmul { a_shape[a_shape.len() - 1].to_usize().unwrap() as i32, b_shape[b_shape.len() - 1].to_usize().unwrap() as i32, ); - println!("{:?}", (batch_size, m, k, n)); let a = get_buffer_from_tensor::(&inp[0].0); let b = get_buffer_from_tensor::(&inp[1].0); let mut out = self diff --git a/crates/luminal_cuda/src/other.rs b/crates/luminal_cuda/src/other.rs index 90f242c0..b714805d 100644 --- a/crates/luminal_cuda/src/other.rs +++ b/crates/luminal_cuda/src/other.rs @@ -20,7 +20,7 @@ pub struct CudaARange { dyn_map: *const FxHashMap, _phantom: PhantomData, } -crate::debug_type!(CudaARange); +crate::debug_type!(CudaARange); impl CudaARange { pub fn new( diff --git a/crates/luminal_cuda/src/prim.rs b/crates/luminal_cuda/src/prim.rs index a4f47277..346d993d 100644 --- a/crates/luminal_cuda/src/prim.rs +++ b/crates/luminal_cuda/src/prim.rs @@ -20,7 +20,7 @@ use luminal::{ /// Copy a tensor to the GPU #[derive(Clone)] pub struct CudaCopyToDevice(Arc, PhantomData); -crate::debug_type!(CudaCopyToDevice); +crate::debug_type!(CudaCopyToDevice); impl CudaCopyToDevice { pub fn new(dev: Arc) -> Self { @@ -47,7 +47,7 @@ impl Operator for CudaCopyToDevice { /// Copy a tensor from the GPU #[derive(Clone)] pub struct CudaCopyFromDevice(Arc, PhantomData); -crate::debug_type!(CudaCopyFromDevice); +crate::debug_type!(CudaCopyFromDevice); impl CudaCopyFromDevice { pub fn new(dev: Arc) -> Self { @@ -112,6 +112,15 @@ impl Operator for CudaConstant { self.device.htod_copy_into(vec![value], &mut a).unwrap(); vec![Tensor::new(CudaData(a))] } + + fn custom(&mut self, key: &str, _: Box) -> Option> { + if key == "elementwise" { + if let ConstantValue::Float(f) = self.value { + return Some(Box::new(format!("{f:?}"))); + } + } + None + } } #[derive(Clone)] @@ -122,7 +131,7 @@ pub struct CudaContiguous { dyn_symbols: Vec, dyn_map: *const FxHashMap, } -crate::debug_type!(CudaContiguous); +crate::debug_type!(CudaContiguous); impl CudaContiguous { pub fn new( @@ -172,6 +181,13 @@ impl Operator for CudaContiguous { vec![Tensor::new(CudaData(out))] } + + fn custom(&mut self, key: &str, _: Box) -> Option> { + if key == "elementwise" { + return Some(Box::new("input0".to_string())); + } + None + } } #[derive(Clone)] @@ -180,7 +196,7 @@ pub struct CudaLog2 { device: Arc, _phantom: PhantomData, } -crate::debug_type!(CudaLog2); +crate::debug_type!(CudaLog2); impl CudaLog2 { pub fn new(device: Arc) -> Self { @@ -236,7 +252,7 @@ pub struct CudaExp2 { device: Arc, _phantom: PhantomData, } -crate::debug_type!(CudaExp2); +crate::debug_type!(CudaExp2); impl CudaExp2 { pub fn new(device: Arc) -> Self { @@ -291,7 +307,7 @@ pub struct CudaSqrt { device: Arc, _phantom: PhantomData, } -crate::debug_type!(CudaSqrt); +crate::debug_type!(CudaSqrt); impl CudaSqrt { pub fn new(device: Arc) -> Self { @@ -350,7 +366,7 @@ pub struct CudaSin { device: Arc, _phantom: PhantomData, } -crate::debug_type!(CudaSin); +crate::debug_type!(CudaSin); impl CudaSin { pub fn new(device: Arc) -> Self { @@ -408,7 +424,7 @@ pub struct CudaRecip { device: Arc, _phantom: PhantomData, } -crate::debug_type!(CudaRecip); +crate::debug_type!(CudaRecip); impl CudaRecip { pub fn new(device: Arc) -> Self { @@ -470,7 +486,7 @@ pub struct CudaAdd { dyn_symbols: Vec, dyn_map: *const FxHashMap, } -crate::debug_type!(CudaAdd); +crate::debug_type!(CudaAdd); impl CudaAdd { pub fn new( @@ -544,7 +560,7 @@ pub struct CudaMul { dyn_symbols: Vec, dyn_map: *const FxHashMap, } -crate::debug_type!(CudaMul); +crate::debug_type!(CudaMul); impl CudaMul { pub fn new( @@ -615,7 +631,7 @@ pub struct CudaMod { dyn_symbols: Vec, dyn_map: *const FxHashMap, } -crate::debug_type!(CudaMod); +crate::debug_type!(CudaMod); impl CudaMod { pub fn new( @@ -686,7 +702,7 @@ pub struct CudaLessThan { dyn_symbols: Vec, dyn_map: *const FxHashMap, } -crate::debug_type!(CudaLessThan); +crate::debug_type!(CudaLessThan); impl CudaLessThan { pub fn new( @@ -764,7 +780,7 @@ pub struct CudaSumReduce { dyn_symbols: Vec, dyn_map: *const FxHashMap, } -crate::debug_type!(CudaSumReduce); +crate::debug_type!(CudaSumReduce); impl CudaSumReduce { pub fn new( @@ -858,7 +874,7 @@ pub struct CudaMaxReduce { dyn_symbols: Vec, dyn_map: *const FxHashMap, } -crate::debug_type!(CudaMaxReduce); +crate::debug_type!(CudaMaxReduce); impl CudaMaxReduce { pub fn new( diff --git a/crates/luminal_cuda/src/quantized.rs b/crates/luminal_cuda/src/quantized.rs index 7deb3a15..3a7a16d3 100644 --- a/crates/luminal_cuda/src/quantized.rs +++ b/crates/luminal_cuda/src/quantized.rs @@ -19,7 +19,7 @@ pub struct QuantizedMatmul { device: Arc, _phantom: PhantomData, } -crate::debug_type!(QuantizedMatmul); +crate::debug_type!(QuantizedMatmul); impl QuantizedMatmul { fn new(device: Arc) -> Self { @@ -177,7 +177,7 @@ pub struct QuantizedGather { embed_dim: usize, _phantom: PhantomData, } -crate::debug_type!(QuantizedGather); +crate::debug_type!(QuantizedGather); impl QuantizedGather { fn new(device: Arc, embed_dim: usize) -> Self { diff --git a/crates/luminal_cuda/src/unary.rs b/crates/luminal_cuda/src/unary.rs index 55753f22..e93fa27c 100644 --- a/crates/luminal_cuda/src/unary.rs +++ b/crates/luminal_cuda/src/unary.rs @@ -30,7 +30,7 @@ pub struct CudaMeanReduce { pub dyn_map: *const FxHashMap, _phantom: PhantomData, } -crate::debug_type!(CudaMeanReduce); +crate::debug_type!(CudaMeanReduce); impl PartialEq for CudaMeanReduce { fn eq(&self, other: &Self) -> bool { @@ -175,7 +175,7 @@ pub struct CudaStdNorm { epsilon: f32, // Epsilon _phantom: PhantomData, } -crate::debug_type!(CudaStdNorm); +crate::debug_type!(CudaStdNorm); impl PartialEq for CudaStdNorm { fn eq(&self, other: &Self) -> bool { @@ -403,7 +403,7 @@ pub struct CudaExp { device: Arc, _phantom: PhantomData, } -crate::debug_type!(CudaExp); +crate::debug_type!(CudaExp); impl CudaExp { fn new(device: Arc) -> Self { @@ -506,7 +506,7 @@ pub struct CudaCos { device: Arc, _phantom: PhantomData, } -crate::debug_type!(CudaCos); +crate::debug_type!(CudaCos); impl CudaCos { fn new(device: Arc) -> Self { @@ -611,7 +611,7 @@ pub struct CudaSoftmax { device: Arc, _phantom: PhantomData, } -crate::debug_type!(CudaSoftmax); +crate::debug_type!(CudaSoftmax); impl CudaSoftmax { fn new(device: Arc) -> Self { From 0f63429513b759ef27eff1aa56523dbcc47042dd Mon Sep 17 00:00:00 2001 From: Joe Fioti Date: Sun, 28 Apr 2024 22:00:20 -0500 Subject: [PATCH 09/12] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 98f5171d..c6180957 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ [![Website](https://img.shields.io/badge/Docs-Website-blue?style=for-the-badge&color=0D9373)](https://luminalai.com) [![CI Status](https://img.shields.io/github/actions/workflow/status/jafioti/luminal/test.yml?style=for-the-badge&logo=github-actions&logoColor=white&branch=main)](https://github.com/Sidekick-AI/dataflow/actions) [![Current Crates.io Version](https://img.shields.io/crates/v/luminal.svg?style=for-the-badge&logo=rust)](https://crates.io/crates/luminal) -[![](https://dcbadge.vercel.app/api/server/VQf3j8WWNd)](https://discord.gg/VQf3j8WWNd) +[![discord](https://dcbadge.vercel.app/api/server/VQf3j8WWNd)](https://discord.gg/VQf3j8WWNd) **Deep learning at the speed of light.** From 41d6d08cd3d4da44e6ef93198cbf8713079b4432 Mon Sep 17 00:00:00 2001 From: Joe Fioti Date: Sun, 28 Apr 2024 23:29:54 -0500 Subject: [PATCH 10/12] Phi 3 working --- crates/luminal_cuda/src/prim.rs | 3 ++ crates/luminal_cuda/src/tests/fp16.rs | 38 ------------------------- crates/luminal_cuda/src/tests/fp32.rs | 22 --------------- crates/luminal_metal/src/prim.rs | 2 +- examples/phi/src/main.rs | 18 ++++++------ examples/phi/src/model.rs | 40 ++++++++++----------------- src/graph.rs | 2 +- 7 files changed, 28 insertions(+), 97 deletions(-) diff --git a/crates/luminal_cuda/src/prim.rs b/crates/luminal_cuda/src/prim.rs index 346d993d..cc59e70b 100644 --- a/crates/luminal_cuda/src/prim.rs +++ b/crates/luminal_cuda/src/prim.rs @@ -990,6 +990,9 @@ impl Compiler for PrimitiveCompiler { graph.remove_edge(edge_id); } + if graph.no_delete.remove(&function_node) { + graph.no_delete.insert(copy_node); + } if let Some(v) = graph.to_retrieve.get(&function_node) { graph.to_retrieve.insert(copy_node, *v); } diff --git a/crates/luminal_cuda/src/tests/fp16.rs b/crates/luminal_cuda/src/tests/fp16.rs index be07cfbf..23bd3db8 100644 --- a/crates/luminal_cuda/src/tests/fp16.rs +++ b/crates/luminal_cuda/src/tests/fp16.rs @@ -576,28 +576,6 @@ fn test_rms_norm() { assert_close(&b.data(), &out.to_dtype::().as_vec()); } -#[test] -fn test_layer_norm() { - let mut cx = Graph::new(); - let a_data = random_vec(15 * 16 * 32); - let a = cx.tensor::>().set(a_data.clone()); - let mut b = a.layer_norm::, _>(1e-5).retrieve(); - let mut c = a.layer_norm::, _>(1e-5).retrieve(); - cx.compile( - <(GenericCompiler, CudaCompiler)>::default(), - (&mut b, &mut c), - ); - cx.execute(); - - let d_dev = Cpu::default(); - let d_a = d_dev.tensor_from_vec(a_data, (DConst::<15>, DConst::<16>, DConst::<32>)); - let d_b = d_a.clone().normalize::>(1e-5); - let d_c = d_a.normalize::>(1e-5); - - assert_close_precision(&b.data(), &d_b.as_vec(), 0.01); - assert_close_precision(&c.data(), &d_c.as_vec(), 0.01); -} - #[test] fn test_transformer_encoder_block() { let mut cx = Graph::new(); @@ -670,22 +648,6 @@ fn test_transformer_encoder_block() { assert_close_precision(&b.data(), &d_b.as_vec(), 0.01); } -#[test] -fn test_common_buffer() { - let data = random_vec(32); - let mut cx = Graph::new(); - let a = cx.tensor::>(); - a.set(data.clone()); - let a1 = cx.tensor::>(); - a1.set(data.clone()); - let exped = a * a1; - let mut b = exped.log2().retrieve(); - let mut c = exped.sin().retrieve(); - - cx.compile(CudaCompiler::::default(), (&mut b, &mut c)); - cx.execute(); -} - #[test] fn test_embedding() { let mut cx = Graph::new(); diff --git a/crates/luminal_cuda/src/tests/fp32.rs b/crates/luminal_cuda/src/tests/fp32.rs index 61d5512a..7118c6a3 100644 --- a/crates/luminal_cuda/src/tests/fp32.rs +++ b/crates/luminal_cuda/src/tests/fp32.rs @@ -529,28 +529,6 @@ fn test_rms_norm() { assert_close(&b.data(), &out.as_vec()); } -#[test] -fn test_layer_norm() { - let mut cx = Graph::new(); - let a_data = random_vec(15 * 16 * 32); - let a = cx.tensor::>().set(a_data.clone()); - let mut b = a.layer_norm::, _>(1e-5).retrieve(); - let mut c = a.layer_norm::, _>(1e-5).retrieve(); - cx.compile( - <(GenericCompiler, CudaCompiler)>::default(), - (&mut b, &mut c), - ); - cx.execute(); - - let d_dev = Cpu::default(); - let d_a = d_dev.tensor_from_vec(a_data, (DConst::<15>, DConst::<16>, DConst::<32>)); - let d_b = d_a.clone().normalize::>(1e-5); - let d_c = d_a.normalize::>(1e-5); - - assert_close_precision(&b.data(), &d_b.as_vec(), 1e-2); - assert_close_precision(&c.data(), &d_c.as_vec(), 1e-2); -} - #[test] fn test_transformer_encoder_block() { let mut cx = Graph::new(); diff --git a/crates/luminal_metal/src/prim.rs b/crates/luminal_metal/src/prim.rs index 354e9127..4ccf1983 100644 --- a/crates/luminal_metal/src/prim.rs +++ b/crates/luminal_metal/src/prim.rs @@ -1600,7 +1600,7 @@ impl Compiler for PrimitiveCompiler { graph.remove_edge(edge_id); } - if graph.no_delete.contains(&function_node) { + if graph.no_delete.remove(&function_node) { graph.no_delete.insert(copy_node); } if let Some(w) = graph.to_retrieve.get(&function_node) { diff --git a/examples/phi/src/main.rs b/examples/phi/src/main.rs index c2b7499b..b7c96658 100644 --- a/examples/phi/src/main.rs +++ b/examples/phi/src/main.rs @@ -43,9 +43,9 @@ fn main() { let mut cache_src: Vec, Dyn<'p'>>> = (0..model::NUM_LAYERS) .map(|_| (cx.named_tensor("Key Cache"), cx.named_tensor("Value Cache"))) .collect(); - cache_src.set_dyn(vec![], &[1, model::N_KV_HEADS, 0, model::HEAD_DIM]); + cache_src.set_dyn(vec![], &[1, model::N_HEADS, 0, model::HEAD_DIM]); let model = model::MistralLM::initialize(&mut cx); - let mut model_weights = downstream(params(&model), &cx); + let mut model_weights = params(&model); cx.keep_tensors(&model_weights); let (logits, mut cache_dest) = model.forward((input, &cache_src, PhantomData::>)); let mut logits = logits @@ -81,8 +81,8 @@ fn main() { &mut model_weights, ), ); - let cache_src_set = downstream(&cache_src, &cx); - let cache_dest_set = cache_dest.to_ids(); + let cache_src = downstream(&cache_src, &cx); + let cache_dest = cache_dest.to_ids(); println!("\t\t - {}ms", now.elapsed().as_millis()); // Initial forward pass to load weights @@ -93,11 +93,11 @@ fn main() { cx.set_dyn_dim('t', 1); cx.execute(); logits.drop(); - cache_dest.drop(); + cx.drop_tensors(&cache_dest); println!("\t\t - {}ms", now.elapsed().as_millis()); // Now that weights are loaded, delete the loading nodes so they don't run again - delete_inputs(&model_weights, &mut cx); + delete_inputs(&downstream(model_weights, &cx), &mut cx); // Run prompt processing pass let mut input_ids = tokenizer .encode(&cli_args.prompt as &str, false) @@ -120,7 +120,7 @@ fn main() { 1000.0 * (input_ids.len() as f64) / (elapsed_ms as f64), input_ids.len() ); - delete_inputs(&cache_src_set, &mut cx); + delete_inputs(&cache_src, &mut cx); let mut output_ids = vec![sample_index(&logits.data())]; logits.drop(); @@ -133,7 +133,7 @@ fn main() { io::stdout().flush().unwrap(); // Swap caches - transfer_data_same_graph(&cache_dest_set, &cache_src_set, &mut cx); + transfer_data_same_graph(&cache_dest, &cache_src, &mut cx); // Decode loop let start_decode = std::time::Instant::now(); @@ -162,7 +162,7 @@ fn main() { prev_output_len = current_output.len(); // Swap caches - transfer_data_same_graph(&cache_dest_set, &cache_src_set, &mut cx); + transfer_data_same_graph(&cache_dest, &cache_src, &mut cx); } println!(); diff --git a/examples/phi/src/model.rs b/examples/phi/src/model.rs index 94cb3e2b..72f6409c 100644 --- a/examples/phi/src/model.rs +++ b/examples/phi/src/model.rs @@ -1,4 +1,4 @@ -use std::{marker::PhantomData, ops::Div}; +use std::marker::PhantomData; use luminal::prelude::{binary::F32Pow, *}; use luminal_nn::{Embedding, PermutedLinear, RMSNorm}; @@ -6,19 +6,17 @@ use luminal_nn::{Embedding, PermutedLinear, RMSNorm}; // Llama3 8B Config pub const VOCAB_SIZE: usize = 32064; pub const HIDDEN_DIM: usize = 3072; -pub const NUM_LAYERS: usize = 1; +pub const NUM_LAYERS: usize = 32; pub const N_HEADS: usize = 32; -pub const N_KV_HEADS: usize = 8; pub const MLP_DIM: usize = 8192; -pub const N_ATTENTION_GROUPS: usize = N_HEADS / N_KV_HEADS; pub const HEAD_DIM: usize = HIDDEN_DIM / N_HEADS; pub const HEAD_DIM_OVER_2: usize = HEAD_DIM / 2; -pub const ATTN_PROJ_DIM: usize = HEAD_DIM * N_KV_HEADS; +pub const ATTN_PROJ_DIM: usize = HEAD_DIM * N_HEADS; pub type KVCache = ( - GraphTensor<(Batch, Const, Seq, Const)>, - GraphTensor<(Batch, Const, Seq, Const)>, + GraphTensor<(Batch, Const, Seq, Const)>, + GraphTensor<(Batch, Const, Seq, Const)>, ); pub struct Mlp { @@ -127,15 +125,13 @@ impl .matmul(self.q_proj.permute()) .reshape::<(Batch, CurSeq, Const, Const)>() .permute::<_, Axes4<0, 2, 1, 3>>(); - let keys = x .matmul(self.k_proj.permute()) - .reshape::<(Batch, CurSeq, Const, Const)>() + .reshape::<(Batch, CurSeq, Const, Const)>() .permute::<_, Axes4<0, 2, 1, 3>>(); - let values = x .matmul(self.v_proj.permute()) - .reshape::<(Batch, CurSeq, Const, Const)>() + .reshape::<(Batch, CurSeq, Const, Const)>() .permute::<_, Axes4<0, 2, 1, 3>>(); // Rotary embed queries and keys @@ -146,15 +142,8 @@ impl let keys = k_cache.concat_along::<_, Axis<2>, _>(keys); let values = v_cache.concat_along::<_, Axis<2>, _>(values); - // Repeat the KV States for Grouped-Query Attention - let repeated_keys = keys.expand::<(_, _, Const, _, _), _>(); - let repeated_values = values.expand::<(_, _, Const, _, _), _>(); - // Calculate attention weights - let mut attention_weights = queries - .reshape::<(_, Const, Const, _, _)>() // Split query heads into groups - .matmul(repeated_keys.permute()) - .div((HEAD_DIM as f32).sqrt()); + let mut attention_weights = queries.matmul(keys.permute()) / (HEAD_DIM as f32).sqrt(); let attention_mask = self.k_proj.graph().triu::(1) * f16::MIN.to_f32(); attention_weights += attention_mask @@ -166,11 +155,11 @@ impl // Calculate final outputs let output = attention_weights - .softmax::>() + .softmax::>() // Apply distribution to values - .matmul(repeated_values) + .matmul(values) // Merge heads - .permute::<_, Axes5<0, 3, 1, 2, 4>>() + .permute::<_, Axes4<0, 2, 1, 3>>() .reshape::<(Batch, CurSeq, Const)>(); let output = output // Apply output projection @@ -226,10 +215,9 @@ impl ), ) -> Self::Output { // Attention - let normed = self.attention_norm.forward(x); - let (y, cache) = self - .attention - .forward((normed, cache, PhantomData::)); + let (y, cache) = + self.attention + .forward((self.attention_norm.forward(x), cache, PhantomData::)); // Residual Addition x += y; diff --git a/src/graph.rs b/src/graph.rs index 3c5a8441..1ec8d9ab 100644 --- a/src/graph.rs +++ b/src/graph.rs @@ -290,7 +290,7 @@ impl Graph { if self.tensors.contains_key(&(*node, 0)) { continue; } - let op_name = format!("{:?}", self.graph.node_weight(*node).unwrap()); + let op_name = format!("{:?} | {}", self.node_weight(*node).unwrap(), node.index()); print!("{}", op_name.bold().bright_green()); let mut srcs = From 4fee36107de2f318ace032dfe89d3206329031fa Mon Sep 17 00:00:00 2001 From: Joe Fioti Date: Mon, 29 Apr 2024 09:40:16 -0500 Subject: [PATCH 11/12] Small changes --- examples/llama/src/main.rs | 12 ++++++------ examples/phi/prompts/merge_sort.txt | 3 +++ examples/phi/src/main.rs | 8 ++++---- 3 files changed, 13 insertions(+), 10 deletions(-) create mode 100644 examples/phi/prompts/merge_sort.txt diff --git a/examples/llama/src/main.rs b/examples/llama/src/main.rs index 83cd8c8a..36cd1e79 100644 --- a/examples/llama/src/main.rs +++ b/examples/llama/src/main.rs @@ -81,8 +81,8 @@ fn main() { &mut model_weights, ), ); - let cache_src_set = downstream(&cache_src, &cx); - let cache_dest_set = cache_dest.to_ids(); + let cache_src = downstream(&cache_src, &cx); + let cache_dest = cache_dest.to_ids(); println!("\t\t - {}ms", now.elapsed().as_millis()); // Initial forward pass to load weights @@ -93,7 +93,7 @@ fn main() { cx.set_dyn_dim('t', 1); cx.execute(); logits.drop(); - cache_dest.drop(); + cx.drop_tensors(&cache_dest); println!("\t\t - {}ms", now.elapsed().as_millis()); // Now that weights are loaded, delete the loading nodes so they don't run again @@ -121,7 +121,7 @@ fn main() { 1000.0 * (input_ids.len() as f64) / (elapsed_ms as f64), input_ids.len() ); - delete_inputs(&cache_src_set, &mut cx); + delete_inputs(&cache_src, &mut cx); let mut output_ids = vec![sample_index(&logits.data())]; logits.drop(); @@ -134,7 +134,7 @@ fn main() { io::stdout().flush().unwrap(); // Swap caches - transfer_data_same_graph(&cache_dest_set, &cache_src_set, &mut cx); + transfer_data_same_graph(&cache_dest, &cache_src, &mut cx); // Decode loop let start_decode = std::time::Instant::now(); @@ -162,7 +162,7 @@ fn main() { prev_output_len = current_output.len(); // Swap caches - transfer_data_same_graph(&cache_dest_set, &cache_src_set, &mut cx); + transfer_data_same_graph(&cache_dest, &cache_src, &mut cx); } println!(); diff --git a/examples/phi/prompts/merge_sort.txt b/examples/phi/prompts/merge_sort.txt new file mode 100644 index 00000000..58a5cb70 --- /dev/null +++ b/examples/phi/prompts/merge_sort.txt @@ -0,0 +1,3 @@ +<|user|> +Please write me a python implementation of merge sort<|end|> +<|assistant|> diff --git a/examples/phi/src/main.rs b/examples/phi/src/main.rs index b7c96658..7ad9be9e 100644 --- a/examples/phi/src/main.rs +++ b/examples/phi/src/main.rs @@ -21,11 +21,11 @@ use luminal::prelude::*; #[command(author, version, about, long_about = None)] pub struct CLIArgs { /// Number of tokens to generate - #[clap(short = 't', long = "gen_tokens", default_value = "128")] + #[clap(short = 't', long = "gen_tokens", default_value = "512")] gen_tokens: i32, /// Prompt for the model - #[clap(short = 'p', long = "prompt", default_value = include_str!("../../llama/prompts/merge_sort.txt"))] + #[clap(short = 'p', long = "prompt", default_value = include_str!("../prompts/merge_sort.txt"))] prompt: String, } @@ -67,9 +67,9 @@ fn main() { ( GenericCompiler::default(), #[cfg(feature = "metal")] - luminal_metal::quantized::MetalQuantizedCompiler::::new(q_weights), + luminal_metal::quantized::MetalQuantizedCompiler::::new(q_weights), #[cfg(feature = "cuda")] - luminal_cuda::CudaQuantizedCompiler::::new(q_weights), + luminal_cuda::CudaQuantizedCompiler::::new(q_weights), #[cfg(all(not(feature = "metal"), not(feature = "cuda")))] luminal_cpu::CPUCompiler::default(), ), From 0efbc51e412a9d2f7023456fde9b5157a64c2c34 Mon Sep 17 00:00:00 2001 From: Joe Fioti Date: Mon, 29 Apr 2024 09:59:29 -0500 Subject: [PATCH 12/12] Fixed metal tests --- crates/luminal_metal/src/quantized.rs | 13 ++++++------- crates/luminal_metal/src/tests/fp16.rs | 2 +- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/crates/luminal_metal/src/quantized.rs b/crates/luminal_metal/src/quantized.rs index e4abea00..fcdd57c8 100644 --- a/crates/luminal_metal/src/quantized.rs +++ b/crates/luminal_metal/src/quantized.rs @@ -613,13 +613,12 @@ mod tests { }) .collect::>(); let dev = Device::system_default().unwrap(); - cx.tensors - .insert((weights.id, 0), quantized_buffer(&blocks, &dev)); - cx.compile( MetalQuantizedCompiler::::new(vec![weights.id]), &mut out, ); + cx.tensors + .insert((weights.id, 0), quantized_buffer(&blocks, &dev)); cx.execute(); let mut cx1 = Graph::new(); @@ -659,13 +658,13 @@ mod tests { }) .collect::>(); let dev = Device::system_default().unwrap(); - cx.tensors - .insert((weights.id, 0), quantized_buffer(&blocks, &dev)); cx.compile( MetalQuantizedCompiler::::new(vec![weights.id]), &mut out, ); + cx.tensors + .insert((weights.id, 0), quantized_buffer(&blocks, &dev)); cx.execute(); let cpu = dfdx::tensor::Cpu::default(); @@ -706,13 +705,13 @@ mod tests { }) .collect::>(); let dev = Device::system_default().unwrap(); - cx.tensors - .insert((weights.id, 0), quantized_buffer(&blocks, &dev)); cx.compile( MetalQuantizedCompiler::::new(vec![weights.id]), &mut out, ); + cx.tensors + .insert((weights.id, 0), quantized_buffer(&blocks, &dev)); cx.execute(); let cpu = dfdx::tensor::Cpu::default(); diff --git a/crates/luminal_metal/src/tests/fp16.rs b/crates/luminal_metal/src/tests/fp16.rs index e0594ff9..a7802e78 100644 --- a/crates/luminal_metal/src/tests/fp16.rs +++ b/crates/luminal_metal/src/tests/fp16.rs @@ -754,7 +754,7 @@ fn test_pad_contig() { .set_dyn(a_data, &[m, k]) .retrieve(); let mut b: GraphTensor<(Dyn<'M'>, Dyn<'K'>)> = a - .pad(&[(0, 0.into()), (0, Expression::from(16) - 'K')]) + .pad(&[(0, 0.into()), (0, Expression::from(24) - 'K')]) .contiguous() .retrieve(); let mut c: GraphTensor<(Dyn<'M'>, Dyn<'K'>)> =