Skip to content

Commit

Permalink
Updates
Browse files Browse the repository at this point in the history
  • Loading branch information
Joe Fioti authored and Joe Fioti committed Dec 30, 2023
1 parent 51545ee commit a5d01c7
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 62 deletions.
28 changes: 0 additions & 28 deletions examples/llama/config.rs

This file was deleted.

59 changes: 40 additions & 19 deletions examples/llama/main.rs
Original file line number Diff line number Diff line change
@@ -1,25 +1,26 @@
mod config;
mod loader;
mod model;

use std::{marker::PhantomData, time::Instant};
use std::{io::Write, marker::PhantomData, time::Instant};

use luminal::prelude::*;
use colored::Colorize;
use luminal::{prelude::*, shape::symbolic::Expression};
use model::LlamaForCausalLM;
use rust_tokenizers::tokenizer::{
SentencePieceBpeTokenizer, Tokenizer, TruncationStrategy::LongestFirst,
SentencePieceBpeTokenizer, Tokenizer,
TruncationStrategy::{self},
};

use crate::model::KVCache;

type Model = LlamaForCausalLM<
{ config::VOCAB },
{ config::HEADS },
{ config::HIDDEN },
{ config::INTERMEDIATE },
{ config::HEAD_DIM },
{ config::HEAD_DIM_OVER_2 },
{ config::LAYERS },
{ model::VOCAB },
{ model::HEADS },
{ model::HIDDEN },
{ model::INTERMEDIATE },
{ model::HEAD_DIM },
{ model::HEAD_DIM_OVER_2 },
{ model::LAYERS },
>;

#[cfg(feature = "metal")]
Expand All @@ -37,7 +38,13 @@ fn main() {
)
.unwrap();
let mut input = tokenizer
.encode(prompt, None, prompt.len(), &LongestFirst, 0)
.encode(
prompt,
None,
prompt.len(),
&TruncationStrategy::LongestFirst,
0,
)
.token_ids
.iter()
.map(|&x| x as usize)
Expand All @@ -54,7 +61,7 @@ fn main() {
);
let (out1, cache1) = model.forward((
inp,
Option::<Vec<KVCache<Const<1>, Const<0>, { config::HEADS }, { config::HEAD_DIM }>>>::None,
Option::<Vec<KVCache<Const<1>, Const<0>, { model::HEADS }, { model::HEAD_DIM }>>>::None,
PhantomData::<Dyn<'s'>>,
));
out1.retrieve();
Expand All @@ -72,8 +79,8 @@ fn main() {
// Build KV cache forward graph
let kv_model = Model::initialize(&mut cx2);
let single_inp = cx2.named_tensor::<R2<1, 1>>("Input");
let cache_src: Vec<KVCache<Const<1>, Dyn<'p'>, { config::HEADS }, { config::HEAD_DIM }>> = (0
..config::LAYERS)
let cache_src: Vec<KVCache<Const<1>, Dyn<'p'>, { model::HEADS }, { model::HEAD_DIM }>> = (0
..model::LAYERS)
.map(|_| {
(
cx2.named_tensor("Key Cache"),
Expand Down Expand Up @@ -111,7 +118,7 @@ fn main() {
cx1.execute_debug();

let out1 = out1.data();
input.push(sample_index(&out1[out1.len() - 32_000..]));
input.push(sample_index(&out1[out1.len() - 32_000..]) as usize);
println!(
"{}",
tokenizer
Expand Down Expand Up @@ -144,7 +151,7 @@ fn main() {
let o = out.data();
out.drop();
// Sample tokens
input.push(sample_index(&o));
input.push(sample_index(&o) as usize);
println!(
"{}",
tokenizer
Expand All @@ -168,11 +175,25 @@ fn main() {
}
}

fn encode(tokenizer: &SentencePieceBpeTokenizer, text: &str) -> Vec<i64> {
let mut vector = tokenizer
.encode(text, None, text.len(), &TruncationStrategy::LongestFirst, 0)
.token_ids;
vector.insert(0, 1); // Start token
vector
}

fn decode(tokenizer: &SentencePieceBpeTokenizer, token_ids: &[i64]) -> String {
tokenizer
.decode(token_ids, true, false)
.replace("<0x0A>", "\n")
}

// Currently just an argmax, do actual sampling here
fn sample_index(dist: &[f32]) -> usize {
fn sample_index(dist: &[f32]) -> i64 {
dist.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.unwrap()
.0
.0 as i64
}
9 changes: 9 additions & 0 deletions examples/llama/model.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
#![allow(clippy::type_complexity)]
use std::{marker::PhantomData, ops::Mul};

// LLaMa 1 7B Config
pub const VOCAB: usize = 32_000;
pub const HEAD_DIM: usize = 128;
pub const HEAD_DIM_OVER_2: usize = HEAD_DIM / 2;
pub const HIDDEN: usize = 4096;
pub const INTERMEDIATE: usize = 11008;
pub const HEADS: usize = 32;
pub const LAYERS: usize = 1;

use half::f16;
use luminal::{
nn::{embedding::Embedding, norm::RMSNorm},
Expand Down
35 changes: 20 additions & 15 deletions examples/mistral/main.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{marker::PhantomData, time::Instant};
use std::{io::Write, marker::PhantomData, time::Instant};

use colored::Colorize;
use rust_tokenizers::tokenizer::{SentencePieceBpeTokenizer, Tokenizer, TruncationStrategy};
Expand All @@ -15,8 +15,8 @@ type DeviceCompiler = CudaFp16Compiler;
#[cfg(all(not(feature = "cuda"), not(feature = "metal")))]
type DeviceCompiler = CPUCompiler;

fn main() -> Result<(), String> {
println!("Constructing graph...");
fn main() {
println!("Creating graph...");
let tokenizer = SentencePieceBpeTokenizer::from_file(
"./examples/mistral/setup/mistral-7b-hf/tokenizer.model",
false,
Expand Down Expand Up @@ -109,14 +109,17 @@ fn main() -> Result<(), String> {
input_ids.iter().map(|i| *i as f32).collect::<Vec<_>>(),
vec![1, input_ids.len()],
);
let now = Instant::now();
cx1.execute();
println!("Prompt processing took {}ms", now.elapsed().as_millis());

let output_id = sample_index(&logits.data());
input_ids.push(output_id);

// Decode token
completion.push_str(&decode(&tokenizer, &[output_id]));
println!("{}{}", prompt.on_black().white().bold(), completion.green());
print!("{}{}", prompt.white().bold(), completion.green());
std::io::stdout().flush().unwrap();

// Transfer weights and kv cache
transfer_weights(&model, &mut cx1, &kv_model, &mut cx2);
Expand All @@ -126,46 +129,48 @@ fn main() -> Result<(), String> {
}

// Decode loop
let mut token_decode_times = vec![];
for _ in 0..100 {
single_input.set(vec![*input_ids.last().unwrap() as f32]);
cx2.set_dyn_dim('p', input_ids.len() - 1);
cx2.set_dyn_dim('t', input_ids.len());

let now = Instant::now();
cx2.execute();
println!("Forward Pass Took {:.2}s", now.elapsed().as_secs_f32());
token_decode_times.push(now.elapsed().as_millis());

// Sample tokens
let output_id = sample_index(&decode_logits.data());
decode_logits.drop();
completion.push_str(&decode(&tokenizer, &[output_id]));
input_ids.push(output_id);
println!("{}{}", prompt.on_black().white().bold(), completion.green());
print!("{}", decode(&tokenizer, &[output_id]).green());
std::io::stdout().flush().unwrap();

// Swap caches
for ((src_k, src_v), (dest_k, dest_v)) in cache_src.iter().zip(cache_dest.iter()) {
// Move dest caches to src
cx2.swap_tensors(*src_k, *dest_k);
cx2.swap_tensors(*src_v, *dest_v);
// // Drop dest caches
// dest_k.drop();
// dest_v.drop();
// Drop dest caches
dest_k.drop();
dest_v.drop();
}
}

Ok(())
println!(
"\nAverage token generated in {}ms",
token_decode_times.iter().sum::<u128>() / token_decode_times.len() as u128
);
}

// Method to encode text as vector
pub fn encode(tokenizer: &SentencePieceBpeTokenizer, text: &str) -> Vec<i64> {
fn encode(tokenizer: &SentencePieceBpeTokenizer, text: &str) -> Vec<i64> {
let mut vector = tokenizer
.encode(text, None, text.len(), &TruncationStrategy::LongestFirst, 0)
.token_ids;
vector.insert(0, 1); // Start token
vector
}

pub fn decode(tokenizer: &SentencePieceBpeTokenizer, token_ids: &[i64]) -> String {
fn decode(tokenizer: &SentencePieceBpeTokenizer, token_ids: &[i64]) -> String {
tokenizer
.decode(token_ids, true, false)
.replace("<0x0A>", "\n")
Expand Down

0 comments on commit a5d01c7

Please sign in to comment.