From bf97c898734cdcacba2c9f3ee85b8364f5976d6f Mon Sep 17 00:00:00 2001 From: Joe Fioti Date: Fri, 3 May 2024 12:49:42 -0500 Subject: [PATCH] Fixed phi example on metal --- examples/llama/src/main.rs | 2 +- examples/llama/src/model.rs | 8 ++++---- examples/llama_server/src/chat.rs | 6 ++---- examples/llama_server/src/llama/setup.rs | 10 +++++----- examples/llama_server/src/main.rs | 2 +- examples/phi/src/main.rs | 14 ++++++++------ examples/phi/src/model.rs | 22 ++++++++++++++-------- 7 files changed, 35 insertions(+), 29 deletions(-) diff --git a/examples/llama/src/main.rs b/examples/llama/src/main.rs index 36cd1e79..d5ef8981 100644 --- a/examples/llama/src/main.rs +++ b/examples/llama/src/main.rs @@ -44,7 +44,7 @@ fn main() { .map(|_| (cx.named_tensor("Key Cache"), cx.named_tensor("Value Cache"))) .collect(); cache_src.set_dyn(vec![], &[1, model::N_KV_HEADS, 0, model::HEAD_DIM]); - let model = model::MistralLM::initialize(&mut cx); + let model = model::Llama::initialize(&mut cx); let mut model_weights = params(&model); cx.keep_tensors(&model_weights); let (logits, mut cache_dest) = model.forward((input, &cache_src, PhantomData::>)); diff --git a/examples/llama/src/model.rs b/examples/llama/src/model.rs index 90760362..bdf3954c 100644 --- a/examples/llama/src/model.rs +++ b/examples/llama/src/model.rs @@ -268,7 +268,7 @@ impl SerializeModule for TransformerBlock { } } -pub struct MistralLM { +pub struct Llama { // Token embeddings pub embedding: Embedding, // Transformer layers @@ -284,7 +284,7 @@ impl GraphTensor<(Batch, CurSeq)>, &[KVCache], PhantomData, - )> for MistralLM + )> for Llama { type Output = ( GraphTensor<(Batch, CurSeq, Const)>, @@ -315,7 +315,7 @@ impl } } -impl InitModule for MistralLM { +impl InitModule for Llama { fn initialize(cx: &mut Graph) -> Self { Self { embedding: Embedding { @@ -333,7 +333,7 @@ impl InitModule for MistralLM { } } -impl SerializeModule for MistralLM { +impl SerializeModule for Llama { fn serialize(&self, s: &mut Serializer) { s.module("token_embd", &self.embedding); s.module("output_norm", &self.norm); diff --git a/examples/llama_server/src/chat.rs b/examples/llama_server/src/chat.rs index 040298b3..eb03a181 100644 --- a/examples/llama_server/src/chat.rs +++ b/examples/llama_server/src/chat.rs @@ -100,7 +100,7 @@ pub async fn respond_chat_request(model: &mut Model, request: ChatRequest) -> Ch let completion_str = model.tokenizer.decode(&completion, false).unwrap(); let completion_tokens = completion.len(); - let response = ChatResponse { + ChatResponse { id, created, object: "chat.completion".to_string(), @@ -118,7 +118,5 @@ pub async fn respond_chat_request(model: &mut Model, request: ChatRequest) -> Ch prompt_tokens, completion_tokens, }, - }; - - response + } } diff --git a/examples/llama_server/src/llama/setup.rs b/examples/llama_server/src/llama/setup.rs index a771e521..327b6ca0 100644 --- a/examples/llama_server/src/llama/setup.rs +++ b/examples/llama_server/src/llama/setup.rs @@ -93,8 +93,8 @@ impl Model { &mut model_weights, ), ); - let cache_src_set = downstream(&cache_src, &cx); - let cache_dest_set = cache_dest.to_ids(); + let cache_src = downstream(&cache_src, &cx); + let cache_dest = cache_dest.to_ids(); println!("\t\t - {}ms", now.elapsed().as_millis()); // Initial forward pass to load weights @@ -105,16 +105,16 @@ impl Model { cx.set_dyn_dim('t', 1); cx.execute(); logits.drop(); - cache_dest.drop(); + cx.drop_tensors(&cache_dest); println!("\t\t - {}ms", now.elapsed().as_millis()); // Now that weights are loaded, delete the loading nodes so they don't run again - delete_inputs(&downstream(model_weights, &cx), &mut cx); + delete_inputs(downstream(model_weights, &cx), &mut cx); Model { input, tokenizer, - kv_cache_src_set: downstream(&cache_src, &cx), + kv_cache_src_set: downstream(cache_src, &cx), kv_cache_dest_set: cache_dest.to_ids(), graph: cx, logits, diff --git a/examples/llama_server/src/main.rs b/examples/llama_server/src/main.rs index 3845df31..eaae6076 100644 --- a/examples/llama_server/src/main.rs +++ b/examples/llama_server/src/main.rs @@ -39,6 +39,6 @@ async fn chat_completions( ) -> (StatusCode, Json) { let mut model = model.lock().await; - let response = respond_chat_request(&mut *model, payload).await; + let response = respond_chat_request(&mut model, payload).await; (StatusCode::OK, Json(response)) } diff --git a/examples/phi/src/main.rs b/examples/phi/src/main.rs index 13ae99f1..a26fb8a3 100644 --- a/examples/phi/src/main.rs +++ b/examples/phi/src/main.rs @@ -21,7 +21,7 @@ use luminal::prelude::*; #[command(author, version, about, long_about = None)] pub struct CLIArgs { /// Number of tokens to generate - #[clap(short = 't', long = "gen_tokens", default_value = "512")] + #[clap(short = 't', long = "gen_tokens", default_value = "128")] gen_tokens: i32, /// Prompt for the model @@ -44,7 +44,7 @@ fn main() { .map(|_| (cx.named_tensor("Key Cache"), cx.named_tensor("Value Cache"))) .collect(); cache_src.set_dyn(vec![], &[1, model::N_HEADS, 0, model::HEAD_DIM]); - let model = model::MistralLM::initialize(&mut cx); + let model = model::Phi::initialize(&mut cx); let mut model_weights = params(&model); cx.keep_tensors(&model_weights); let (logits, mut cache_dest) = model.forward((input, &cache_src, PhantomData::>)); @@ -67,9 +67,9 @@ fn main() { ( GenericCompiler::default(), #[cfg(feature = "metal")] - luminal_metal::quantized::MetalQuantizedCompiler::::new(q_weights), + luminal_metal::quantized::MetalQuantizedCompiler::::new(q_weights), #[cfg(feature = "cuda")] - luminal_cuda::CudaQuantizedCompiler::::new(q_weights), + luminal_cuda::CudaQuantizedCompiler::::new(q_weights), #[cfg(all(not(feature = "metal"), not(feature = "cuda")))] luminal_cpu::CPUCompiler::default(), ), @@ -98,6 +98,7 @@ fn main() { // Now that weights are loaded, delete the loading nodes so they don't run again delete_inputs(&downstream(model_weights, &cx), &mut cx); + // Run prompt processing pass let mut input_ids = tokenizer .encode(&cli_args.prompt as &str, false) @@ -126,9 +127,9 @@ fn main() { // Decode token print!("{}", cli_args.prompt.white().bold()); - let out_str = tokenizer.decode(&output_ids, false).unwrap(); + let out_str = tokenizer.decode(&output_ids, false).unwrap().bright_green(); let mut prev_output_len = out_str.len(); - print!("{}", out_str.bright_green()); + print!("{out_str}"); io::stdout().flush().unwrap(); // Swap caches @@ -173,6 +174,7 @@ fn main() { ); } +// Currently just an argmax, do actual sampling here fn argmax(dist: &[f32]) -> u32 { dist.iter() .position_max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)) diff --git a/examples/phi/src/model.rs b/examples/phi/src/model.rs index 72f6409c..be9abe49 100644 --- a/examples/phi/src/model.rs +++ b/examples/phi/src/model.rs @@ -125,10 +125,12 @@ impl .matmul(self.q_proj.permute()) .reshape::<(Batch, CurSeq, Const, Const)>() .permute::<_, Axes4<0, 2, 1, 3>>(); + let keys = x .matmul(self.k_proj.permute()) .reshape::<(Batch, CurSeq, Const, Const)>() .permute::<_, Axes4<0, 2, 1, 3>>(); + let values = x .matmul(self.v_proj.permute()) .reshape::<(Batch, CurSeq, Const, Const)>() @@ -143,7 +145,10 @@ impl let values = v_cache.concat_along::<_, Axis<2>, _>(values); // Calculate attention weights - let mut attention_weights = queries.matmul(keys.permute()) / (HEAD_DIM as f32).sqrt(); + let mut attention_weights = queries + .reshape::<(_, Const, _, _)>() // Split query heads into groups + .matmul(keys.permute()) + / (HEAD_DIM as f32).sqrt(); let attention_mask = self.k_proj.graph().triu::(1) * f16::MIN.to_f32(); attention_weights += attention_mask @@ -215,9 +220,10 @@ impl ), ) -> Self::Output { // Attention - let (y, cache) = - self.attention - .forward((self.attention_norm.forward(x), cache, PhantomData::)); + let normed = self.attention_norm.forward(x); + let (y, cache) = self + .attention + .forward((normed, cache, PhantomData::)); // Residual Addition x += y; @@ -256,7 +262,7 @@ impl SerializeModule for TransformerBlock { } } -pub struct MistralLM { +pub struct Phi { // Token embeddings pub embedding: Embedding, // Transformer layers @@ -272,7 +278,7 @@ impl GraphTensor<(Batch, CurSeq)>, &[KVCache], PhantomData, - )> for MistralLM + )> for Phi { type Output = ( GraphTensor<(Batch, CurSeq, Const)>, @@ -303,7 +309,7 @@ impl } } -impl InitModule for MistralLM { +impl InitModule for Phi { fn initialize(cx: &mut Graph) -> Self { Self { embedding: Embedding { @@ -321,7 +327,7 @@ impl InitModule for MistralLM { } } -impl SerializeModule for MistralLM { +impl SerializeModule for Phi { fn serialize(&self, s: &mut Serializer) { s.module("token_embd", &self.embedding); s.module("output_norm", &self.norm);