Skip to content

Commit

Permalink
Fixed phi example on metal
Browse files Browse the repository at this point in the history
  • Loading branch information
jafioti committed May 3, 2024
1 parent 816deea commit bf97c89
Show file tree
Hide file tree
Showing 7 changed files with 35 additions and 29 deletions.
2 changes: 1 addition & 1 deletion examples/llama/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ fn main() {
.map(|_| (cx.named_tensor("Key Cache"), cx.named_tensor("Value Cache")))
.collect();
cache_src.set_dyn(vec![], &[1, model::N_KV_HEADS, 0, model::HEAD_DIM]);
let model = model::MistralLM::initialize(&mut cx);
let model = model::Llama::initialize(&mut cx);
let mut model_weights = params(&model);
cx.keep_tensors(&model_weights);
let (logits, mut cache_dest) = model.forward((input, &cache_src, PhantomData::<Dyn<'t'>>));
Expand Down
8 changes: 4 additions & 4 deletions examples/llama/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ impl SerializeModule for TransformerBlock {
}
}

pub struct MistralLM {
pub struct Llama {
// Token embeddings
pub embedding: Embedding<VOCAB_SIZE, HIDDEN_DIM>,
// Transformer layers
Expand All @@ -284,7 +284,7 @@ impl<Batch: Dimension, CurSeq: Dimension, PrevSeq: Dimension, TotSeq: Dimension>
GraphTensor<(Batch, CurSeq)>,
&[KVCache<Batch, PrevSeq>],
PhantomData<TotSeq>,
)> for MistralLM
)> for Llama
{
type Output = (
GraphTensor<(Batch, CurSeq, Const<VOCAB_SIZE>)>,
Expand Down Expand Up @@ -315,7 +315,7 @@ impl<Batch: Dimension, CurSeq: Dimension, PrevSeq: Dimension, TotSeq: Dimension>
}
}

impl InitModule for MistralLM {
impl InitModule for Llama {
fn initialize(cx: &mut Graph) -> Self {
Self {
embedding: Embedding {
Expand All @@ -333,7 +333,7 @@ impl InitModule for MistralLM {
}
}

impl SerializeModule for MistralLM {
impl SerializeModule for Llama {
fn serialize(&self, s: &mut Serializer) {
s.module("token_embd", &self.embedding);
s.module("output_norm", &self.norm);
Expand Down
6 changes: 2 additions & 4 deletions examples/llama_server/src/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ pub async fn respond_chat_request(model: &mut Model, request: ChatRequest) -> Ch
let completion_str = model.tokenizer.decode(&completion, false).unwrap();
let completion_tokens = completion.len();

let response = ChatResponse {
ChatResponse {
id,
created,
object: "chat.completion".to_string(),
Expand All @@ -118,7 +118,5 @@ pub async fn respond_chat_request(model: &mut Model, request: ChatRequest) -> Ch
prompt_tokens,
completion_tokens,
},
};

response
}
}
10 changes: 5 additions & 5 deletions examples/llama_server/src/llama/setup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@ impl Model {
&mut model_weights,
),
);
let cache_src_set = downstream(&cache_src, &cx);
let cache_dest_set = cache_dest.to_ids();
let cache_src = downstream(&cache_src, &cx);
let cache_dest = cache_dest.to_ids();
println!("\t\t - {}ms", now.elapsed().as_millis());

// Initial forward pass to load weights
Expand All @@ -105,16 +105,16 @@ impl Model {
cx.set_dyn_dim('t', 1);
cx.execute();
logits.drop();
cache_dest.drop();
cx.drop_tensors(&cache_dest);
println!("\t\t - {}ms", now.elapsed().as_millis());

// Now that weights are loaded, delete the loading nodes so they don't run again
delete_inputs(&downstream(model_weights, &cx), &mut cx);
delete_inputs(downstream(model_weights, &cx), &mut cx);

Model {
input,
tokenizer,
kv_cache_src_set: downstream(&cache_src, &cx),
kv_cache_src_set: downstream(cache_src, &cx),
kv_cache_dest_set: cache_dest.to_ids(),
graph: cx,
logits,
Expand Down
2 changes: 1 addition & 1 deletion examples/llama_server/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,6 @@ async fn chat_completions(
) -> (StatusCode, Json<ChatResponse>) {
let mut model = model.lock().await;

let response = respond_chat_request(&mut *model, payload).await;
let response = respond_chat_request(&mut model, payload).await;
(StatusCode::OK, Json(response))
}
14 changes: 8 additions & 6 deletions examples/phi/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use luminal::prelude::*;
#[command(author, version, about, long_about = None)]
pub struct CLIArgs {
/// Number of tokens to generate
#[clap(short = 't', long = "gen_tokens", default_value = "512")]
#[clap(short = 't', long = "gen_tokens", default_value = "128")]
gen_tokens: i32,

/// Prompt for the model
Expand All @@ -44,7 +44,7 @@ fn main() {
.map(|_| (cx.named_tensor("Key Cache"), cx.named_tensor("Value Cache")))
.collect();
cache_src.set_dyn(vec![], &[1, model::N_HEADS, 0, model::HEAD_DIM]);
let model = model::MistralLM::initialize(&mut cx);
let model = model::Phi::initialize(&mut cx);
let mut model_weights = params(&model);
cx.keep_tensors(&model_weights);
let (logits, mut cache_dest) = model.forward((input, &cache_src, PhantomData::<Dyn<'t'>>));
Expand All @@ -67,9 +67,9 @@ fn main() {
(
GenericCompiler::default(),
#[cfg(feature = "metal")]
luminal_metal::quantized::MetalQuantizedCompiler::<f16>::new(q_weights),
luminal_metal::quantized::MetalQuantizedCompiler::<f32>::new(q_weights),
#[cfg(feature = "cuda")]
luminal_cuda::CudaQuantizedCompiler::<f16>::new(q_weights),
luminal_cuda::CudaQuantizedCompiler::<f32>::new(q_weights),
#[cfg(all(not(feature = "metal"), not(feature = "cuda")))]
luminal_cpu::CPUCompiler::default(),
),
Expand Down Expand Up @@ -98,6 +98,7 @@ fn main() {

// Now that weights are loaded, delete the loading nodes so they don't run again
delete_inputs(&downstream(model_weights, &cx), &mut cx);

// Run prompt processing pass
let mut input_ids = tokenizer
.encode(&cli_args.prompt as &str, false)
Expand Down Expand Up @@ -126,9 +127,9 @@ fn main() {

// Decode token
print!("{}", cli_args.prompt.white().bold());
let out_str = tokenizer.decode(&output_ids, false).unwrap();
let out_str = tokenizer.decode(&output_ids, false).unwrap().bright_green();
let mut prev_output_len = out_str.len();
print!("{}", out_str.bright_green());
print!("{out_str}");
io::stdout().flush().unwrap();

// Swap caches
Expand Down Expand Up @@ -173,6 +174,7 @@ fn main() {
);
}

// Currently just an argmax, do actual sampling here
fn argmax(dist: &[f32]) -> u32 {
dist.iter()
.position_max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
Expand Down
22 changes: 14 additions & 8 deletions examples/phi/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,12 @@ impl<Batch: Dimension, CurSeq: Dimension, PrevSeq: Dimension, TotSeq: Dimension>
.matmul(self.q_proj.permute())
.reshape::<(Batch, CurSeq, Const<N_HEADS>, Const<HEAD_DIM>)>()
.permute::<_, Axes4<0, 2, 1, 3>>();

let keys = x
.matmul(self.k_proj.permute())
.reshape::<(Batch, CurSeq, Const<N_HEADS>, Const<HEAD_DIM>)>()
.permute::<_, Axes4<0, 2, 1, 3>>();

let values = x
.matmul(self.v_proj.permute())
.reshape::<(Batch, CurSeq, Const<N_HEADS>, Const<HEAD_DIM>)>()
Expand All @@ -143,7 +145,10 @@ impl<Batch: Dimension, CurSeq: Dimension, PrevSeq: Dimension, TotSeq: Dimension>
let values = v_cache.concat_along::<_, Axis<2>, _>(values);

// Calculate attention weights
let mut attention_weights = queries.matmul(keys.permute()) / (HEAD_DIM as f32).sqrt();
let mut attention_weights = queries
.reshape::<(_, Const<N_HEADS>, _, _)>() // Split query heads into groups
.matmul(keys.permute())
/ (HEAD_DIM as f32).sqrt();

let attention_mask = self.k_proj.graph().triu::<CurSeq>(1) * f16::MIN.to_f32();
attention_weights += attention_mask
Expand Down Expand Up @@ -215,9 +220,10 @@ impl<Batch: Dimension, CurSeq: Dimension, PrevSeq: Dimension, TotSeq: Dimension>
),
) -> Self::Output {
// Attention
let (y, cache) =
self.attention
.forward((self.attention_norm.forward(x), cache, PhantomData::<TotSeq>));
let normed = self.attention_norm.forward(x);
let (y, cache) = self
.attention
.forward((normed, cache, PhantomData::<TotSeq>));

// Residual Addition
x += y;
Expand Down Expand Up @@ -256,7 +262,7 @@ impl SerializeModule for TransformerBlock {
}
}

pub struct MistralLM {
pub struct Phi {
// Token embeddings
pub embedding: Embedding<VOCAB_SIZE, HIDDEN_DIM>,
// Transformer layers
Expand All @@ -272,7 +278,7 @@ impl<Batch: Dimension, CurSeq: Dimension, PrevSeq: Dimension, TotSeq: Dimension>
GraphTensor<(Batch, CurSeq)>,
&[KVCache<Batch, PrevSeq>],
PhantomData<TotSeq>,
)> for MistralLM
)> for Phi
{
type Output = (
GraphTensor<(Batch, CurSeq, Const<VOCAB_SIZE>)>,
Expand Down Expand Up @@ -303,7 +309,7 @@ impl<Batch: Dimension, CurSeq: Dimension, PrevSeq: Dimension, TotSeq: Dimension>
}
}

impl InitModule for MistralLM {
impl InitModule for Phi {
fn initialize(cx: &mut Graph) -> Self {
Self {
embedding: Embedding {
Expand All @@ -321,7 +327,7 @@ impl InitModule for MistralLM {
}
}

impl SerializeModule for MistralLM {
impl SerializeModule for Phi {
fn serialize(&self, s: &mut Serializer) {
s.module("token_embd", &self.embedding);
s.module("output_norm", &self.norm);
Expand Down

0 comments on commit bf97c89

Please sign in to comment.