From a5d01c75768a1e8f84be10d9b03914ae77a0e088 Mon Sep 17 00:00:00 2001 From: Joe Fioti Date: Fri, 29 Dec 2023 20:42:25 -0500 Subject: [PATCH] Updates --- examples/llama/config.rs | 28 ------------------- examples/llama/main.rs | 59 +++++++++++++++++++++++++++------------- examples/llama/model.rs | 9 ++++++ examples/mistral/main.rs | 35 ++++++++++++++---------- 4 files changed, 69 insertions(+), 62 deletions(-) delete mode 100644 examples/llama/config.rs diff --git a/examples/llama/config.rs b/examples/llama/config.rs deleted file mode 100644 index 4032751c..00000000 --- a/examples/llama/config.rs +++ /dev/null @@ -1,28 +0,0 @@ -// Common -pub const VOCAB: usize = 32_000; -pub const HEAD_DIM: usize = 128; -pub const HEAD_DIM_OVER_2: usize = 64; - -// Sheared llama -// pub const HIDDEN: usize = 2048; -// pub const INTERMEDIATE: usize = 5504; -// pub const HEADS: usize = 16; -// pub const LAYERS: usize = 24; - -// 7B -pub const HIDDEN: usize = 4096; -pub const INTERMEDIATE: usize = 11008; -pub const HEADS: usize = 32; -pub const LAYERS: usize = 32; - -// 13B -// pub const HIDDEN: usize = 5120; -// pub const INTERMEDIATE: usize = 13824; -// pub const HEADS: usize = 40; -// pub const LAYERS: usize = 40; - -// 65B -// pub const HIDDEN: usize = 8192; -// pub const INTERMEDIATE: usize = 22016; -// pub const HEADS: usize = 64; -// pub const LAYERS: usize = 80; diff --git a/examples/llama/main.rs b/examples/llama/main.rs index 5468e7ee..5cecdd5d 100644 --- a/examples/llama/main.rs +++ b/examples/llama/main.rs @@ -1,25 +1,26 @@ -mod config; mod loader; mod model; -use std::{marker::PhantomData, time::Instant}; +use std::{io::Write, marker::PhantomData, time::Instant}; -use luminal::prelude::*; +use colored::Colorize; +use luminal::{prelude::*, shape::symbolic::Expression}; use model::LlamaForCausalLM; use rust_tokenizers::tokenizer::{ - SentencePieceBpeTokenizer, Tokenizer, TruncationStrategy::LongestFirst, + SentencePieceBpeTokenizer, Tokenizer, + TruncationStrategy::{self}, }; use crate::model::KVCache; type Model = LlamaForCausalLM< - { config::VOCAB }, - { config::HEADS }, - { config::HIDDEN }, - { config::INTERMEDIATE }, - { config::HEAD_DIM }, - { config::HEAD_DIM_OVER_2 }, - { config::LAYERS }, + { model::VOCAB }, + { model::HEADS }, + { model::HIDDEN }, + { model::INTERMEDIATE }, + { model::HEAD_DIM }, + { model::HEAD_DIM_OVER_2 }, + { model::LAYERS }, >; #[cfg(feature = "metal")] @@ -37,7 +38,13 @@ fn main() { ) .unwrap(); let mut input = tokenizer - .encode(prompt, None, prompt.len(), &LongestFirst, 0) + .encode( + prompt, + None, + prompt.len(), + &TruncationStrategy::LongestFirst, + 0, + ) .token_ids .iter() .map(|&x| x as usize) @@ -54,7 +61,7 @@ fn main() { ); let (out1, cache1) = model.forward(( inp, - Option::, Const<0>, { config::HEADS }, { config::HEAD_DIM }>>>::None, + Option::, Const<0>, { model::HEADS }, { model::HEAD_DIM }>>>::None, PhantomData::>, )); out1.retrieve(); @@ -72,8 +79,8 @@ fn main() { // Build KV cache forward graph let kv_model = Model::initialize(&mut cx2); let single_inp = cx2.named_tensor::>("Input"); - let cache_src: Vec, Dyn<'p'>, { config::HEADS }, { config::HEAD_DIM }>> = (0 - ..config::LAYERS) + let cache_src: Vec, Dyn<'p'>, { model::HEADS }, { model::HEAD_DIM }>> = (0 + ..model::LAYERS) .map(|_| { ( cx2.named_tensor("Key Cache"), @@ -111,7 +118,7 @@ fn main() { cx1.execute_debug(); let out1 = out1.data(); - input.push(sample_index(&out1[out1.len() - 32_000..])); + input.push(sample_index(&out1[out1.len() - 32_000..]) as usize); println!( "{}", tokenizer @@ -144,7 +151,7 @@ fn main() { let o = out.data(); out.drop(); // Sample tokens - input.push(sample_index(&o)); + input.push(sample_index(&o) as usize); println!( "{}", tokenizer @@ -168,11 +175,25 @@ fn main() { } } +fn encode(tokenizer: &SentencePieceBpeTokenizer, text: &str) -> Vec { + let mut vector = tokenizer + .encode(text, None, text.len(), &TruncationStrategy::LongestFirst, 0) + .token_ids; + vector.insert(0, 1); // Start token + vector +} + +fn decode(tokenizer: &SentencePieceBpeTokenizer, token_ids: &[i64]) -> String { + tokenizer + .decode(token_ids, true, false) + .replace("<0x0A>", "\n") +} + // Currently just an argmax, do actual sampling here -fn sample_index(dist: &[f32]) -> usize { +fn sample_index(dist: &[f32]) -> i64 { dist.iter() .enumerate() .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)) .unwrap() - .0 + .0 as i64 } diff --git a/examples/llama/model.rs b/examples/llama/model.rs index c2ddd22a..a7237daa 100644 --- a/examples/llama/model.rs +++ b/examples/llama/model.rs @@ -1,6 +1,15 @@ #![allow(clippy::type_complexity)] use std::{marker::PhantomData, ops::Mul}; +// LLaMa 1 7B Config +pub const VOCAB: usize = 32_000; +pub const HEAD_DIM: usize = 128; +pub const HEAD_DIM_OVER_2: usize = HEAD_DIM / 2; +pub const HIDDEN: usize = 4096; +pub const INTERMEDIATE: usize = 11008; +pub const HEADS: usize = 32; +pub const LAYERS: usize = 1; + use half::f16; use luminal::{ nn::{embedding::Embedding, norm::RMSNorm}, diff --git a/examples/mistral/main.rs b/examples/mistral/main.rs index 992d0794..07000197 100644 --- a/examples/mistral/main.rs +++ b/examples/mistral/main.rs @@ -1,4 +1,4 @@ -use std::{marker::PhantomData, time::Instant}; +use std::{io::Write, marker::PhantomData, time::Instant}; use colored::Colorize; use rust_tokenizers::tokenizer::{SentencePieceBpeTokenizer, Tokenizer, TruncationStrategy}; @@ -15,8 +15,8 @@ type DeviceCompiler = CudaFp16Compiler; #[cfg(all(not(feature = "cuda"), not(feature = "metal")))] type DeviceCompiler = CPUCompiler; -fn main() -> Result<(), String> { - println!("Constructing graph..."); +fn main() { + println!("Creating graph..."); let tokenizer = SentencePieceBpeTokenizer::from_file( "./examples/mistral/setup/mistral-7b-hf/tokenizer.model", false, @@ -109,14 +109,17 @@ fn main() -> Result<(), String> { input_ids.iter().map(|i| *i as f32).collect::>(), vec![1, input_ids.len()], ); + let now = Instant::now(); cx1.execute(); + println!("Prompt processing took {}ms", now.elapsed().as_millis()); let output_id = sample_index(&logits.data()); input_ids.push(output_id); // Decode token completion.push_str(&decode(&tokenizer, &[output_id])); - println!("{}{}", prompt.on_black().white().bold(), completion.green()); + print!("{}{}", prompt.white().bold(), completion.green()); + std::io::stdout().flush().unwrap(); // Transfer weights and kv cache transfer_weights(&model, &mut cx1, &kv_model, &mut cx2); @@ -126,6 +129,7 @@ fn main() -> Result<(), String> { } // Decode loop + let mut token_decode_times = vec![]; for _ in 0..100 { single_input.set(vec![*input_ids.last().unwrap() as f32]); cx2.set_dyn_dim('p', input_ids.len() - 1); @@ -133,31 +137,32 @@ fn main() -> Result<(), String> { let now = Instant::now(); cx2.execute(); - println!("Forward Pass Took {:.2}s", now.elapsed().as_secs_f32()); + token_decode_times.push(now.elapsed().as_millis()); // Sample tokens let output_id = sample_index(&decode_logits.data()); decode_logits.drop(); - completion.push_str(&decode(&tokenizer, &[output_id])); input_ids.push(output_id); - println!("{}{}", prompt.on_black().white().bold(), completion.green()); + print!("{}", decode(&tokenizer, &[output_id]).green()); + std::io::stdout().flush().unwrap(); // Swap caches for ((src_k, src_v), (dest_k, dest_v)) in cache_src.iter().zip(cache_dest.iter()) { // Move dest caches to src cx2.swap_tensors(*src_k, *dest_k); cx2.swap_tensors(*src_v, *dest_v); - // // Drop dest caches - // dest_k.drop(); - // dest_v.drop(); + // Drop dest caches + dest_k.drop(); + dest_v.drop(); } } - - Ok(()) + println!( + "\nAverage token generated in {}ms", + token_decode_times.iter().sum::() / token_decode_times.len() as u128 + ); } -// Method to encode text as vector -pub fn encode(tokenizer: &SentencePieceBpeTokenizer, text: &str) -> Vec { +fn encode(tokenizer: &SentencePieceBpeTokenizer, text: &str) -> Vec { let mut vector = tokenizer .encode(text, None, text.len(), &TruncationStrategy::LongestFirst, 0) .token_ids; @@ -165,7 +170,7 @@ pub fn encode(tokenizer: &SentencePieceBpeTokenizer, text: &str) -> Vec { vector } -pub fn decode(tokenizer: &SentencePieceBpeTokenizer, token_ids: &[i64]) -> String { +fn decode(tokenizer: &SentencePieceBpeTokenizer, token_ids: &[i64]) -> String { tokenizer .decode(token_ids, true, false) .replace("<0x0A>", "\n")