diff --git a/Cargo.toml b/Cargo.toml index 775e2917..e9ff4c45 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,7 +32,4 @@ members = [ "crates/luminal_nn", "crates/luminal_training", ] -exclude = [ - "crates/luminal_metal", - "crates/luminal_cuda", -] +exclude = ["crates/luminal_metal", "crates/luminal_cuda"] diff --git a/examples/llama_server/.gitignore b/examples/llama_server/.gitignore new file mode 100644 index 00000000..199cec8e --- /dev/null +++ b/examples/llama_server/.gitignore @@ -0,0 +1,17 @@ +# Generated by Cargo +# will have compiled files and executables +debug/ +target/ + +# Remove Cargo.lock from gitignore if creating an executable, leave it for libraries +# More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html +Cargo.lock + +# These are backup files generated by rustfmt +**/*.rs.bk + +# MSVC Windows builds of rustc generate these, which store debugging information +*.pdb +setup/*.gguf +setup/*.json +.vscode \ No newline at end of file diff --git a/examples/llama_server/Cargo.toml b/examples/llama_server/Cargo.toml new file mode 100644 index 00000000..2879b703 --- /dev/null +++ b/examples/llama_server/Cargo.toml @@ -0,0 +1,39 @@ +[package] +name = "llama_server" +version = "0.1.0" +edition = "2021" + +[features] +metal = ["dep:luminal_metal", "dep:metal-rs"] +cuda = ["dep:luminal_cuda", "dep:luminal_cudarc"] + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +luminal = { path = "../.." } +luminal_nn = { path = "../../crates/luminal_nn" } +luminal_cpu = { path = "../../crates/luminal_cpu" } +luminal_metal = { path = "../../crates/luminal_metal", optional = true } +luminal_cuda = { path = "../../crates/luminal_cuda", optional = true } +clap = { version = "4.4.18", features = ["derive"] } +byteorder = "1.5.0" +memmap2 = "0.9.4" +metal-rs = { version = "0.27.0", package = "metal", features = [ + "mps", +], optional = true } +colored = "2.1.0" +itertools = "0.12.1" +luminal_cudarc = { version = "0.10.0", features = [ + "cublas", + "f16", +], optional = true } +tokenizers = "0.15.2" +axum = "0.7.5" +serde = { version = "1.0.199", features = ["derive"] } +tokio = { version = "1.37.0", features = ["rt-multi-thread"] } +tracing = "0.1.40" +tracing-subscriber = { version = "0.3.18", features = ["env-filter"] } +chrono = "0.4.38" +uuid = { version = "1.8.0", features = ["v4"] } +async-trait = "0.1.80" +serde_json = "1.0.116" diff --git a/examples/llama_server/setup/setup.sh b/examples/llama_server/setup/setup.sh new file mode 100644 index 00000000..efbc479e --- /dev/null +++ b/examples/llama_server/setup/setup.sh @@ -0,0 +1,7 @@ +#!/usr/bin/env bash +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) + +echo "Downloading Model and Tokenizer..." +curl --location https://huggingface.co/NousResearch/Meta-Llama-3-8B-Instruct/resolve/main/tokenizer.json?download=true --output $SCRIPT_DIR/tokenizer.json +curl --location https://huggingface.co/QuantFactory/Meta-Llama-3-8B-Instruct-GGUF/resolve/main/Meta-Llama-3-8B-Instruct.Q8_0.gguf?download=true --output $SCRIPT_DIR/llama3-8b.gguf +echo "Done!" diff --git a/examples/llama_server/src/chat.rs b/examples/llama_server/src/chat.rs new file mode 100644 index 00000000..040298b3 --- /dev/null +++ b/examples/llama_server/src/chat.rs @@ -0,0 +1,124 @@ +use chrono::Utc; +use serde::{Deserialize, Serialize}; +use uuid::Uuid; + +// src/chat.rs +use crate::llama::setup::Model; // Import the Model struct + +#[derive(Deserialize)] +pub struct ChatRequest { + pub messages: Vec, +} + +#[derive(Deserialize, Serialize)] +pub struct Message { + pub role: Role, + pub content: String, +} + +#[derive(Deserialize, Serialize, PartialEq, Eq, Debug)] +pub enum Role { + #[serde(rename = "system")] + System, + #[serde(rename = "assistant")] + Assistant, + #[serde(rename = "user")] + User, +} + +#[derive(Serialize)] +pub struct ChatResponse { + pub id: String, + pub object: String, + pub created: i64, + pub model: String, + pub choices: Vec, + pub usage: Usage, +} + +#[derive(Serialize)] +pub struct Choice { + pub index: usize, + pub message: Message, + pub finish_reason: String, +} + +#[derive(Serialize)] +pub struct Usage { + pub prompt_tokens: usize, + pub completion_tokens: usize, + pub total_tokens: usize, +} + +pub fn apply_chat_template(messages: Vec) -> String { + let mut output = "<|begin_of_text|>".to_string(); + for message in messages { + output += "<|start_header_id|>"; + if message.role == Role::Assistant { + output += "assistant" + } else if message.role == Role::User { + output += "user" + } else if message.role == Role::System { + output += "system" + } + output += "<|end_header_id|>"; + output += "\n"; + output += message.content.as_str(); + output += "<|eot_id|>"; + } + output +} + +/// Respond to chat request +pub async fn respond_chat_request(model: &mut Model, request: ChatRequest) -> ChatResponse { + let created = Utc::now().timestamp(); + let raw_uuid = Uuid::new_v4(); + let id = format!("chatcmpl-{}", raw_uuid); + + let mut prompt = apply_chat_template(request.messages); + prompt += "<|start_header_id|>assistant<|end_header_id|>\n"; + // let prompt = "<|begin_of_text|>Here is an implementation of merge sort: + // + // ```python" + // .to_string(); + let prompt_tokens = model.tokenizer.encode(prompt.clone(), false).unwrap(); + let prompt_tokens = prompt_tokens.get_ids(); + let prompt_tokens = prompt_tokens.len(); + println!("Prompt: {:?}", prompt); + + // Generate + let mut completion = vec![]; + model.generate(&prompt, |token| { + const EOS_TOKEN: u32 = 128009; + if token != EOS_TOKEN { + completion.push(token); + } + true + }); + // For now, just clear the cache each time + model.clear_cache(); + let completion_str = model.tokenizer.decode(&completion, false).unwrap(); + let completion_tokens = completion.len(); + + let response = ChatResponse { + id, + created, + object: "chat.completion".to_string(), + model: "meta-llama/Meta-Llama-3-70B-Instruct".to_string(), + choices: vec![Choice { + index: 0, + message: Message { + role: Role::Assistant, + content: completion_str, + }, + finish_reason: "stop".to_string(), + }], + usage: Usage { + total_tokens: prompt_tokens + completion_tokens, + prompt_tokens, + completion_tokens, + }, + }; + + response +} diff --git a/examples/llama_server/src/llama/gguf.rs b/examples/llama_server/src/llama/gguf.rs new file mode 100644 index 00000000..ad4f94bd --- /dev/null +++ b/examples/llama_server/src/llama/gguf.rs @@ -0,0 +1,302 @@ +//! Support for the GGUF file format. +//! +//! Spec: https://github.com/philpax/ggml/blob/gguf-spec/docs/gguf.md + +use byteorder::{LittleEndian, ReadBytesExt}; +use std::collections::HashMap; + +pub const DEFAULT_ALIGNMENT: u64 = 32; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum Magic { + Gguf, +} + +impl TryFrom for Magic { + type Error = (); + fn try_from(value: u32) -> Result { + let magic = match value { + 0x46554747 | 0x47475546 => Self::Gguf, + _ => panic!("unknown magic 0x{value:08x}"), + }; + Ok(magic) + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum VersionedMagic { + GgufV1, + GgufV2, + GgufV3, +} + +impl VersionedMagic { + pub fn read(reader: &mut R) -> Result { + let magic = reader.read_u32::().unwrap(); + let magic = Magic::try_from(magic).unwrap(); + let version = reader.read_u32::().unwrap(); + let versioned_magic = match (magic, version) { + (Magic::Gguf, 1) => Self::GgufV1, + (Magic::Gguf, 2) => Self::GgufV2, + (Magic::Gguf, 3) => Self::GgufV3, + _ => panic!("gguf: unsupported magic/version {magic:?}/{version}"), + }; + Ok(versioned_magic) + } +} + +#[derive(Debug)] +pub struct Content { + pub magic: VersionedMagic, + pub metadata: HashMap, + pub tensor_infos: HashMap, // buffer size and offset + pub tensor_data_offset: u64, +} + +pub fn read_string(reader: &mut R, magic: &VersionedMagic) -> Result { + let len = match magic { + VersionedMagic::GgufV1 => reader.read_u32::().unwrap() as usize, + VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => { + reader.read_u64::().unwrap() as usize + } + }; + let mut v = vec![0u8; len]; + reader.read_exact(&mut v).unwrap(); + // GGUF strings are supposed to be non-null terminated but in practice this happens. + while let Some(0) = v.last() { + v.pop(); + } + // GGUF strings are utf8 encoded but there are cases that don't seem to be valid. + Ok(String::from_utf8_lossy(&v).into_owned()) +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum ValueType { + // The value is a 8-bit unsigned integer. + U8, + // The value is a 8-bit signed integer. + I8, + // The value is a 16-bit unsigned little-endian integer. + U16, + // The value is a 16-bit signed little-endian integer. + I16, + // The value is a 32-bit unsigned little-endian integer. + U32, + // The value is a 32-bit signed little-endian integer. + I32, + // The value is a 64-bit unsigned little-endian integer. + U64, + // The value is a 64-bit signed little-endian integer. + I64, + // The value is a 32-bit IEEE754 floating point number. + F32, + // The value is a 64-bit IEEE754 floating point number. + F64, + // The value is a boolean. + // 1-byte value where 0 is false and 1 is true. + // Anything else is invalid, and should be treated as either the model being invalid or the reader being buggy. + Bool, + // The value is a UTF-8 non-null-terminated string, with length prepended. + String, + // The value is an array of other values, with the length and type prepended. + /// + // Arrays can be nested, and the length of the array is the number of elements in the array, not the number of bytes. + Array, +} + +#[derive(Debug, Clone)] +pub enum Value { + U8(u8), + I8(i8), + U16(u16), + I16(i16), + U32(u32), + I32(i32), + U64(u64), + I64(i64), + F32(f32), + F64(f64), + Bool(bool), + String(String), + Array(Vec), +} + +impl Value { + pub fn read( + reader: &mut R, + value_type: ValueType, + magic: &VersionedMagic, + ) -> Result { + let v = match value_type { + ValueType::U8 => Self::U8(reader.read_u8().unwrap()), + ValueType::I8 => Self::I8(reader.read_i8().unwrap()), + ValueType::U16 => Self::U16(reader.read_u16::().unwrap()), + ValueType::I16 => Self::I16(reader.read_i16::().unwrap()), + ValueType::U32 => Self::U32(reader.read_u32::().unwrap()), + ValueType::I32 => Self::I32(reader.read_i32::().unwrap()), + ValueType::U64 => Self::U64(reader.read_u64::().unwrap()), + ValueType::I64 => Self::I64(reader.read_i64::().unwrap()), + ValueType::F32 => Self::F32(reader.read_f32::().unwrap()), + ValueType::F64 => Self::F64(reader.read_f64::().unwrap()), + ValueType::Bool => match reader.read_u8().unwrap() { + 0 => Self::Bool(false), + 1 => Self::Bool(true), + b => panic!("unexpected bool value {b}"), + }, + ValueType::String => Self::String(read_string(reader, magic).unwrap()), + ValueType::Array => { + let value_type = reader.read_u32::().unwrap(); + let value_type = ValueType::from_u32(value_type).unwrap(); + let len = match magic { + VersionedMagic::GgufV1 => reader.read_u32::().unwrap() as usize, + VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => { + reader.read_u64::().unwrap() as usize + } + }; + let mut vs = Vec::with_capacity(len); + for _ in 0..len { + vs.push(Value::read(reader, value_type, magic).unwrap()) + } + Self::Array(vs) + } + }; + Ok(v) + } +} + +impl ValueType { + pub fn from_u32(v: u32) -> Result { + let v = match v { + 0 => Self::U8, + 1 => Self::I8, + 2 => Self::U16, + 3 => Self::I16, + 4 => Self::U32, + 5 => Self::I32, + 6 => Self::F32, + 7 => Self::Bool, + 8 => Self::String, + 9 => Self::Array, + 10 => Self::U64, + 11 => Self::I64, + 12 => Self::F64, + v => panic!("unrecognized value-type {v:#08x}"), + }; + Ok(v) + } +} + +impl Content { + pub fn read(reader: &mut R) -> Result { + let magic = VersionedMagic::read(reader).unwrap(); + + let tensor_count = match magic { + VersionedMagic::GgufV1 => reader.read_u32::().unwrap() as usize, + VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => { + reader.read_u64::().unwrap() as usize + } + }; + let metadata_kv_count = match magic { + VersionedMagic::GgufV1 => reader.read_u32::().unwrap() as usize, + VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => { + reader.read_u64::().unwrap() as usize + } + }; + + // Read metadata + let mut metadata = HashMap::new(); + for _idx in 0..metadata_kv_count { + let key = read_string(reader, &magic).unwrap(); + let value_type = reader.read_u32::().unwrap(); + let value_type = ValueType::from_u32(value_type).unwrap(); + let value = Value::read(reader, value_type, &magic).unwrap(); + metadata.insert(key, value); + } + // Read tensor infos + let mut tensor_infos = HashMap::new(); + for _idx in 0..tensor_count { + let tensor_name = read_string(reader, &magic).unwrap(); + let n_dimensions = reader.read_u32::().unwrap(); + let n_elements = match magic { + VersionedMagic::GgufV1 => { + let mut dimensions = vec![0; n_dimensions as usize]; + reader + .read_u32_into::(&mut dimensions) + .unwrap(); + dimensions.into_iter().map(|c| c as usize).product() + } + VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => { + let mut dimensions = vec![0; n_dimensions as usize]; + reader + .read_u64_into::(&mut dimensions) + .unwrap(); + dimensions.into_iter().map(|c| c as usize).product() + } + }; + + let ggml_dtype = reader.read_u32::().unwrap(); + let offset = reader.read_u64::().unwrap(); + tensor_infos.insert( + tensor_name, + (n_elements, offset as usize, GgmlDType::from_u32(ggml_dtype)), + ); + } + let position = reader.stream_position().unwrap(); + let alignment = match metadata.get("general.alignment") { + Some(Value::U8(v)) => *v as u64, + Some(Value::U16(v)) => *v as u64, + Some(Value::U32(v)) => *v as u64, + Some(Value::I8(v)) if *v >= 0 => *v as u64, + Some(Value::I16(v)) if *v >= 0 => *v as u64, + Some(Value::I32(v)) if *v >= 0 => *v as u64, + _ => DEFAULT_ALIGNMENT, + }; + let tensor_data_offset = (position + alignment - 1) / alignment * alignment; + Ok(Self { + magic, + metadata, + tensor_infos, + tensor_data_offset, + }) + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum GgmlDType { + F32, + F16, + Q4_0, + Q4_1, + Q5_0, + Q5_1, + Q8_0, + Q8_1, + Q2K, + Q3K, + Q4K, + Q5K, + Q6K, + Q8K, +} + +impl GgmlDType { + fn from_u32(u: u32) -> Self { + match u { + 0 => Self::F32, + 1 => Self::F16, + 2 => Self::Q4_0, + 3 => Self::Q4_1, + 6 => Self::Q5_0, + 7 => Self::Q5_1, + 8 => Self::Q8_0, + 9 => Self::Q8_1, + 10 => Self::Q2K, + 11 => Self::Q3K, + 12 => Self::Q4K, + 13 => Self::Q5K, + 14 => Self::Q6K, + 15 => Self::Q8K, + _ => panic!("unknown dtype for tensor {u}"), + } + } +} diff --git a/examples/llama_server/src/llama/loader.rs b/examples/llama_server/src/llama/loader.rs new file mode 100644 index 00000000..7fbc56f1 --- /dev/null +++ b/examples/llama_server/src/llama/loader.rs @@ -0,0 +1,249 @@ +use itertools::Itertools; +use std::fs::File; +use std::io::{Read, Seek}; +use std::path::Path; + +use luminal::{op::Function, prelude::*}; + +#[cfg(feature = "cuda")] +use {luminal_cuda::CudaData, luminal_cudarc::driver::CudaDevice}; + +use crate::llama::gguf::*; + +#[cfg(feature = "metal")] +use { + luminal_metal::MetalBuffer, + memmap2::Mmap, + metal_rs::{Device, MTLResourceOptions}, +}; + +#[cfg(feature = "metal")] +pub fn q8_load, M: SerializeModule>( + path: P, + model: &M, + graph: &mut Graph, +) -> Vec { + // Read metadata from file + let mut reader = File::open(&path).unwrap(); + let Content { + mut tensor_infos, + tensor_data_offset, + .. + } = Content::read(&mut reader).unwrap(); + + // Create weight loading closures + let mut q8_weights = vec![]; + for (weight_name, node_index) in param_dict(model) { + if let Some(loading_node) = graph + .graph + .node_weight_mut(node_index) + .and_then(|op| op.as_any_mut().downcast_mut::()) + { + let file_path = path.as_ref().to_owned(); + let (n_elements, buffer_offset, data_type) = + tensor_infos.remove(&weight_name.replace('/', ".")).unwrap(); + let n_bytes = match data_type { + GgmlDType::F32 => n_elements * 4, + GgmlDType::Q8_0 => { + q8_weights.push(node_index); + n_elements + (n_elements / 16) + } + _ => panic!("Unsupported dtype: {data_type:?}"), + }; + if let GgmlDType::F32 = data_type { + loading_node.1 = Box::new(move |_| { + // Read bytes + let mut bytes = vec![0; n_bytes]; + let mut file = File::open(&file_path).unwrap(); + file.seek(std::io::SeekFrom::Start( + buffer_offset as u64 + tensor_data_offset, + )) + .unwrap(); + file.read_exact(&mut bytes).unwrap(); + vec![Tensor::new( + bytes + .into_iter() + .chunks(4) + .into_iter() + .map(|c| { + let c = c.collect::>(); + f32::from_le_bytes([c[0], c[1], c[2], c[3]]) + }) + .collect::>(), + )] + }); + } else { + loading_node.1 = Box::new(move |_| { + let mmap_buffer = + unsafe { Mmap::map(&File::open(&file_path).unwrap()).unwrap() }; + let buffer = Device::system_default() + .unwrap() + .new_buffer_with_bytes_no_copy( + unsafe { + mmap_buffer + .as_ptr() + .add(buffer_offset + tensor_data_offset as usize) + as *const _ + }, + n_bytes as u64, + MTLResourceOptions::StorageModeShared, + None, + ); + vec![Tensor::new(MetalBuffer(buffer))] + }); + } + } + } + q8_weights +} + +#[cfg(feature = "cuda")] +pub fn q8_load, M: SerializeModule>( + path: P, + model: &M, + graph: &mut Graph, +) -> Vec { + // Read metadata from file + let mut reader = File::open(&path).unwrap(); + let Content { + mut tensor_infos, + tensor_data_offset, + .. + } = Content::read(&mut reader).unwrap(); + + // Create weight loading closures + let mut q8_weights = vec![]; + for (weight_name, node_index) in param_dict(model) { + if let Some(loading_node) = graph + .graph + .node_weight_mut(node_index) + .and_then(|op| op.as_any_mut().downcast_mut::()) + { + let file_path = path.as_ref().to_owned(); + let (n_elements, buffer_offset, data_type) = + tensor_infos.remove(&weight_name.replace('/', ".")).unwrap(); + let n_bytes = match data_type { + GgmlDType::F32 => n_elements * 4, + GgmlDType::Q8_0 => { + q8_weights.push(node_index); + n_elements + (n_elements / 16) + } + _ => panic!("Unsupported dtype: {data_type:?}"), + }; + loading_node.1 = Box::new(move |_| { + // Read bytes + let mut bytes = vec![0; n_bytes]; + let mut file = File::open(&file_path).unwrap(); + file.seek(std::io::SeekFrom::Start( + buffer_offset as u64 + tensor_data_offset, + )) + .unwrap(); + file.read_exact(&mut bytes).unwrap(); + // Copy buffer over to cuda slice + let device = CudaDevice::new(0).unwrap(); + match data_type { + GgmlDType::F32 => vec![Tensor::new( + bytes + .into_iter() + .chunks(4) + .into_iter() + .map(|c| { + let c = c.collect::>(); + f32::from_le_bytes([c[0], c[1], c[2], c[3]]) + }) + .collect::>(), + )], + GgmlDType::Q8_0 => vec![Tensor::new(CudaData( + device.htod_sync_copy::(&bytes).unwrap(), + ))], + _ => unimplemented!(), + } + }); + } + } + q8_weights +} + +#[cfg(all(not(feature = "metal"), not(feature = "cuda")))] +pub fn q8_load, M: SerializeModule>( + path: P, + model: &M, + graph: &mut Graph, +) -> Vec { + #[repr(C, packed)] + #[derive(Clone, Copy)] + struct Q8Block { + delta: f16, + weights: [i8; 32], + } + + // Read metadata from file + let mut reader = File::open(&path).unwrap(); + let Content { + mut tensor_infos, + tensor_data_offset, + .. + } = Content::read(&mut reader).unwrap(); + + // Create weight loading closures + let mut q8_weights = vec![]; + for (weight_name, node_index) in param_dict(model) { + if let Some(loading_node) = graph + .graph + .node_weight_mut(node_index) + .and_then(|op| op.as_any_mut().downcast_mut::()) + { + let file_path = path.as_ref().to_owned(); + let (n_elements, buffer_offset, data_type) = + tensor_infos.remove(&weight_name.replace('/', ".")).unwrap(); + let n_bytes = match data_type { + GgmlDType::F32 => n_elements * 4, + GgmlDType::Q8_0 => { + q8_weights.push(node_index); + n_elements + (n_elements / 16) + } + _ => panic!("Unsupported dtype: {data_type:?}"), + }; + loading_node.1 = Box::new(move |_| { + // Load all bytes + let mut bytes = vec![0; n_bytes]; + let mut file = File::open(&file_path).unwrap(); + file.seek(std::io::SeekFrom::Start( + buffer_offset as u64 + tensor_data_offset, + )) + .unwrap(); + file.read_exact(&mut bytes).unwrap(); + // Dequantize into f32 + let data: Vec = match data_type { + GgmlDType::F32 => bytes + .into_iter() + .chunks(4) + .into_iter() + .map(|c| { + let c = c.collect::>(); + f32::from_le_bytes([c[0], c[1], c[2], c[3]]) + }) + .collect(), + GgmlDType::Q8_0 => bytes + .into_iter() + .chunks(34) + .into_iter() + .map(|c| { + let chunk = c.collect::>(); + unsafe { chunk.align_to::().1[0] } + }) + .flat_map(|chunk| { + chunk + .weights + .into_iter() + .map(move |i| i as f32 * chunk.delta.to_f32()) + }) + .collect(), + _ => panic!("Unsupported dtype: {data_type:?}"), + }; + vec![Tensor::new(data)] + }); + } + } + q8_weights +} diff --git a/examples/llama_server/src/llama/mod.rs b/examples/llama_server/src/llama/mod.rs new file mode 100644 index 00000000..afd4deec --- /dev/null +++ b/examples/llama_server/src/llama/mod.rs @@ -0,0 +1,4 @@ +pub mod gguf; +pub mod loader; +pub mod model; +pub mod setup; diff --git a/examples/llama_server/src/llama/model.rs b/examples/llama_server/src/llama/model.rs new file mode 100644 index 00000000..90760362 --- /dev/null +++ b/examples/llama_server/src/llama/model.rs @@ -0,0 +1,345 @@ +use std::{marker::PhantomData, ops::Div}; + +use luminal::prelude::{binary::F32Pow, *}; +use luminal_nn::{Embedding, PermutedLinear, RMSNorm}; + +// Llama3 8B Config +pub const VOCAB_SIZE: usize = 128256; +pub const HIDDEN_DIM: usize = 4096; +pub const NUM_LAYERS: usize = 32; +pub const N_HEADS: usize = 32; +pub const N_KV_HEADS: usize = 8; +pub const MLP_DIM: usize = 14336; + +pub const N_ATTENTION_GROUPS: usize = N_HEADS / N_KV_HEADS; +pub const HEAD_DIM: usize = HIDDEN_DIM / N_HEADS; +pub const HEAD_DIM_OVER_2: usize = HEAD_DIM / 2; +pub const ATTN_PROJ_DIM: usize = HEAD_DIM * N_KV_HEADS; + +pub type KVCache = ( + GraphTensor<(Batch, Const, Seq, Const)>, + GraphTensor<(Batch, Const, Seq, Const)>, +); + +pub struct Mlp { + pub gate_proj: PermutedLinear, + pub down_proj: PermutedLinear, + pub up_proj: PermutedLinear, +} + +impl Module> for Mlp +where + GraphTensor: Matmul, Output = GraphTensor>, + GraphTensor: Matmul, Output = GraphTensor>, +{ + type Output = GraphTensor; + + fn forward(&self, input: GraphTensor) -> Self::Output { + let gate = self.gate_proj.forward(input).swish(); + let up = self.up_proj.forward(input) * gate; + self.down_proj.forward(up) + } +} + +impl InitModule for Mlp { + fn initialize(cx: &mut Graph) -> Self { + Self { + gate_proj: PermutedLinear { + weight: cx.named_tensor("Gate"), + }, + up_proj: PermutedLinear { + weight: cx.named_tensor("Up"), + }, + down_proj: PermutedLinear { + weight: cx.named_tensor("Down"), + }, + } + } +} + +impl SerializeModule for Mlp { + fn serialize(&self, s: &mut Serializer) { + s.module("ffn_gate", &self.gate_proj); + s.module("ffn_up", &self.up_proj); + s.module("ffn_down", &self.down_proj); + } +} + +fn apply_rotary_embeddings_ggml( + input: GraphTensor<(Batch, Const, Seq, Const)>, + prev_seq: BigExpression, +) -> GraphTensor<(Batch, Const, Seq, Const)> { + // Get freqs + let freqs = (input.graph().arange::>() * 2.0) / (HEAD_DIM as f32); + let freqs = 500000_f32.pow(freqs); + let pos = input.graph().arange::() + prev_seq; + let emb = pos.expand::<(_, Const<1>), _>().matmul(freqs.expand()); + + // Split input into evens and odds + let split = input.reshape::<(Batch, Const, Seq, Const, Const<2>)>(); + let x0: GraphTensor<(Batch, Const, Seq, Const, Const<1>)> = split + .slice((.., .., .., .., ..Expression::from(1))) + .realize(); + let x1: GraphTensor<(Batch, Const, Seq, Const, Const<1>)> = split + .slice((.., .., .., .., Expression::from(1)..)) + .realize(); + + // Apply sin and cos embeddings + let x0_out = x0 * emb.cos().expand() - x1 * emb.sin().expand(); + let x1_out = x0 * emb.sin().expand() + x1 * emb.cos().expand(); + + // Combine back into output + x0_out + .concat_along::<(Batch, Const, Seq, Const, Const<2>), Axis<4>, _>( + x1_out, + ) + .reshape() +} + +pub struct SelfAttention { + pub q_proj: GraphTensor>, + pub k_proj: GraphTensor>, + pub v_proj: GraphTensor>, + pub o_proj: GraphTensor>, +} + +impl + Module<( + GraphTensor<(Batch, CurSeq, Const)>, + KVCache, + PhantomData, + )> for SelfAttention +{ + type Output = ( + GraphTensor<(Batch, CurSeq, Const)>, + KVCache, + ); + fn forward( + &self, + (x, (k_cache, v_cache), _): ( + GraphTensor<(Batch, CurSeq, Const)>, + KVCache, + PhantomData, + ), + ) -> Self::Output { + // Apply the Projections + let queries = x + .matmul(self.q_proj.permute()) + .reshape::<(Batch, CurSeq, Const, Const)>() + .permute::<_, Axes4<0, 2, 1, 3>>(); + + let keys = x + .matmul(self.k_proj.permute()) + .reshape::<(Batch, CurSeq, Const, Const)>() + .permute::<_, Axes4<0, 2, 1, 3>>(); + + let values = x + .matmul(self.v_proj.permute()) + .reshape::<(Batch, CurSeq, Const, Const)>() + .permute::<_, Axes4<0, 2, 1, 3>>(); + + // Rotary embed queries and keys + let queries = apply_rotary_embeddings_ggml(queries, PrevSeq::const_size().into()); + let keys = apply_rotary_embeddings_ggml(keys, PrevSeq::const_size().into()); + + // Add KV cache + let keys = k_cache.concat_along::<_, Axis<2>, _>(keys); + let values = v_cache.concat_along::<_, Axis<2>, _>(values); + + // Repeat the KV States for Grouped-Query Attention + let repeated_keys = keys.expand::<(_, _, Const, _, _), _>(); + let repeated_values = values.expand::<(_, _, Const, _, _), _>(); + + // Calculate attention weights + let mut attention_weights = queries + .reshape::<(_, Const, Const, _, _)>() // Split query heads into groups + .matmul(repeated_keys.permute()) + .div((HEAD_DIM as f32).sqrt()); + + let attention_mask = self.k_proj.graph().triu::(1) * f16::MIN.to_f32(); + attention_weights += attention_mask + .pad::<(CurSeq, TotSeq), _, _>(&[ + (0.into(), Expression::from(0)), + (TotSeq::const_size() - CurSeq::const_size(), 0.into()), + ]) + .expand(); + + // Calculate final outputs + let output = attention_weights + .softmax::>() + // Apply distribution to values + .matmul(repeated_values) + // Merge heads + .permute::<_, Axes5<0, 3, 1, 2, 4>>() + .reshape::<(Batch, CurSeq, Const)>(); + let output = output + // Apply output projection + .matmul(self.o_proj.permute()); + (output, (keys.contiguous(), values.contiguous())) // Cache needs to be contiguous for transferring to another graph + } +} + +impl InitModule for SelfAttention { + fn initialize(cx: &mut Graph) -> Self { + Self { + q_proj: cx.named_tensor("Q Proj"), + k_proj: cx.named_tensor("K Proj"), + v_proj: cx.named_tensor("V Proj"), + o_proj: cx.named_tensor("O Proj"), + } + } +} + +impl SerializeModule for SelfAttention { + fn serialize(&self, s: &mut Serializer) { + s.tensor("attn_q/weight", self.q_proj); + s.tensor("attn_v/weight", self.v_proj); + s.tensor("attn_k/weight", self.k_proj); + s.tensor("attn_output/weight", self.o_proj); + } +} + +pub struct TransformerBlock { + pub attention: SelfAttention, + pub attention_norm: RMSNorm, + pub feed_forward: Mlp, + pub feed_forward_norm: RMSNorm, +} + +impl + Module<( + GraphTensor<(Batch, CurSeq, Const)>, + KVCache, + PhantomData, + )> for TransformerBlock +{ + type Output = ( + GraphTensor<(Batch, CurSeq, Const)>, + KVCache, + ); + fn forward( + &self, + (mut x, cache, _): ( + GraphTensor<(Batch, CurSeq, Const)>, + KVCache, + PhantomData, + ), + ) -> Self::Output { + // Attention + let normed = self.attention_norm.forward(x); + let (y, cache) = self + .attention + .forward((normed, cache, PhantomData::)); + + // Residual Addition + x += y; + + // Feed Forward + let y = self.feed_forward.forward(self.feed_forward_norm.forward(x)); + + // Residual Addition + (x + y, cache) + } +} + +impl InitModule for TransformerBlock { + fn initialize(cx: &mut Graph) -> Self { + Self { + attention: InitModule::initialize(cx), + attention_norm: RMSNorm { + weight: cx.named_tensor("RMS Norm Weight"), + epsilon: 1e-5, + }, + feed_forward: InitModule::initialize(cx), + feed_forward_norm: RMSNorm { + weight: cx.named_tensor("RMS Norm Weight"), + epsilon: 1e-5, + }, + } + } +} + +impl SerializeModule for TransformerBlock { + fn serialize(&self, s: &mut Serializer) { + s.module("", &self.attention); + s.module("attn_norm", &self.attention_norm); + s.module("ffn_norm", &self.feed_forward_norm); + s.module("", &self.feed_forward); + } +} + +pub struct MistralLM { + // Token embeddings + pub embedding: Embedding, + // Transformer layers + pub layers: Vec, + // Final Norm layer + pub norm: RMSNorm, + // LM Head Layer + pub lm_head: GraphTensor>, +} + +impl + Module<( + GraphTensor<(Batch, CurSeq)>, + &[KVCache], + PhantomData, + )> for MistralLM +{ + type Output = ( + GraphTensor<(Batch, CurSeq, Const)>, + Vec>, + ); + fn forward( + &self, + (input, cache, _): ( + GraphTensor<(Batch, CurSeq)>, + &[KVCache], + PhantomData, + ), + ) -> Self::Output { + // Embed tokens + let mut x = self.embedding.forward(input); + + // Run through layers and collect new caches + let mut new_caches = vec![]; + let mut new_cache; + for (i, layer) in self.layers.iter().enumerate() { + (x, new_cache) = layer.forward((x, cache[i], PhantomData::)); + new_caches.push(new_cache); + } + // Run through last norm and output projection + let output = self.norm.forward(x).matmul(self.lm_head.permute()); + + (output, new_caches) + } +} + +impl InitModule for MistralLM { + fn initialize(cx: &mut Graph) -> Self { + Self { + embedding: Embedding { + weight: cx.named_tensor("Embedding Weight"), + }, + norm: RMSNorm { + weight: cx.named_tensor("RMS Norm Weight"), + epsilon: 1e-5, + }, + lm_head: cx.named_tensor("LM Head"), + layers: (0..NUM_LAYERS) + .map(|_| InitModule::initialize(cx)) + .collect(), + } + } +} + +impl SerializeModule for MistralLM { + fn serialize(&self, s: &mut Serializer) { + s.module("token_embd", &self.embedding); + s.module("output_norm", &self.norm); + s.tensor("output/weight", self.lm_head); + for (i, layer) in self.layers.iter().enumerate() { + s.module(&format!("blk/{i}"), layer); + } + } +} diff --git a/examples/llama_server/src/llama/setup.rs b/examples/llama_server/src/llama/setup.rs new file mode 100644 index 00000000..a771e521 --- /dev/null +++ b/examples/llama_server/src/llama/setup.rs @@ -0,0 +1,212 @@ +use std::{ + io::{self, Write}, + marker::PhantomData, + path::Path, + time::Instant, +}; + +use itertools::Itertools; +use luminal::prelude::*; +use tokenizers::Tokenizer; + +use crate::llama::{ + loader, + model::{KVCache, MistralLM, HEAD_DIM, NUM_LAYERS, N_KV_HEADS}, +}; + +use super::model::VOCAB_SIZE; + +/// Define the model +pub struct Model { + pub graph: Box, + pub input: GraphTensor<(Const<1>, Dyn<'s'>)>, + kv_cache_src_set: Vec, + kv_cache_dest_set: Vec, + logits: GraphTensor>, + pub tokenizer: Tokenizer, + pub last_generated_token: Option, +} + +unsafe impl Send for Model {} +unsafe impl Sync for Model {} + +const TOKENIZER_PATH: &str = "./setup/tokenizer.json"; +const MODEL_PATH: &str = "./setup/llama3-8b.gguf"; + +// Load the model +impl Model { + pub fn setup() -> Self { + if Path::new(TOKENIZER_PATH).exists() && Path::new(MODEL_PATH).exists() { + println!("Tokenizer and Model Exists"); + } else { + panic!("Model does not exist"); + } + + let tokenizer = Tokenizer::from_file(TOKENIZER_PATH).unwrap(); + + print!("Defining graph"); + let now = Instant::now(); + + // Set up graph + let mut cx = Box::new(Graph::new()); + + let mut input = cx.named_tensor::<(Const<1>, Dyn<'s'>)>("Input"); + let mut cache_src: Vec, Dyn<'p'>>> = (0..NUM_LAYERS) + .map(|_| (cx.named_tensor("Key Cache"), cx.named_tensor("Value Cache"))) + .collect(); + cache_src.set_dyn(vec![], &[1, N_KV_HEADS, 0, HEAD_DIM]); + let model = MistralLM::initialize(&mut cx); + let mut model_weights = params(&model); + cx.keep_tensors(&model_weights); + let (logits, mut cache_dest) = model.forward((input, &cache_src, PhantomData::>)); + let mut logits = logits + .slice((.., (Expression::from('s') - 1).., ..)) + .retrieve() + .realize(); + cache_dest.keep(); + + // Set up model loading + #[cfg(any(feature = "metal", feature = "cuda"))] + let q_weights = loader::q8_load(MODEL_PATH, &model, &mut cx); + #[cfg(all(not(feature = "metal"), not(feature = "cuda")))] + loader::q8_load(MODEL_PATH, &model, &mut cx); + println!("\t\t - {}ms", now.elapsed().as_millis()); + + print!("Compiling graph"); + io::stdout().flush().unwrap(); + let now = Instant::now(); + cx.compile( + ( + GenericCompiler::default(), + #[cfg(feature = "metal")] + luminal_metal::quantized::MetalQuantizedCompiler::::new(q_weights), + #[cfg(feature = "cuda")] + luminal_cuda::CudaQuantizedCompiler::::new(q_weights), + #[cfg(all(not(feature = "metal"), not(feature = "cuda")))] + luminal_cpu::CPUCompiler::default(), + ), + ( + &mut input, + &mut logits, + &mut cache_src, + &mut cache_dest, + &mut model_weights, + ), + ); + let cache_src_set = downstream(&cache_src, &cx); + let cache_dest_set = cache_dest.to_ids(); + println!("\t\t - {}ms", now.elapsed().as_millis()); + + // Initial forward pass to load weights + print!("Loading model"); + io::stdout().flush().unwrap(); + let now = Instant::now(); + input.set_dyn(vec![1.], &[1, 1]); + cx.set_dyn_dim('t', 1); + cx.execute(); + logits.drop(); + cache_dest.drop(); + println!("\t\t - {}ms", now.elapsed().as_millis()); + + // Now that weights are loaded, delete the loading nodes so they don't run again + delete_inputs(&downstream(model_weights, &cx), &mut cx); + + Model { + input, + tokenizer, + kv_cache_src_set: downstream(&cache_src, &cx), + kv_cache_dest_set: cache_dest.to_ids(), + graph: cx, + logits, + last_generated_token: None, + } + } + + /// Generate new tokens given some input + pub fn generate(&mut self, prompt: &str, mut continue_callback: impl FnMut(u32) -> bool) { + let input_tokens = self.tokenizer.encode(prompt, false).unwrap(); + let input_tokens = input_tokens.get_ids(); + + self.generate_internal(input_tokens, |dist| { + let output_id = argmax(dist); + (output_id, continue_callback(output_id)) + }) + } + + fn generate_internal( + &mut self, + prompt: &[u32], + mut callback: impl FnMut(&[f32]) -> (u32, bool), + ) { + const EOS_TOKEN: u32 = 128009; // From the llama3 vocab + + let mut input_ids = prompt.to_vec(); + + // Set the dyn dims + let mut seq_len = input_ids.len(); + let mut p = 0; + if self.graph.dyn_map[&'p'] != 0 { + input_ids.insert(0, self.last_generated_token.unwrap()); + p = self.graph.dyn_map[&'t']; + seq_len += p + 1; + } + + self.graph.set_dyn_dim('p', p); + self.graph.set_dyn_dim('t', seq_len); + self.input.set_dyn( + input_ids.iter().map(|i| *i as f32).collect::>(), + &[1, input_ids.len()], + ); + + // First token output (from prompt processing) + self.graph.execute(); + + // Get the output token + let (mut output_id, mut cont) = callback(&self.logits.data()); + self.logits.drop(); + seq_len += 1; + self.last_generated_token = Some(output_id); + + // Swap cache + transfer_data_same_graph( + &self.kv_cache_dest_set, + &self.kv_cache_src_set, + &mut self.graph, + ); + + // Decode loop (next token) + while output_id != EOS_TOKEN && cont { + // Set the data + self.graph.set_dyn_dim('p', seq_len - 1); + self.graph.set_dyn_dim('t', seq_len); + self.input.set_dyn(vec![output_id as f32], &[1, 1]); + + // Execute the graph + self.graph.execute(); + + // Get the output token + (output_id, cont) = callback(&self.logits.data()); + seq_len += 1; + self.logits.drop(); + self.last_generated_token = Some(output_id); + + // Swap cache + transfer_data_same_graph( + &self.kv_cache_dest_set, + &self.kv_cache_src_set, + &mut self.graph, + ); + } + } + + pub fn clear_cache(&mut self) { + self.last_generated_token = None; + self.graph.set_dyn_dim('p', 0); + } +} + +fn argmax(dist: &[f32]) -> u32 { + dist.iter() + .position_max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)) + .unwrap() as u32 +} diff --git a/examples/llama_server/src/main.rs b/examples/llama_server/src/main.rs new file mode 100644 index 00000000..3845df31 --- /dev/null +++ b/examples/llama_server/src/main.rs @@ -0,0 +1,44 @@ +use axum::{ + extract::Extension, + http::StatusCode, + routing::{get, post}, + Json, Router, +}; +use std::sync::Arc; +use tokio::{net::TcpListener, sync::Mutex}; + +mod chat; +mod llama; + +use crate::llama::setup::Model; +use chat::{respond_chat_request, ChatRequest, ChatResponse}; + +#[tokio::main] +async fn main() { + tracing_subscriber::fmt::init(); + + let model = Arc::new(Mutex::new(Model::setup())); + + let app = Router::new() + .route("/", get(root)) + .route("/chat/completions", post(chat_completions)) + .layer(Extension(model)); + + let listener = TcpListener::bind("127.0.0.1:3000").await.unwrap(); + tracing::debug!("listening on {}", listener.local_addr().unwrap()); + axum::serve(listener, app).await.unwrap(); +} + +async fn root() -> &'static str { + "Hello, World!" +} + +async fn chat_completions( + Extension(model): Extension>>, + Json(payload): Json, +) -> (StatusCode, Json) { + let mut model = model.lock().await; + + let response = respond_chat_request(&mut *model, payload).await; + (StatusCode::OK, Json(response)) +}